batch_mode_batch_scribe/
language_model_batch_api_request.rs

1// ---------------- [ File: batch-mode-batch-scribe/src/language_model_batch_api_request.rs ]
2/*!
3  | basic example:
4  | {
5  |     "custom_id": "request-1", 
6  |     "method": "POST", 
7  |     "url": "/v1/chat/completions", 
8  |     "body": {
9  |         "model": "gpt-4", 
10  |         "messages": [
11  |             {"role": "system", "content": "You are a helpful assistant."},
12  |             {"role": "user", "content": "Hello world!"}
13  |         ],
14  |         "max_tokens": 1000
15  |     }
16  | }
17  */
18crate::ix!();
19
20/// Represents the complete request structure.
21#[derive(Getters,Setters,Clone,Debug, Serialize, Deserialize)]
22#[getset(get="pub")]
23pub struct LanguageModelBatchAPIRequest {
24
25    /// Identifier for the custom request.
26    custom_id: CustomRequestId,
27
28    /// HTTP method used for the request.
29    #[serde(with = "http_method")]
30    method: HttpMethod,
31
32    /// URL of the API endpoint.
33    #[serde(with = "api_url")]
34    url:  LanguageModelApiUrl,
35
36    /// Body of the request.
37    body: LanguageModelRequestBody,
38}
39
40impl LanguageModelBatchAPIRequest {
41
42    pub fn mock(custom_id: &str) -> Self {
43        LanguageModelBatchAPIRequest {
44            custom_id: CustomRequestId::new(custom_id),
45            method:    HttpMethod::Post,
46            url:       LanguageModelApiUrl::ChatCompletions,
47            body:      LanguageModelRequestBody::mock(),
48        }
49    }
50}
51
52impl From<LanguageModelBatchAPIRequest> for BatchRequestInput {
53
54    fn from(request: LanguageModelBatchAPIRequest) -> Self {
55        BatchRequestInput {
56            custom_id: request.custom_id.to_string(),
57            method: BatchRequestInputMethod::POST,
58            url: match request.url {
59                LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
60            },
61            body: Some(serde_json::to_value(&request.body).unwrap()),
62        }
63    }
64}
65
66pub fn create_batch_input_file(
67    requests:             &[LanguageModelBatchAPIRequest],
68    batch_input_filename: impl AsRef<Path>,
69
70) -> Result<(), BatchInputCreationError> {
71
72    use std::io::{BufWriter,Write};
73    use std::fs::File;
74
75    let file = File::create(batch_input_filename.as_ref())?;
76    let mut writer = BufWriter::new(file);
77
78    for request in requests {
79        let batch_input = BatchRequestInput {
80            custom_id: request.custom_id.to_string(),
81            method: match request.method {
82                HttpMethod::Post => BatchRequestInputMethod::POST,
83                _ => unimplemented!("Only POST method is supported"),
84            },
85            url: match request.url {
86                LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
87                // Handle other endpoints if necessary
88            },
89            body: Some(serde_json::to_value(&request.body)?),
90        };
91        let line = serde_json::to_string(&batch_input)?;
92        writeln!(writer, "{}", line)?;
93    }
94
95    Ok(())
96}
97
98impl LanguageModelBatchAPIRequest {
99
100    pub fn requests_from_query_strings(system_message: &str, model: LanguageModelType, queries: &[String]) -> Vec<Self> {
101        queries.iter().enumerate().map(|(idx,query)| Self::new_basic(model,idx,system_message,&query)).collect()
102    }
103
104    pub fn new_basic(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str) -> Self {
105        Self {
106            custom_id: Self::custom_id_for_idx(idx),
107            method:    HttpMethod::Post,
108            url:       LanguageModelApiUrl::ChatCompletions,
109            body:      LanguageModelRequestBody::new_basic(model,system_message,user_message),
110        }
111    }
112
113    pub fn new_with_image(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
114        Self {
115            custom_id: Self::custom_id_for_idx(idx),
116            method:    HttpMethod::Post,
117            url:       LanguageModelApiUrl::ChatCompletions,
118            body:      LanguageModelRequestBody::new_with_image(model,system_message,user_message,image_b64),
119        }
120    }
121}
122
123impl Display for LanguageModelBatchAPIRequest {
124    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125        match serde_json::to_string(self) {
126            Ok(json) => write!(f, "{}", json),
127            Err(e) => {
128                // Handle JSON serialization errors, though they shouldn't occur with proper struct definitions
129                write!(f, "Error serializing to JSON: {}", e)
130            }
131        }
132    }
133}
134
135impl LanguageModelBatchAPIRequest {
136
137    pub(crate) fn custom_id_for_idx(idx: usize) -> CustomRequestId {
138        CustomRequestId::new(format!("request-{}",idx))
139    }
140}
141
142/// Custom serialization modules for enum string representations.
143mod http_method {
144
145    use super::*;
146
147    pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
148    where
149        S: Serializer,
150    {
151        serializer.serialize_str(&value.to_string())
152    }
153
154    pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
155    where
156        D: Deserializer<'de>,
157    {
158        let s: String = Deserialize::deserialize(deserializer)?;
159        match s.as_ref() {
160            "POST" => Ok(HttpMethod::Post),
161            _ => Err(serde::de::Error::custom("unknown method")),
162        }
163    }
164}
165
166mod api_url {
167
168    use super::*;
169
170    pub fn serialize<S>(value: &LanguageModelApiUrl, serializer: S) -> Result<S::Ok, S::Error>
171    where
172        S: Serializer,
173    {
174        serializer.serialize_str(&value.to_string())
175    }
176
177    pub fn deserialize<'de, D>(deserializer: D) -> Result<LanguageModelApiUrl, D::Error>
178    where
179        D: Deserializer<'de>,
180    {
181        let s: String = Deserialize::deserialize(deserializer)?;
182        match s.as_ref() {
183            "/v1/chat/completions" => Ok(LanguageModelApiUrl::ChatCompletions),
184            _ => Err(serde::de::Error::custom("unknown URL")),
185        }
186    }
187}
188
189// Updated: Provide a minimal request body that matches the struct shape.
190pub fn make_valid_lmb_api_request_json_mock(custom_id: &str) -> String {
191    let request = LanguageModelBatchAPIRequest::mock(custom_id);
192    serde_json::to_string(&request).unwrap()
193}
194
195#[cfg(test)]
196mod language_model_batch_api_request_exhaustive_tests {
197    use super::*;
198
199    #[traced_test]
200    fn mock_produces_expected_fields() {
201        trace!("===== BEGIN TEST: mock_produces_expected_fields =====");
202        let custom_id_str = "test_id";
203        let request = LanguageModelBatchAPIRequest::mock(custom_id_str);
204        debug!("Mock request: {:?}", request);
205
206        pretty_assert_eq!(request.custom_id().to_string(), custom_id_str, "Custom ID mismatch");
207        match request.method {
208            HttpMethod::Post => trace!("Method is POST as expected"),
209            _ => panic!("Expected POST method"),
210        }
211        match request.url {
212            LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
213        }
214        let body = &request.body;
215        match body.model() {
216            LanguageModelType::Gpt4o => trace!("Body model is Gpt4o as expected"),
217            _ => panic!("Expected LanguageModelType::Gpt4o"),
218        }
219        assert!(body.messages().is_empty(), "Mock body should start with no messages");
220        pretty_assert_eq!(
221            *body.max_completion_tokens(), 128,
222            "Mock body should have max_completion_tokens=128"
223        );
224
225        trace!("===== END TEST: mock_produces_expected_fields =====");
226    }
227
228    #[traced_test]
229    fn custom_id_for_idx_produces_expected_format() {
230        trace!("===== BEGIN TEST: custom_id_for_idx_produces_expected_format =====");
231        let idx = 5;
232        let custom_id = LanguageModelBatchAPIRequest::custom_id_for_idx(idx);
233        debug!("Produced CustomRequestId: {:?}", custom_id);
234        pretty_assert_eq!(
235            custom_id.to_string(),
236            "request-5",
237            "Expected custom ID format 'request-<idx>'"
238        );
239        trace!("===== END TEST: custom_id_for_idx_produces_expected_format =====");
240    }
241
242    #[traced_test]
243    fn new_basic_produces_correct_fields() {
244        trace!("===== BEGIN TEST: new_basic_produces_correct_fields =====");
245        let idx = 2;
246        let model = LanguageModelType::Gpt4o;
247        let system_msg = "System basic";
248        let user_msg = "User basic request";
249        let request = LanguageModelBatchAPIRequest::new_basic(model.clone(), idx, system_msg, user_msg);
250        debug!("Constructed request: {:?}", request);
251
252        pretty_assert_eq!(request.custom_id().to_string(), "request-2");
253        match request.method {
254            HttpMethod::Post => trace!("Method is POST as expected"),
255            _ => panic!("Expected POST method"),
256        }
257        match request.url {
258            LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
259        }
260        pretty_assert_eq!(
261            request.body.messages().len(),
262            2,
263            "Should have system + user messages"
264        );
265
266        trace!("===== END TEST: new_basic_produces_correct_fields =====");
267    }
268
269    #[traced_test]
270    fn new_with_image_produces_correct_fields() {
271        trace!("===== BEGIN TEST: new_with_image_produces_correct_fields =====");
272        let idx = 3;
273        let model = LanguageModelType::Gpt4o;
274        let system_msg = "System with image";
275        let user_msg = "User with image request";
276        let image_b64 = "fake_image_data";
277        let request = LanguageModelBatchAPIRequest::new_with_image(model.clone(), idx, system_msg, user_msg, image_b64);
278        debug!("Constructed request with image: {:?}", request);
279
280        pretty_assert_eq!(request.custom_id().to_string(), "request-3");
281        match request.method {
282            HttpMethod::Post => trace!("Method is POST as expected"),
283            _ => panic!("Expected POST method"),
284        }
285        match request.url {
286            LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
287        }
288        pretty_assert_eq!(
289            request.body.messages().len(),
290            2,
291            "Should have system + user-with-image messages"
292        );
293        trace!("===== END TEST: new_with_image_produces_correct_fields =====");
294    }
295
296    #[traced_test]
297    fn requests_from_query_strings_creates_requests_for_each_query() {
298        trace!("===== BEGIN TEST: requests_from_query_strings_creates_requests_for_each_query =====");
299        let system_message = "System greeting";
300        let model = LanguageModelType::Gpt4o;
301        let queries = vec!["Hello".to_string(), "World".to_string(), "Third".to_string()];
302        let requests = LanguageModelBatchAPIRequest::requests_from_query_strings(system_message, model.clone(), &queries);
303        debug!("Constructed requests: {:?}", requests);
304
305        pretty_assert_eq!(
306            requests.len(),
307            queries.len(),
308            "Number of requests should match number of queries"
309        );
310        for (idx, req) in requests.iter().enumerate() {
311            let expected_custom_id = format!("request-{}", idx);
312            pretty_assert_eq!(req.custom_id().to_string(), expected_custom_id);
313            match req.url {
314                LanguageModelApiUrl::ChatCompletions => (),
315            }
316        }
317        trace!("===== END TEST: requests_from_query_strings_creates_requests_for_each_query =====");
318    }
319
320    #[traced_test]
321    fn display_formats_as_json() {
322        trace!("===== BEGIN TEST: display_formats_as_json =====");
323        let request = LanguageModelBatchAPIRequest::mock("test_display");
324        let displayed = format!("{}", request);
325        debug!("Display output: {}", displayed);
326
327        // Just ensure it's valid JSON
328        let parsed: serde_json::Value = serde_json::from_str(&displayed)
329            .expect("Display output should be valid JSON");
330        debug!("Parsed JSON: {:?}", parsed);
331        assert!(parsed.is_object(), "Top-level display output should be an object");
332        trace!("===== END TEST: display_formats_as_json =====");
333    }
334
335    #[traced_test]
336    fn into_batch_request_input_sets_expected_fields() {
337        trace!("===== BEGIN TEST: into_batch_request_input_sets_expected_fields =====");
338        let request = LanguageModelBatchAPIRequest::mock("test_conversion");
339        let converted: BatchRequestInput = request.clone().into();
340        debug!("Converted BatchRequestInput: {:?}", converted);
341
342        pretty_assert_eq!(
343            converted.custom_id,
344            request.custom_id().to_string(),
345            "Custom ID should match"
346        );
347        pretty_assert_eq!(
348            converted.method,
349            BatchRequestInputMethod::POST,
350            "HTTP method should be POST"
351        );
352        pretty_assert_eq!(
353            converted.url,
354            BatchEndpoint::V1ChatCompletions,
355            "URL should be V1ChatCompletions"
356        );
357        assert!(
358            converted.body.is_some(),
359            "Body should be present in the conversion"
360        );
361        trace!("===== END TEST: into_batch_request_input_sets_expected_fields =====");
362    }
363
364    #[traced_test]
365    fn create_batch_input_file_writes_valid_json_lines() {
366        trace!("===== BEGIN TEST: create_batch_input_file_writes_valid_json_lines =====");
367        let requests = vec![
368            LanguageModelBatchAPIRequest::mock("id-1"),
369            LanguageModelBatchAPIRequest::mock("id-2"),
370        ];
371        let temp_dir = std::env::temp_dir();
372        let output_file = temp_dir.join("test_batch_input_file.json");
373        debug!("Temporary output file: {:?}", output_file);
374
375        let result = create_batch_input_file(&requests, &output_file);
376        assert!(result.is_ok(), "create_batch_input_file should succeed");
377
378        let contents = std::fs::read_to_string(&output_file)
379            .expect("Failed to read the output file");
380        debug!("File contents:\n{}", contents);
381        let lines: Vec<&str> = contents.trim().split('\n').collect();
382        pretty_assert_eq!(lines.len(), 2, "Should have exactly 2 lines for 2 requests");
383
384        for (i, line) in lines.iter().enumerate() {
385            let parsed: serde_json::Value = serde_json::from_str(line)
386                .expect("Line should be valid JSON");
387            assert!(
388                parsed.is_object(),
389                "Each line should represent a JSON object"
390            );
391            let custom_id = parsed.get("custom_id")
392                .and_then(|v| v.as_str())
393                .unwrap_or("<missing>");
394            debug!("Parsed line {} custom_id={}", i, custom_id);
395            assert!(
396                custom_id.contains(&format!("id-{}", i+1)),
397                "Expected custom_id to match 'id-<i+1>'"
398            );
399        }
400
401        // Clean up
402        if let Err(err) = std::fs::remove_file(&output_file) {
403            warn!("Failed to remove temp file: {:?}", err);
404        }
405
406        trace!("===== END TEST: create_batch_input_file_writes_valid_json_lines =====");
407    }
408}