batch_mode_batch_scribe/
language_model_batch_api_request.rs

1// ---------------- [ File: batch-mode-batch-scribe/src/language_model_batch_api_request.rs ]
2/*!
3  | basic example:
4  | {
5  |     "custom_id": "request-1", 
6  |     "method": "POST", 
7  |     "url": "/v1/chat/completions", 
8  |     "body": {
9  |         "model": "gpt-4", 
10  |         "messages": [
11  |             {"role": "system", "content": "You are a helpful assistant."},
12  |             {"role": "user", "content": "Hello world!"}
13  |         ],
14  |         "max_tokens": 1000
15  |     }
16  | }
17  */
18crate::ix!();
19
20/// Represents the complete request structure.
21#[derive(Builder,Getters,Setters,Clone,Debug, Serialize, Deserialize)]
22#[getset(get="pub")]
23#[builder(setter(into))]
24pub struct LanguageModelBatchAPIRequest {
25
26    /// Identifier for the custom request.
27    custom_id: CustomRequestId,
28
29    /// HTTP method used for the request.
30    #[serde(with = "http_method")]
31    method: HttpMethod,
32
33    /// URL of the API endpoint.
34    #[serde(with = "api_url")]
35    url:  LanguageModelApiUrl,
36
37    /// Body of the request.
38    body: LanguageModelRequestBody,
39}
40
41impl LanguageModelBatchAPIRequest {
42    pub fn chat_completion_with_id(
43        custom_id: impl Into<String>,
44        system_msg: impl Into<String>,
45        user_msg: impl Into<String>,
46        model: LanguageModelType,
47    ) -> Self {
48        let system_msg = system_msg.into();
49        let user_msg   = user_msg.into();
50
51        LanguageModelBatchAPIRequestBuilder::default()
52            .custom_id(CustomRequestId::new(custom_id))
53            .method(HttpMethod::Post)
54            .url(LanguageModelApiUrl::ChatCompletions)
55            .body(LanguageModelRequestBody::new_basic(
56                model,
57                &system_msg,
58                &user_msg,
59            ))
60            .build()
61            .expect("LanguageModelBatchAPIRequest should build without error")
62    }
63}
64
65impl SeedManifestEntry for LanguageModelBatchAPIRequest {
66
67    fn custom_id(&self) -> String {
68        self.custom_id().as_str().to_string()
69    }
70}
71
72impl LanguageModelBatchAPIRequest {
73
74    pub fn mock(custom_id: &str) -> Self {
75        LanguageModelBatchAPIRequest {
76            custom_id: CustomRequestId::new(custom_id),
77            method:    HttpMethod::Post,
78            url:       LanguageModelApiUrl::ChatCompletions,
79            body:      LanguageModelRequestBody::mock(),
80        }
81    }
82}
83
84impl From<LanguageModelBatchAPIRequest> for BatchRequestInput {
85
86    fn from(request: LanguageModelBatchAPIRequest) -> Self {
87        BatchRequestInput {
88            custom_id: request.custom_id.to_string(),
89            method: BatchRequestInputMethod::POST,
90            url: match request.url {
91                LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
92            },
93            body: Some(serde_json::to_value(&request.body).unwrap()),
94        }
95    }
96}
97
98pub fn create_batch_input_file(
99    requests:             &[LanguageModelBatchAPIRequest],
100    batch_input_filename: impl AsRef<Path>,
101
102) -> Result<(), BatchInputCreationError> {
103
104    use std::io::{BufWriter,Write};
105    use std::fs::File;
106
107    let file = File::create(batch_input_filename.as_ref())?;
108    let mut writer = BufWriter::new(file);
109
110    for request in requests {
111        let batch_input = BatchRequestInput {
112            custom_id: request.custom_id.to_string(),
113            method: match request.method {
114                HttpMethod::Post => BatchRequestInputMethod::POST,
115                _ => unimplemented!("Only POST method is supported"),
116            },
117            url: match request.url {
118                LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
119                // Handle other endpoints if necessary
120            },
121            body: Some(serde_json::to_value(&request.body)?),
122        };
123        let line = serde_json::to_string(&batch_input)?;
124        writeln!(writer, "{}", line)?;
125    }
126
127    Ok(())
128}
129
130impl LanguageModelBatchAPIRequest {
131
132    pub fn requests_from_query_strings(system_message: &str, model: LanguageModelType, queries: &[String]) -> Vec<Self> {
133        queries.iter().enumerate().map(|(idx,query)| Self::new_basic(model,idx,system_message,&query)).collect()
134    }
135
136    pub fn new_basic(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str) -> Self {
137        Self {
138            custom_id: Self::custom_id_for_idx(idx),
139            method:    HttpMethod::Post,
140            url:       LanguageModelApiUrl::ChatCompletions,
141            body:      LanguageModelRequestBody::new_basic(model,system_message,user_message),
142        }
143    }
144
145    pub fn new_with_image(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
146        Self {
147            custom_id: Self::custom_id_for_idx(idx),
148            method:    HttpMethod::Post,
149            url:       LanguageModelApiUrl::ChatCompletions,
150            body:      LanguageModelRequestBody::new_with_image(model,system_message,user_message,image_b64),
151        }
152    }
153}
154
155impl Display for LanguageModelBatchAPIRequest {
156    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157        match serde_json::to_string(self) {
158            Ok(json) => write!(f, "{}", json),
159            Err(e) => {
160                // Handle JSON serialization errors, though they shouldn't occur with proper struct definitions
161                write!(f, "Error serializing to JSON: {}", e)
162            }
163        }
164    }
165}
166
167impl LanguageModelBatchAPIRequest {
168
169    pub(crate) fn custom_id_for_idx(idx: usize) -> CustomRequestId {
170        CustomRequestId::new(format!("request-{}",idx))
171    }
172}
173
174/// Custom serialization modules for enum string representations.
175mod http_method {
176
177    use super::*;
178
179    pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
180    where
181        S: Serializer,
182    {
183        serializer.serialize_str(&value.to_string())
184    }
185
186    pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
187    where
188        D: Deserializer<'de>,
189    {
190        let s: String = Deserialize::deserialize(deserializer)?;
191        match s.as_ref() {
192            "POST" => Ok(HttpMethod::Post),
193            _ => Err(serde::de::Error::custom("unknown method")),
194        }
195    }
196}
197
198mod api_url {
199
200    use super::*;
201
202    pub fn serialize<S>(value: &LanguageModelApiUrl, serializer: S) -> Result<S::Ok, S::Error>
203    where
204        S: Serializer,
205    {
206        serializer.serialize_str(&value.to_string())
207    }
208
209    pub fn deserialize<'de, D>(deserializer: D) -> Result<LanguageModelApiUrl, D::Error>
210    where
211        D: Deserializer<'de>,
212    {
213        let s: String = Deserialize::deserialize(deserializer)?;
214        match s.as_ref() {
215            "/v1/chat/completions" => Ok(LanguageModelApiUrl::ChatCompletions),
216            _ => Err(serde::de::Error::custom("unknown URL")),
217        }
218    }
219}
220
221// Updated: Provide a minimal request body that matches the struct shape.
222pub fn make_valid_lmb_api_request_json_mock(custom_id: &str) -> String {
223    let request = LanguageModelBatchAPIRequest::mock(custom_id);
224    serde_json::to_string(&request).unwrap()
225}
226
227#[cfg(test)]
228mod language_model_batch_api_request_exhaustive_tests {
229    use super::*;
230
231    #[traced_test]
232    fn mock_produces_expected_fields() {
233        trace!("===== BEGIN TEST: mock_produces_expected_fields =====");
234        let custom_id_str = "test_id";
235        let request = LanguageModelBatchAPIRequest::mock(custom_id_str);
236        debug!("Mock request: {:?}", request);
237
238        pretty_assert_eq!(request.custom_id().to_string(), custom_id_str, "Custom ID mismatch");
239        match request.method {
240            HttpMethod::Post => trace!("Method is POST as expected"),
241            _ => panic!("Expected POST method"),
242        }
243        match request.url {
244            LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
245        }
246        let body = &request.body;
247        match body.model() {
248            LanguageModelType::Gpt4o => trace!("Body model is Gpt4o as expected"),
249            _ => panic!("Expected LanguageModelType::Gpt4o"),
250        }
251        assert!(body.messages().is_empty(), "Mock body should start with no messages");
252        pretty_assert_eq!(
253            *body.max_completion_tokens(), 128,
254            "Mock body should have max_completion_tokens=128"
255        );
256
257        trace!("===== END TEST: mock_produces_expected_fields =====");
258    }
259
260    #[traced_test]
261    fn custom_id_for_idx_produces_expected_format() {
262        trace!("===== BEGIN TEST: custom_id_for_idx_produces_expected_format =====");
263        let idx = 5;
264        let custom_id = LanguageModelBatchAPIRequest::custom_id_for_idx(idx);
265        debug!("Produced CustomRequestId: {:?}", custom_id);
266        pretty_assert_eq!(
267            custom_id.to_string(),
268            "request-5",
269            "Expected custom ID format 'request-<idx>'"
270        );
271        trace!("===== END TEST: custom_id_for_idx_produces_expected_format =====");
272    }
273
274    #[traced_test]
275    fn new_basic_produces_correct_fields() {
276        trace!("===== BEGIN TEST: new_basic_produces_correct_fields =====");
277        let idx = 2;
278        let model = LanguageModelType::Gpt4o;
279        let system_msg = "System basic";
280        let user_msg = "User basic request";
281        let request = LanguageModelBatchAPIRequest::new_basic(model.clone(), idx, system_msg, user_msg);
282        debug!("Constructed request: {:?}", request);
283
284        pretty_assert_eq!(request.custom_id().to_string(), "request-2");
285        match request.method {
286            HttpMethod::Post => trace!("Method is POST as expected"),
287            _ => panic!("Expected POST method"),
288        }
289        match request.url {
290            LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
291        }
292        pretty_assert_eq!(
293            request.body.messages().len(),
294            2,
295            "Should have system + user messages"
296        );
297
298        trace!("===== END TEST: new_basic_produces_correct_fields =====");
299    }
300
301    #[traced_test]
302    fn new_with_image_produces_correct_fields() {
303        trace!("===== BEGIN TEST: new_with_image_produces_correct_fields =====");
304        let idx = 3;
305        let model = LanguageModelType::Gpt4o;
306        let system_msg = "System with image";
307        let user_msg = "User with image request";
308        let image_b64 = "fake_image_data";
309        let request = LanguageModelBatchAPIRequest::new_with_image(model.clone(), idx, system_msg, user_msg, image_b64);
310        debug!("Constructed request with image: {:?}", request);
311
312        pretty_assert_eq!(request.custom_id().to_string(), "request-3");
313        match request.method {
314            HttpMethod::Post => trace!("Method is POST as expected"),
315            _ => panic!("Expected POST method"),
316        }
317        match request.url {
318            LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
319        }
320        pretty_assert_eq!(
321            request.body.messages().len(),
322            2,
323            "Should have system + user-with-image messages"
324        );
325        trace!("===== END TEST: new_with_image_produces_correct_fields =====");
326    }
327
328    #[traced_test]
329    fn requests_from_query_strings_creates_requests_for_each_query() {
330        trace!("===== BEGIN TEST: requests_from_query_strings_creates_requests_for_each_query =====");
331        let system_message = "System greeting";
332        let model = LanguageModelType::Gpt4o;
333        let queries = vec!["Hello".to_string(), "World".to_string(), "Third".to_string()];
334        let requests = LanguageModelBatchAPIRequest::requests_from_query_strings(system_message, model.clone(), &queries);
335        debug!("Constructed requests: {:?}", requests);
336
337        pretty_assert_eq!(
338            requests.len(),
339            queries.len(),
340            "Number of requests should match number of queries"
341        );
342        for (idx, req) in requests.iter().enumerate() {
343            let expected_custom_id = format!("request-{}", idx);
344            pretty_assert_eq!(req.custom_id().to_string(), expected_custom_id);
345            match req.url {
346                LanguageModelApiUrl::ChatCompletions => (),
347            }
348        }
349        trace!("===== END TEST: requests_from_query_strings_creates_requests_for_each_query =====");
350    }
351
352    #[traced_test]
353    fn display_formats_as_json() {
354        trace!("===== BEGIN TEST: display_formats_as_json =====");
355        let request = LanguageModelBatchAPIRequest::mock("test_display");
356        let displayed = format!("{}", request);
357        debug!("Display output: {}", displayed);
358
359        // Just ensure it's valid JSON
360        let parsed: serde_json::Value = serde_json::from_str(&displayed)
361            .expect("Display output should be valid JSON");
362        debug!("Parsed JSON: {:?}", parsed);
363        assert!(parsed.is_object(), "Top-level display output should be an object");
364        trace!("===== END TEST: display_formats_as_json =====");
365    }
366
367    #[traced_test]
368    fn into_batch_request_input_sets_expected_fields() {
369        trace!("===== BEGIN TEST: into_batch_request_input_sets_expected_fields =====");
370        let request = LanguageModelBatchAPIRequest::mock("test_conversion");
371        let converted: BatchRequestInput = request.clone().into();
372        debug!("Converted BatchRequestInput: {:?}", converted);
373
374        pretty_assert_eq!(
375            converted.custom_id,
376            request.custom_id().to_string(),
377            "Custom ID should match"
378        );
379        pretty_assert_eq!(
380            converted.method,
381            BatchRequestInputMethod::POST,
382            "HTTP method should be POST"
383        );
384        pretty_assert_eq!(
385            converted.url,
386            BatchEndpoint::V1ChatCompletions,
387            "URL should be V1ChatCompletions"
388        );
389        assert!(
390            converted.body.is_some(),
391            "Body should be present in the conversion"
392        );
393        trace!("===== END TEST: into_batch_request_input_sets_expected_fields =====");
394    }
395
396    #[traced_test]
397    fn create_batch_input_file_writes_valid_json_lines() {
398        trace!("===== BEGIN TEST: create_batch_input_file_writes_valid_json_lines =====");
399        let requests = vec![
400            LanguageModelBatchAPIRequest::mock("id-1"),
401            LanguageModelBatchAPIRequest::mock("id-2"),
402        ];
403        let temp_dir = std::env::temp_dir();
404        let output_file = temp_dir.join("test_batch_input_file.json");
405        debug!("Temporary output file: {:?}", output_file);
406
407        let result = create_batch_input_file(&requests, &output_file);
408        assert!(result.is_ok(), "create_batch_input_file should succeed");
409
410        let contents = std::fs::read_to_string(&output_file)
411            .expect("Failed to read the output file");
412        debug!("File contents:\n{}", contents);
413        let lines: Vec<&str> = contents.trim().split('\n').collect();
414        pretty_assert_eq!(lines.len(), 2, "Should have exactly 2 lines for 2 requests");
415
416        for (i, line) in lines.iter().enumerate() {
417            let parsed: serde_json::Value = serde_json::from_str(line)
418                .expect("Line should be valid JSON");
419            assert!(
420                parsed.is_object(),
421                "Each line should represent a JSON object"
422            );
423            let custom_id = parsed.get("custom_id")
424                .and_then(|v| v.as_str())
425                .unwrap_or("<missing>");
426            debug!("Parsed line {} custom_id={}", i, custom_id);
427            assert!(
428                custom_id.contains(&format!("id-{}", i+1)),
429                "Expected custom_id to match 'id-<i+1>'"
430            );
431        }
432
433        // Clean up
434        if let Err(err) = std::fs::remove_file(&output_file) {
435            warn!("Failed to remove temp file: {:?}", err);
436        }
437
438        trace!("===== END TEST: create_batch_input_file_writes_valid_json_lines =====");
439    }
440}