1crate::ix!();
19
20#[derive(Getters,Setters,Clone,Debug, Serialize, Deserialize)]
22#[getset(get="pub")]
23pub struct LanguageModelBatchAPIRequest {
24
25 custom_id: CustomRequestId,
27
28 #[serde(with = "http_method")]
30 method: HttpMethod,
31
32 #[serde(with = "api_url")]
34 url: LanguageModelApiUrl,
35
36 body: LanguageModelRequestBody,
38}
39
40impl LanguageModelBatchAPIRequest {
41
42 pub fn mock(custom_id: &str) -> Self {
43 LanguageModelBatchAPIRequest {
44 custom_id: CustomRequestId::new(custom_id),
45 method: HttpMethod::Post,
46 url: LanguageModelApiUrl::ChatCompletions,
47 body: LanguageModelRequestBody::mock(),
48 }
49 }
50}
51
52impl From<LanguageModelBatchAPIRequest> for BatchRequestInput {
53
54 fn from(request: LanguageModelBatchAPIRequest) -> Self {
55 BatchRequestInput {
56 custom_id: request.custom_id.to_string(),
57 method: BatchRequestInputMethod::POST,
58 url: match request.url {
59 LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
60 },
61 body: Some(serde_json::to_value(&request.body).unwrap()),
62 }
63 }
64}
65
66pub fn create_batch_input_file(
67 requests: &[LanguageModelBatchAPIRequest],
68 batch_input_filename: impl AsRef<Path>,
69
70) -> Result<(), BatchInputCreationError> {
71
72 use std::io::{BufWriter,Write};
73 use std::fs::File;
74
75 let file = File::create(batch_input_filename.as_ref())?;
76 let mut writer = BufWriter::new(file);
77
78 for request in requests {
79 let batch_input = BatchRequestInput {
80 custom_id: request.custom_id.to_string(),
81 method: match request.method {
82 HttpMethod::Post => BatchRequestInputMethod::POST,
83 _ => unimplemented!("Only POST method is supported"),
84 },
85 url: match request.url {
86 LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
87 },
89 body: Some(serde_json::to_value(&request.body)?),
90 };
91 let line = serde_json::to_string(&batch_input)?;
92 writeln!(writer, "{}", line)?;
93 }
94
95 Ok(())
96}
97
98impl LanguageModelBatchAPIRequest {
99
100 pub fn requests_from_query_strings(system_message: &str, model: LanguageModelType, queries: &[String]) -> Vec<Self> {
101 queries.iter().enumerate().map(|(idx,query)| Self::new_basic(model,idx,system_message,&query)).collect()
102 }
103
104 pub fn new_basic(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str) -> Self {
105 Self {
106 custom_id: Self::custom_id_for_idx(idx),
107 method: HttpMethod::Post,
108 url: LanguageModelApiUrl::ChatCompletions,
109 body: LanguageModelRequestBody::new_basic(model,system_message,user_message),
110 }
111 }
112
113 pub fn new_with_image(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
114 Self {
115 custom_id: Self::custom_id_for_idx(idx),
116 method: HttpMethod::Post,
117 url: LanguageModelApiUrl::ChatCompletions,
118 body: LanguageModelRequestBody::new_with_image(model,system_message,user_message,image_b64),
119 }
120 }
121}
122
123impl Display for LanguageModelBatchAPIRequest {
124 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125 match serde_json::to_string(self) {
126 Ok(json) => write!(f, "{}", json),
127 Err(e) => {
128 write!(f, "Error serializing to JSON: {}", e)
130 }
131 }
132 }
133}
134
135impl LanguageModelBatchAPIRequest {
136
137 pub(crate) fn custom_id_for_idx(idx: usize) -> CustomRequestId {
138 CustomRequestId::new(format!("request-{}",idx))
139 }
140}
141
142mod http_method {
144
145 use super::*;
146
147 pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
148 where
149 S: Serializer,
150 {
151 serializer.serialize_str(&value.to_string())
152 }
153
154 pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
155 where
156 D: Deserializer<'de>,
157 {
158 let s: String = Deserialize::deserialize(deserializer)?;
159 match s.as_ref() {
160 "POST" => Ok(HttpMethod::Post),
161 _ => Err(serde::de::Error::custom("unknown method")),
162 }
163 }
164}
165
166mod api_url {
167
168 use super::*;
169
170 pub fn serialize<S>(value: &LanguageModelApiUrl, serializer: S) -> Result<S::Ok, S::Error>
171 where
172 S: Serializer,
173 {
174 serializer.serialize_str(&value.to_string())
175 }
176
177 pub fn deserialize<'de, D>(deserializer: D) -> Result<LanguageModelApiUrl, D::Error>
178 where
179 D: Deserializer<'de>,
180 {
181 let s: String = Deserialize::deserialize(deserializer)?;
182 match s.as_ref() {
183 "/v1/chat/completions" => Ok(LanguageModelApiUrl::ChatCompletions),
184 _ => Err(serde::de::Error::custom("unknown URL")),
185 }
186 }
187}
188
189pub fn make_valid_lmb_api_request_json_mock(custom_id: &str) -> String {
191 let request = LanguageModelBatchAPIRequest::mock(custom_id);
192 serde_json::to_string(&request).unwrap()
193}
194
195#[cfg(test)]
196mod language_model_batch_api_request_exhaustive_tests {
197 use super::*;
198
199 #[traced_test]
200 fn mock_produces_expected_fields() {
201 trace!("===== BEGIN TEST: mock_produces_expected_fields =====");
202 let custom_id_str = "test_id";
203 let request = LanguageModelBatchAPIRequest::mock(custom_id_str);
204 debug!("Mock request: {:?}", request);
205
206 pretty_assert_eq!(request.custom_id().to_string(), custom_id_str, "Custom ID mismatch");
207 match request.method {
208 HttpMethod::Post => trace!("Method is POST as expected"),
209 _ => panic!("Expected POST method"),
210 }
211 match request.url {
212 LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
213 }
214 let body = &request.body;
215 match body.model() {
216 LanguageModelType::Gpt4o => trace!("Body model is Gpt4o as expected"),
217 _ => panic!("Expected LanguageModelType::Gpt4o"),
218 }
219 assert!(body.messages().is_empty(), "Mock body should start with no messages");
220 pretty_assert_eq!(
221 *body.max_completion_tokens(), 128,
222 "Mock body should have max_completion_tokens=128"
223 );
224
225 trace!("===== END TEST: mock_produces_expected_fields =====");
226 }
227
228 #[traced_test]
229 fn custom_id_for_idx_produces_expected_format() {
230 trace!("===== BEGIN TEST: custom_id_for_idx_produces_expected_format =====");
231 let idx = 5;
232 let custom_id = LanguageModelBatchAPIRequest::custom_id_for_idx(idx);
233 debug!("Produced CustomRequestId: {:?}", custom_id);
234 pretty_assert_eq!(
235 custom_id.to_string(),
236 "request-5",
237 "Expected custom ID format 'request-<idx>'"
238 );
239 trace!("===== END TEST: custom_id_for_idx_produces_expected_format =====");
240 }
241
242 #[traced_test]
243 fn new_basic_produces_correct_fields() {
244 trace!("===== BEGIN TEST: new_basic_produces_correct_fields =====");
245 let idx = 2;
246 let model = LanguageModelType::Gpt4o;
247 let system_msg = "System basic";
248 let user_msg = "User basic request";
249 let request = LanguageModelBatchAPIRequest::new_basic(model.clone(), idx, system_msg, user_msg);
250 debug!("Constructed request: {:?}", request);
251
252 pretty_assert_eq!(request.custom_id().to_string(), "request-2");
253 match request.method {
254 HttpMethod::Post => trace!("Method is POST as expected"),
255 _ => panic!("Expected POST method"),
256 }
257 match request.url {
258 LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
259 }
260 pretty_assert_eq!(
261 request.body.messages().len(),
262 2,
263 "Should have system + user messages"
264 );
265
266 trace!("===== END TEST: new_basic_produces_correct_fields =====");
267 }
268
269 #[traced_test]
270 fn new_with_image_produces_correct_fields() {
271 trace!("===== BEGIN TEST: new_with_image_produces_correct_fields =====");
272 let idx = 3;
273 let model = LanguageModelType::Gpt4o;
274 let system_msg = "System with image";
275 let user_msg = "User with image request";
276 let image_b64 = "fake_image_data";
277 let request = LanguageModelBatchAPIRequest::new_with_image(model.clone(), idx, system_msg, user_msg, image_b64);
278 debug!("Constructed request with image: {:?}", request);
279
280 pretty_assert_eq!(request.custom_id().to_string(), "request-3");
281 match request.method {
282 HttpMethod::Post => trace!("Method is POST as expected"),
283 _ => panic!("Expected POST method"),
284 }
285 match request.url {
286 LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
287 }
288 pretty_assert_eq!(
289 request.body.messages().len(),
290 2,
291 "Should have system + user-with-image messages"
292 );
293 trace!("===== END TEST: new_with_image_produces_correct_fields =====");
294 }
295
296 #[traced_test]
297 fn requests_from_query_strings_creates_requests_for_each_query() {
298 trace!("===== BEGIN TEST: requests_from_query_strings_creates_requests_for_each_query =====");
299 let system_message = "System greeting";
300 let model = LanguageModelType::Gpt4o;
301 let queries = vec!["Hello".to_string(), "World".to_string(), "Third".to_string()];
302 let requests = LanguageModelBatchAPIRequest::requests_from_query_strings(system_message, model.clone(), &queries);
303 debug!("Constructed requests: {:?}", requests);
304
305 pretty_assert_eq!(
306 requests.len(),
307 queries.len(),
308 "Number of requests should match number of queries"
309 );
310 for (idx, req) in requests.iter().enumerate() {
311 let expected_custom_id = format!("request-{}", idx);
312 pretty_assert_eq!(req.custom_id().to_string(), expected_custom_id);
313 match req.url {
314 LanguageModelApiUrl::ChatCompletions => (),
315 }
316 }
317 trace!("===== END TEST: requests_from_query_strings_creates_requests_for_each_query =====");
318 }
319
320 #[traced_test]
321 fn display_formats_as_json() {
322 trace!("===== BEGIN TEST: display_formats_as_json =====");
323 let request = LanguageModelBatchAPIRequest::mock("test_display");
324 let displayed = format!("{}", request);
325 debug!("Display output: {}", displayed);
326
327 let parsed: serde_json::Value = serde_json::from_str(&displayed)
329 .expect("Display output should be valid JSON");
330 debug!("Parsed JSON: {:?}", parsed);
331 assert!(parsed.is_object(), "Top-level display output should be an object");
332 trace!("===== END TEST: display_formats_as_json =====");
333 }
334
335 #[traced_test]
336 fn into_batch_request_input_sets_expected_fields() {
337 trace!("===== BEGIN TEST: into_batch_request_input_sets_expected_fields =====");
338 let request = LanguageModelBatchAPIRequest::mock("test_conversion");
339 let converted: BatchRequestInput = request.clone().into();
340 debug!("Converted BatchRequestInput: {:?}", converted);
341
342 pretty_assert_eq!(
343 converted.custom_id,
344 request.custom_id().to_string(),
345 "Custom ID should match"
346 );
347 pretty_assert_eq!(
348 converted.method,
349 BatchRequestInputMethod::POST,
350 "HTTP method should be POST"
351 );
352 pretty_assert_eq!(
353 converted.url,
354 BatchEndpoint::V1ChatCompletions,
355 "URL should be V1ChatCompletions"
356 );
357 assert!(
358 converted.body.is_some(),
359 "Body should be present in the conversion"
360 );
361 trace!("===== END TEST: into_batch_request_input_sets_expected_fields =====");
362 }
363
364 #[traced_test]
365 fn create_batch_input_file_writes_valid_json_lines() {
366 trace!("===== BEGIN TEST: create_batch_input_file_writes_valid_json_lines =====");
367 let requests = vec![
368 LanguageModelBatchAPIRequest::mock("id-1"),
369 LanguageModelBatchAPIRequest::mock("id-2"),
370 ];
371 let temp_dir = std::env::temp_dir();
372 let output_file = temp_dir.join("test_batch_input_file.json");
373 debug!("Temporary output file: {:?}", output_file);
374
375 let result = create_batch_input_file(&requests, &output_file);
376 assert!(result.is_ok(), "create_batch_input_file should succeed");
377
378 let contents = std::fs::read_to_string(&output_file)
379 .expect("Failed to read the output file");
380 debug!("File contents:\n{}", contents);
381 let lines: Vec<&str> = contents.trim().split('\n').collect();
382 pretty_assert_eq!(lines.len(), 2, "Should have exactly 2 lines for 2 requests");
383
384 for (i, line) in lines.iter().enumerate() {
385 let parsed: serde_json::Value = serde_json::from_str(line)
386 .expect("Line should be valid JSON");
387 assert!(
388 parsed.is_object(),
389 "Each line should represent a JSON object"
390 );
391 let custom_id = parsed.get("custom_id")
392 .and_then(|v| v.as_str())
393 .unwrap_or("<missing>");
394 debug!("Parsed line {} custom_id={}", i, custom_id);
395 assert!(
396 custom_id.contains(&format!("id-{}", i+1)),
397 "Expected custom_id to match 'id-<i+1>'"
398 );
399 }
400
401 if let Err(err) = std::fs::remove_file(&output_file) {
403 warn!("Failed to remove temp file: {:?}", err);
404 }
405
406 trace!("===== END TEST: create_batch_input_file_writes_valid_json_lines =====");
407 }
408}