1crate::ix!();
19
20#[derive(Builder,Getters,Setters,Clone,Debug, Serialize, Deserialize)]
22#[getset(get="pub")]
23#[builder(setter(into))]
24pub struct LanguageModelBatchAPIRequest {
25
26 custom_id: CustomRequestId,
28
29 #[serde(with = "http_method")]
31 method: HttpMethod,
32
33 #[serde(with = "api_url")]
35 url: LanguageModelApiUrl,
36
37 body: LanguageModelRequestBody,
39}
40
41impl LanguageModelBatchAPIRequest {
42 pub fn chat_completion_with_id(
43 custom_id: impl Into<String>,
44 system_msg: impl Into<String>,
45 user_msg: impl Into<String>,
46 model: LanguageModelType,
47 ) -> Self {
48 let system_msg = system_msg.into();
49 let user_msg = user_msg.into();
50
51 LanguageModelBatchAPIRequestBuilder::default()
52 .custom_id(CustomRequestId::new(custom_id))
53 .method(HttpMethod::Post)
54 .url(LanguageModelApiUrl::ChatCompletions)
55 .body(LanguageModelRequestBody::new_basic(
56 model,
57 &system_msg,
58 &user_msg,
59 ))
60 .build()
61 .expect("LanguageModelBatchAPIRequest should build without error")
62 }
63}
64
65impl SeedManifestEntry for LanguageModelBatchAPIRequest {
66
67 fn custom_id(&self) -> String {
68 self.custom_id().as_str().to_string()
69 }
70}
71
72impl LanguageModelBatchAPIRequest {
73
74 pub fn mock(custom_id: &str) -> Self {
75 LanguageModelBatchAPIRequest {
76 custom_id: CustomRequestId::new(custom_id),
77 method: HttpMethod::Post,
78 url: LanguageModelApiUrl::ChatCompletions,
79 body: LanguageModelRequestBody::mock(),
80 }
81 }
82}
83
84impl From<LanguageModelBatchAPIRequest> for BatchRequestInput {
85
86 fn from(request: LanguageModelBatchAPIRequest) -> Self {
87 BatchRequestInput {
88 custom_id: request.custom_id.to_string(),
89 method: BatchRequestInputMethod::POST,
90 url: match request.url {
91 LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
92 },
93 body: Some(serde_json::to_value(&request.body).unwrap()),
94 }
95 }
96}
97
98pub fn create_batch_input_file(
99 requests: &[LanguageModelBatchAPIRequest],
100 batch_input_filename: impl AsRef<Path>,
101
102) -> Result<(), BatchInputCreationError> {
103
104 use std::io::{BufWriter,Write};
105 use std::fs::File;
106
107 let file = File::create(batch_input_filename.as_ref())?;
108 let mut writer = BufWriter::new(file);
109
110 for request in requests {
111 let batch_input = BatchRequestInput {
112 custom_id: request.custom_id.to_string(),
113 method: match request.method {
114 HttpMethod::Post => BatchRequestInputMethod::POST,
115 _ => unimplemented!("Only POST method is supported"),
116 },
117 url: match request.url {
118 LanguageModelApiUrl::ChatCompletions => BatchEndpoint::V1ChatCompletions,
119 },
121 body: Some(serde_json::to_value(&request.body)?),
122 };
123 let line = serde_json::to_string(&batch_input)?;
124 writeln!(writer, "{}", line)?;
125 }
126
127 Ok(())
128}
129
130impl LanguageModelBatchAPIRequest {
131
132 pub fn requests_from_query_strings(system_message: &str, model: LanguageModelType, queries: &[String]) -> Vec<Self> {
133 queries.iter().enumerate().map(|(idx,query)| Self::new_basic(model,idx,system_message,&query)).collect()
134 }
135
136 pub fn new_basic(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str) -> Self {
137 Self {
138 custom_id: Self::custom_id_for_idx(idx),
139 method: HttpMethod::Post,
140 url: LanguageModelApiUrl::ChatCompletions,
141 body: LanguageModelRequestBody::new_basic(model,system_message,user_message),
142 }
143 }
144
145 pub fn new_with_image(model: LanguageModelType, idx: usize, system_message: &str, user_message: &str, image_b64: &str) -> Self {
146 Self {
147 custom_id: Self::custom_id_for_idx(idx),
148 method: HttpMethod::Post,
149 url: LanguageModelApiUrl::ChatCompletions,
150 body: LanguageModelRequestBody::new_with_image(model,system_message,user_message,image_b64),
151 }
152 }
153}
154
155impl Display for LanguageModelBatchAPIRequest {
156 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157 match serde_json::to_string(self) {
158 Ok(json) => write!(f, "{}", json),
159 Err(e) => {
160 write!(f, "Error serializing to JSON: {}", e)
162 }
163 }
164 }
165}
166
167impl LanguageModelBatchAPIRequest {
168
169 pub(crate) fn custom_id_for_idx(idx: usize) -> CustomRequestId {
170 CustomRequestId::new(format!("request-{}",idx))
171 }
172}
173
174mod http_method {
176
177 use super::*;
178
179 pub fn serialize<S>(value: &HttpMethod, serializer: S) -> Result<S::Ok, S::Error>
180 where
181 S: Serializer,
182 {
183 serializer.serialize_str(&value.to_string())
184 }
185
186 pub fn deserialize<'de, D>(deserializer: D) -> Result<HttpMethod, D::Error>
187 where
188 D: Deserializer<'de>,
189 {
190 let s: String = Deserialize::deserialize(deserializer)?;
191 match s.as_ref() {
192 "POST" => Ok(HttpMethod::Post),
193 _ => Err(serde::de::Error::custom("unknown method")),
194 }
195 }
196}
197
198mod api_url {
199
200 use super::*;
201
202 pub fn serialize<S>(value: &LanguageModelApiUrl, serializer: S) -> Result<S::Ok, S::Error>
203 where
204 S: Serializer,
205 {
206 serializer.serialize_str(&value.to_string())
207 }
208
209 pub fn deserialize<'de, D>(deserializer: D) -> Result<LanguageModelApiUrl, D::Error>
210 where
211 D: Deserializer<'de>,
212 {
213 let s: String = Deserialize::deserialize(deserializer)?;
214 match s.as_ref() {
215 "/v1/chat/completions" => Ok(LanguageModelApiUrl::ChatCompletions),
216 _ => Err(serde::de::Error::custom("unknown URL")),
217 }
218 }
219}
220
221pub fn make_valid_lmb_api_request_json_mock(custom_id: &str) -> String {
223 let request = LanguageModelBatchAPIRequest::mock(custom_id);
224 serde_json::to_string(&request).unwrap()
225}
226
227#[cfg(test)]
228mod language_model_batch_api_request_exhaustive_tests {
229 use super::*;
230
231 #[traced_test]
232 fn mock_produces_expected_fields() {
233 trace!("===== BEGIN TEST: mock_produces_expected_fields =====");
234 let custom_id_str = "test_id";
235 let request = LanguageModelBatchAPIRequest::mock(custom_id_str);
236 debug!("Mock request: {:?}", request);
237
238 pretty_assert_eq!(request.custom_id().to_string(), custom_id_str, "Custom ID mismatch");
239 match request.method {
240 HttpMethod::Post => trace!("Method is POST as expected"),
241 _ => panic!("Expected POST method"),
242 }
243 match request.url {
244 LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
245 }
246 let body = &request.body;
247 match body.model() {
248 LanguageModelType::Gpt4o => trace!("Body model is Gpt4o as expected"),
249 _ => panic!("Expected LanguageModelType::Gpt4o"),
250 }
251 assert!(body.messages().is_empty(), "Mock body should start with no messages");
252 pretty_assert_eq!(
253 *body.max_completion_tokens(), 128,
254 "Mock body should have max_completion_tokens=128"
255 );
256
257 trace!("===== END TEST: mock_produces_expected_fields =====");
258 }
259
260 #[traced_test]
261 fn custom_id_for_idx_produces_expected_format() {
262 trace!("===== BEGIN TEST: custom_id_for_idx_produces_expected_format =====");
263 let idx = 5;
264 let custom_id = LanguageModelBatchAPIRequest::custom_id_for_idx(idx);
265 debug!("Produced CustomRequestId: {:?}", custom_id);
266 pretty_assert_eq!(
267 custom_id.to_string(),
268 "request-5",
269 "Expected custom ID format 'request-<idx>'"
270 );
271 trace!("===== END TEST: custom_id_for_idx_produces_expected_format =====");
272 }
273
274 #[traced_test]
275 fn new_basic_produces_correct_fields() {
276 trace!("===== BEGIN TEST: new_basic_produces_correct_fields =====");
277 let idx = 2;
278 let model = LanguageModelType::Gpt4o;
279 let system_msg = "System basic";
280 let user_msg = "User basic request";
281 let request = LanguageModelBatchAPIRequest::new_basic(model.clone(), idx, system_msg, user_msg);
282 debug!("Constructed request: {:?}", request);
283
284 pretty_assert_eq!(request.custom_id().to_string(), "request-2");
285 match request.method {
286 HttpMethod::Post => trace!("Method is POST as expected"),
287 _ => panic!("Expected POST method"),
288 }
289 match request.url {
290 LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
291 }
292 pretty_assert_eq!(
293 request.body.messages().len(),
294 2,
295 "Should have system + user messages"
296 );
297
298 trace!("===== END TEST: new_basic_produces_correct_fields =====");
299 }
300
301 #[traced_test]
302 fn new_with_image_produces_correct_fields() {
303 trace!("===== BEGIN TEST: new_with_image_produces_correct_fields =====");
304 let idx = 3;
305 let model = LanguageModelType::Gpt4o;
306 let system_msg = "System with image";
307 let user_msg = "User with image request";
308 let image_b64 = "fake_image_data";
309 let request = LanguageModelBatchAPIRequest::new_with_image(model.clone(), idx, system_msg, user_msg, image_b64);
310 debug!("Constructed request with image: {:?}", request);
311
312 pretty_assert_eq!(request.custom_id().to_string(), "request-3");
313 match request.method {
314 HttpMethod::Post => trace!("Method is POST as expected"),
315 _ => panic!("Expected POST method"),
316 }
317 match request.url {
318 LanguageModelApiUrl::ChatCompletions => trace!("URL is ChatCompletions as expected"),
319 }
320 pretty_assert_eq!(
321 request.body.messages().len(),
322 2,
323 "Should have system + user-with-image messages"
324 );
325 trace!("===== END TEST: new_with_image_produces_correct_fields =====");
326 }
327
328 #[traced_test]
329 fn requests_from_query_strings_creates_requests_for_each_query() {
330 trace!("===== BEGIN TEST: requests_from_query_strings_creates_requests_for_each_query =====");
331 let system_message = "System greeting";
332 let model = LanguageModelType::Gpt4o;
333 let queries = vec!["Hello".to_string(), "World".to_string(), "Third".to_string()];
334 let requests = LanguageModelBatchAPIRequest::requests_from_query_strings(system_message, model.clone(), &queries);
335 debug!("Constructed requests: {:?}", requests);
336
337 pretty_assert_eq!(
338 requests.len(),
339 queries.len(),
340 "Number of requests should match number of queries"
341 );
342 for (idx, req) in requests.iter().enumerate() {
343 let expected_custom_id = format!("request-{}", idx);
344 pretty_assert_eq!(req.custom_id().to_string(), expected_custom_id);
345 match req.url {
346 LanguageModelApiUrl::ChatCompletions => (),
347 }
348 }
349 trace!("===== END TEST: requests_from_query_strings_creates_requests_for_each_query =====");
350 }
351
352 #[traced_test]
353 fn display_formats_as_json() {
354 trace!("===== BEGIN TEST: display_formats_as_json =====");
355 let request = LanguageModelBatchAPIRequest::mock("test_display");
356 let displayed = format!("{}", request);
357 debug!("Display output: {}", displayed);
358
359 let parsed: serde_json::Value = serde_json::from_str(&displayed)
361 .expect("Display output should be valid JSON");
362 debug!("Parsed JSON: {:?}", parsed);
363 assert!(parsed.is_object(), "Top-level display output should be an object");
364 trace!("===== END TEST: display_formats_as_json =====");
365 }
366
367 #[traced_test]
368 fn into_batch_request_input_sets_expected_fields() {
369 trace!("===== BEGIN TEST: into_batch_request_input_sets_expected_fields =====");
370 let request = LanguageModelBatchAPIRequest::mock("test_conversion");
371 let converted: BatchRequestInput = request.clone().into();
372 debug!("Converted BatchRequestInput: {:?}", converted);
373
374 pretty_assert_eq!(
375 converted.custom_id,
376 request.custom_id().to_string(),
377 "Custom ID should match"
378 );
379 pretty_assert_eq!(
380 converted.method,
381 BatchRequestInputMethod::POST,
382 "HTTP method should be POST"
383 );
384 pretty_assert_eq!(
385 converted.url,
386 BatchEndpoint::V1ChatCompletions,
387 "URL should be V1ChatCompletions"
388 );
389 assert!(
390 converted.body.is_some(),
391 "Body should be present in the conversion"
392 );
393 trace!("===== END TEST: into_batch_request_input_sets_expected_fields =====");
394 }
395
396 #[traced_test]
397 fn create_batch_input_file_writes_valid_json_lines() {
398 trace!("===== BEGIN TEST: create_batch_input_file_writes_valid_json_lines =====");
399 let requests = vec![
400 LanguageModelBatchAPIRequest::mock("id-1"),
401 LanguageModelBatchAPIRequest::mock("id-2"),
402 ];
403 let temp_dir = std::env::temp_dir();
404 let output_file = temp_dir.join("test_batch_input_file.json");
405 debug!("Temporary output file: {:?}", output_file);
406
407 let result = create_batch_input_file(&requests, &output_file);
408 assert!(result.is_ok(), "create_batch_input_file should succeed");
409
410 let contents = std::fs::read_to_string(&output_file)
411 .expect("Failed to read the output file");
412 debug!("File contents:\n{}", contents);
413 let lines: Vec<&str> = contents.trim().split('\n').collect();
414 pretty_assert_eq!(lines.len(), 2, "Should have exactly 2 lines for 2 requests");
415
416 for (i, line) in lines.iter().enumerate() {
417 let parsed: serde_json::Value = serde_json::from_str(line)
418 .expect("Line should be valid JSON");
419 assert!(
420 parsed.is_object(),
421 "Each line should represent a JSON object"
422 );
423 let custom_id = parsed.get("custom_id")
424 .and_then(|v| v.as_str())
425 .unwrap_or("<missing>");
426 debug!("Parsed line {} custom_id={}", i, custom_id);
427 assert!(
428 custom_id.contains(&format!("id-{}", i+1)),
429 "Expected custom_id to match 'id-<i+1>'"
430 );
431 }
432
433 if let Err(err) = std::fs::remove_file(&output_file) {
435 warn!("Failed to remove temp file: {:?}", err);
436 }
437
438 trace!("===== END TEST: create_batch_input_file_writes_valid_json_lines =====");
439 }
440}