1#![allow(dead_code)]
3#![allow(clippy::too_many_lines)]
4#![allow(clippy::missing_const_for_fn)]
5#![allow(clippy::unused_async)]
6#![allow(clippy::unreadable_literal)]
7#![allow(clippy::cast_possible_truncation)]
8#![allow(clippy::cast_precision_loss)]
9#![allow(clippy::uninlined_format_args)]
10#![allow(clippy::doc_markdown)]
11#![allow(clippy::needless_pass_by_value)]
12#![allow(clippy::redundant_closure_for_method_calls)]
13#![allow(clippy::unused_self)]
14#![allow(clippy::cast_lossless)]
15#![allow(clippy::struct_field_names)]
16#![allow(clippy::cast_possible_wrap)]
17use openai_ergonomic::{Client, Error, Result};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::fs;
37use std::path::Path;
38use std::time::{Duration, Instant};
39use tokio::time::sleep;
40use tracing::{debug, error, info, warn};
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44struct BatchRequest {
45 custom_id: String,
47 method: String,
49 url: String,
51 body: BatchRequestBody,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57struct BatchRequestBody {
58 model: String,
60 messages: Vec<ChatMessage>,
62 max_tokens: Option<i32>,
64 temperature: Option<f64>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70struct ChatMessage {
71 role: String,
73 content: String,
75}
76
77#[derive(Debug, Clone)]
79struct BatchJob {
80 id: String,
82 status: String,
84 input_file_id: String,
86 output_file_id: Option<String>,
88 error_file_id: Option<String>,
90 created_at: i64,
92 completed_at: Option<i64>,
94 request_counts: BatchRequestCounts,
96}
97
98#[derive(Debug, Clone)]
100struct BatchRequestCounts {
101 total: i32,
103 completed: i32,
105 failed: i32,
107}
108
109#[derive(Debug)]
111struct BatchProcessor {
112 client: Client,
113 batch_dir: String,
115 max_batch_size: usize,
117 poll_interval: Duration,
119 max_wait_time: Duration,
121}
122
123impl BatchProcessor {
124 fn new(client: Client) -> Self {
126 Self {
127 client,
128 batch_dir: "./batch_files".to_string(),
129 max_batch_size: 50_000,
130 poll_interval: Duration::from_secs(30),
131 max_wait_time: Duration::from_secs(25 * 60 * 60), }
133 }
134
135 fn with_config(
137 mut self,
138 batch_dir: &str,
139 max_batch_size: usize,
140 poll_interval: Duration,
141 max_wait_time: Duration,
142 ) -> Self {
143 self.batch_dir = batch_dir.to_string();
144 self.max_batch_size = max_batch_size;
145 self.poll_interval = poll_interval;
146 self.max_wait_time = max_wait_time;
147 self
148 }
149
150 async fn process_batch_requests(
152 &self,
153 requests: Vec<BatchRequest>,
154 batch_name: &str,
155 ) -> Result<Vec<BatchProcessingResult>> {
156 info!("Starting batch processing for {} requests", requests.len());
157
158 fs::create_dir_all(&self.batch_dir).map_err(|e| {
160 Error::InvalidRequest(format!("Failed to create batch directory: {}", e))
161 })?;
162
163 let batches = self.split_into_batches(requests);
165 let mut all_results = Vec::new();
166
167 for (batch_idx, batch_requests) in batches.into_iter().enumerate() {
168 let batch_id = format!("{}_batch_{}", batch_name, batch_idx);
169 info!(
170 "Processing batch {}/{}: {} requests",
171 batch_idx + 1,
172 batch_requests.len(),
173 batch_requests.len()
174 );
175
176 let results = self.process_single_batch(batch_requests, &batch_id).await?;
177 all_results.extend(results);
178 }
179
180 info!(
181 "Completed batch processing with {} total results",
182 all_results.len()
183 );
184 Ok(all_results)
185 }
186
187 async fn process_single_batch(
189 &self,
190 requests: Vec<BatchRequest>,
191 batch_id: &str,
192 ) -> Result<Vec<BatchProcessingResult>> {
193 let input_file_path = format!("{}/{}_input.jsonl", self.batch_dir, batch_id);
195 self.create_batch_file(&requests, &input_file_path)?;
196
197 let file_upload_result = self.upload_batch_file(&input_file_path).await?;
199 info!("Uploaded batch file with ID: {}", file_upload_result.id);
200
201 let batch_job = self
203 .create_batch_job(&file_upload_result.id, batch_id)
204 .await?;
205 info!("Created batch job with ID: {}", batch_job.id);
206
207 let completed_batch = self.monitor_batch_progress(batch_job).await?;
209
210 let results = self.download_batch_results(&completed_batch).await?;
212
213 info!(
214 "Successfully processed batch with {} results",
215 results.len()
216 );
217 Ok(results)
218 }
219
220 fn split_into_batches(&self, requests: Vec<BatchRequest>) -> Vec<Vec<BatchRequest>> {
222 requests
223 .chunks(self.max_batch_size)
224 .map(|chunk| chunk.to_vec())
225 .collect()
226 }
227
228 fn create_batch_file(&self, requests: &[BatchRequest], file_path: &str) -> Result<()> {
230 let mut content = String::new();
231 for request in requests {
232 let json_line = serde_json::to_string(request).map_err(|e| {
233 Error::InvalidRequest(format!("Failed to serialize request: {}", e))
234 })?;
235 content.push_str(&json_line);
236 content.push('\n');
237 }
238
239 fs::write(file_path, content)
240 .map_err(|e| Error::InvalidRequest(format!("Failed to write batch file: {}", e)))?;
241
242 debug!(
243 "Created batch file: {} ({} requests)",
244 file_path,
245 requests.len()
246 );
247 Ok(())
248 }
249
250 async fn upload_batch_file(&self, file_path: &str) -> Result<FileUploadResult> {
252 info!("Uploading batch file: {}", file_path);
253
254 let file_id = format!("file-{}", uuid::Uuid::new_v4());
260
261 Ok(FileUploadResult {
262 id: file_id,
263 bytes: fs::metadata(file_path)
264 .map_err(|e| Error::InvalidRequest(format!("Failed to get file size: {}", e)))?
265 .len(),
266 filename: Path::new(file_path)
267 .file_name()
268 .unwrap()
269 .to_string_lossy()
270 .to_string(),
271 })
272 }
273
274 async fn create_batch_job(&self, input_file_id: &str, _batch_name: &str) -> Result<BatchJob> {
276 info!("Creating batch job for file: {}", input_file_id);
277
278 let batch_id = format!("batch_{}", uuid::Uuid::new_v4());
284
285 Ok(BatchJob {
286 id: batch_id,
287 status: "validating".to_string(),
288 input_file_id: input_file_id.to_string(),
289 output_file_id: None,
290 error_file_id: None,
291 created_at: chrono::Utc::now().timestamp(),
292 completed_at: None,
293 request_counts: BatchRequestCounts {
294 total: 0,
295 completed: 0,
296 failed: 0,
297 },
298 })
299 }
300
301 async fn monitor_batch_progress(&self, mut batch_job: BatchJob) -> Result<BatchJob> {
303 let start_time = Instant::now();
304
305 loop {
306 if start_time.elapsed() > self.max_wait_time {
308 return Err(Error::InvalidRequest(format!(
309 "Batch processing timed out after {:?}",
310 self.max_wait_time
311 )));
312 }
313
314 batch_job = self.get_batch_status(&batch_job.id).await?;
316
317 match batch_job.status.as_str() {
318 "completed" => {
319 info!("Batch {} completed successfully", batch_job.id);
320 return Ok(batch_job);
321 }
322 "failed" | "expired" | "cancelled" => {
323 return Err(Error::InvalidRequest(format!(
324 "Batch {} failed with status: {}",
325 batch_job.id, batch_job.status
326 )));
327 }
328 "validating" | "in_progress" | "finalizing" => {
329 info!(
330 "Batch {} status: {} ({}s elapsed)",
331 batch_job.id,
332 batch_job.status,
333 start_time.elapsed().as_secs()
334 );
335 }
336 _ => {
337 warn!("Unknown batch status: {}", batch_job.status);
338 }
339 }
340
341 sleep(self.poll_interval).await;
343 }
344 }
345
346 async fn get_batch_status(&self, batch_id: &str) -> Result<BatchJob> {
348 debug!("Checking status for batch: {}", batch_id);
349
350 let current_time = chrono::Utc::now().timestamp();
357
358 Ok(BatchJob {
359 id: batch_id.to_string(),
360 status: "completed".to_string(), input_file_id: format!("file-input-{}", batch_id),
362 output_file_id: Some(format!("file-output-{}", batch_id)),
363 error_file_id: None,
364 created_at: current_time - 3600, completed_at: Some(current_time),
366 request_counts: BatchRequestCounts {
367 total: 100,
368 completed: 98,
369 failed: 2,
370 },
371 })
372 }
373
374 async fn download_batch_results(
376 &self,
377 batch_job: &BatchJob,
378 ) -> Result<Vec<BatchProcessingResult>> {
379 let output_file_id = batch_job
380 .output_file_id
381 .as_ref()
382 .ok_or_else(|| Error::InvalidRequest("No output file available".to_string()))?;
383
384 info!("Downloading results from file: {}", output_file_id);
385
386 let mut results = Vec::new();
392 for i in 0..batch_job.request_counts.completed {
393 results.push(BatchProcessingResult {
394 custom_id: format!("request_{}", i),
395 status: "completed".to_string(),
396 response: Some(BatchResponseData {
397 id: format!("chatcmpl_{}", uuid::Uuid::new_v4()),
398 object: "chat.completion".to_string(),
399 model: "gpt-3.5-turbo".to_string(),
400 choices: vec![BatchChoiceData {
401 index: 0,
402 message: BatchMessageData {
403 role: "assistant".to_string(),
404 content: format!("This is a sample response for request {}", i),
405 },
406 finish_reason: "stop".to_string(),
407 }],
408 usage: BatchUsageData {
409 prompt_tokens: 20,
410 completion_tokens: 15,
411 total_tokens: 35,
412 },
413 }),
414 error: None,
415 });
416 }
417
418 for i in 0..batch_job.request_counts.failed {
420 results.push(BatchProcessingResult {
421 custom_id: format!("failed_request_{}", i),
422 status: "failed".to_string(),
423 response: None,
424 error: Some(BatchErrorData {
425 code: "rate_limit_exceeded".to_string(),
426 message: "Rate limit exceeded, please try again later".to_string(),
427 }),
428 });
429 }
430
431 info!("Downloaded {} batch results", results.len());
432 Ok(results)
433 }
434
435 fn calculate_cost_savings(&self, results: &[BatchProcessingResult]) -> CostAnalysis {
437 let successful_requests = results.iter().filter(|r| r.response.is_some()).count();
438 let total_tokens: i32 = results
439 .iter()
440 .filter_map(|r| r.response.as_ref())
441 .map(|resp| resp.usage.total_tokens)
442 .sum();
443
444 let synchronous_cost = total_tokens as f64 * 0.002; let batch_cost = synchronous_cost * 0.5; let savings = synchronous_cost - batch_cost;
448
449 CostAnalysis {
450 successful_requests,
451 total_tokens,
452 synchronous_cost,
453 batch_cost,
454 savings,
455 savings_percentage: (savings / synchronous_cost) * 100.0,
456 }
457 }
458}
459
460#[derive(Debug)]
462struct FileUploadResult {
463 id: String,
464 bytes: u64,
465 filename: String,
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize)]
470struct BatchProcessingResult {
471 custom_id: String,
472 status: String,
473 response: Option<BatchResponseData>,
474 error: Option<BatchErrorData>,
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize)]
479struct BatchResponseData {
480 id: String,
481 object: String,
482 model: String,
483 choices: Vec<BatchChoiceData>,
484 usage: BatchUsageData,
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
489struct BatchChoiceData {
490 index: i32,
491 message: BatchMessageData,
492 finish_reason: String,
493}
494
495#[derive(Debug, Clone, Serialize, Deserialize)]
497struct BatchMessageData {
498 role: String,
499 content: String,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
504struct BatchUsageData {
505 prompt_tokens: i32,
506 completion_tokens: i32,
507 total_tokens: i32,
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512struct BatchErrorData {
513 code: String,
514 message: String,
515}
516
517#[derive(Debug)]
519struct CostAnalysis {
520 successful_requests: usize,
521 total_tokens: i32,
522 synchronous_cost: f64,
523 batch_cost: f64,
524 savings: f64,
525 savings_percentage: f64,
526}
527
528struct BatchRequestGenerator;
530
531impl BatchRequestGenerator {
532 fn generate_summarization_requests(contents: Vec<String>) -> Vec<BatchRequest> {
534 contents
535 .into_iter()
536 .enumerate()
537 .map(|(idx, content)| BatchRequest {
538 custom_id: format!("summarize_{}", idx),
539 method: "POST".to_string(),
540 url: "/v1/chat/completions".to_string(),
541 body: BatchRequestBody {
542 model: "gpt-3.5-turbo".to_string(),
543 messages: vec![
544 ChatMessage {
545 role: "system".to_string(),
546 content: "You are a helpful assistant that creates concise summaries."
547 .to_string(),
548 },
549 ChatMessage {
550 role: "user".to_string(),
551 content: format!(
552 "Please summarize the following text in 2-3 sentences:\n\n{}",
553 content
554 ),
555 },
556 ],
557 max_tokens: Some(150),
558 temperature: Some(0.3),
559 },
560 })
561 .collect()
562 }
563
564 fn generate_sentiment_requests(texts: Vec<String>) -> Vec<BatchRequest> {
566 texts
567 .into_iter()
568 .enumerate()
569 .map(|(idx, text)| BatchRequest {
570 custom_id: format!("sentiment_{}", idx),
571 method: "POST".to_string(),
572 url: "/v1/chat/completions".to_string(),
573 body: BatchRequestBody {
574 model: "gpt-3.5-turbo".to_string(),
575 messages: vec![
576 ChatMessage {
577 role: "system".to_string(),
578 content: "Analyze the sentiment of the given text. Respond with only: POSITIVE, NEGATIVE, or NEUTRAL.".to_string(),
579 },
580 ChatMessage {
581 role: "user".to_string(),
582 content: text,
583 },
584 ],
585 max_tokens: Some(10),
586 temperature: Some(0.0),
587 },
588 })
589 .collect()
590 }
591
592 fn generate_translation_requests(
594 texts: Vec<String>,
595 target_language: &str,
596 ) -> Vec<BatchRequest> {
597 texts
598 .into_iter()
599 .enumerate()
600 .map(|(idx, text)| BatchRequest {
601 custom_id: format!("translate_{}_{}", target_language, idx),
602 method: "POST".to_string(),
603 url: "/v1/chat/completions".to_string(),
604 body: BatchRequestBody {
605 model: "gpt-3.5-turbo".to_string(),
606 messages: vec![
607 ChatMessage {
608 role: "system".to_string(),
609 content: format!(
610 "Translate the following text to {}. Provide only the translation.",
611 target_language
612 ),
613 },
614 ChatMessage {
615 role: "user".to_string(),
616 content: text,
617 },
618 ],
619 max_tokens: Some(200),
620 temperature: Some(0.3),
621 },
622 })
623 .collect()
624 }
625}
626
627#[tokio::main]
628async fn main() -> Result<()> {
629 tracing_subscriber::fmt()
631 .with_env_filter(
632 tracing_subscriber::EnvFilter::try_from_default_env()
633 .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
634 )
635 .init();
636
637 info!("Starting batch processing example");
638
639 let client = Client::from_env()?.build();
641
642 let batch_processor = BatchProcessor::new(client).with_config(
644 "./batch_results", 1000, Duration::from_secs(10), Duration::from_secs(30 * 60), );
649
650 info!("=== Example 1: Content Summarization Batch ===");
652
653 let content_samples = vec![
654 "Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans and animals. Leading AI textbooks define the field as the study of \"intelligent agents\": any device that perceives its environment and takes actions that maximize its chance of successfully achieving its goals.".to_string(),
655 "Machine learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.".to_string(),
656 "Deep learning is part of a broader family of machine learning methods based on artificial neural networks with representation learning. Learning can be supervised, semi-supervised or unsupervised.".to_string(),
657 ];
658
659 let summarization_requests =
660 BatchRequestGenerator::generate_summarization_requests(content_samples);
661
662 match batch_processor
663 .process_batch_requests(summarization_requests, "content_summarization")
664 .await
665 {
666 Ok(results) => {
667 info!(
668 "Summarization batch completed with {} results",
669 results.len()
670 );
671
672 for result in &results {
674 match &result.response {
675 Some(response) => {
676 if let Some(choice) = response.choices.first() {
677 info!(
678 "Summary for {}: {}",
679 result.custom_id, choice.message.content
680 );
681 }
682 }
683 None => {
684 if let Some(error) = &result.error {
685 error!(
686 "Failed {}: {} - {}",
687 result.custom_id, error.code, error.message
688 );
689 }
690 }
691 }
692 }
693
694 let cost_analysis = batch_processor.calculate_cost_savings(&results);
696 info!("Cost Analysis:");
697 info!(
698 " Successful requests: {}",
699 cost_analysis.successful_requests
700 );
701 info!(" Total tokens: {}", cost_analysis.total_tokens);
702 info!(" Synchronous cost: ${:.4}", cost_analysis.synchronous_cost);
703 info!(" Batch cost: ${:.4}", cost_analysis.batch_cost);
704 info!(
705 " Savings: ${:.4} ({:.1}%)",
706 cost_analysis.savings, cost_analysis.savings_percentage
707 );
708 }
709 Err(e) => {
710 error!("Summarization batch failed: {}", e);
711 }
712 }
713
714 info!("\n=== Example 2: Sentiment Analysis Batch ===");
716
717 let sentiment_texts = vec![
718 "I absolutely love this product! It exceeded all my expectations.".to_string(),
719 "The service was terrible and I'm very disappointed.".to_string(),
720 "It's an okay product, nothing special but gets the job done.".to_string(),
721 "Outstanding quality and amazing customer support!".to_string(),
722 "Not worth the money, poor build quality.".to_string(),
723 ];
724
725 let sentiment_requests = BatchRequestGenerator::generate_sentiment_requests(sentiment_texts);
726
727 match batch_processor
728 .process_batch_requests(sentiment_requests, "sentiment_analysis")
729 .await
730 {
731 Ok(results) => {
732 info!(
733 "Sentiment analysis batch completed with {} results",
734 results.len()
735 );
736
737 let mut sentiment_counts = HashMap::new();
738 for result in &results {
739 if let Some(response) = &result.response {
740 if let Some(choice) = response.choices.first() {
741 let sentiment = choice.message.content.trim();
742 *sentiment_counts.entry(sentiment.to_string()).or_insert(0) += 1;
743 info!("Sentiment for {}: {}", result.custom_id, sentiment);
744 }
745 }
746 }
747
748 info!("Sentiment Distribution:");
749 for (sentiment, count) in sentiment_counts {
750 info!(" {}: {} occurrences", sentiment, count);
751 }
752 }
753 Err(e) => {
754 error!("Sentiment analysis batch failed: {}", e);
755 }
756 }
757
758 info!("\n=== Example 3: Translation Batch ===");
760
761 let english_texts = vec![
762 "Hello, how are you today?".to_string(),
763 "Thank you for your help.".to_string(),
764 "The weather is beautiful today.".to_string(),
765 ];
766
767 let translation_requests =
768 BatchRequestGenerator::generate_translation_requests(english_texts, "Spanish");
769
770 match batch_processor
771 .process_batch_requests(translation_requests, "translation")
772 .await
773 {
774 Ok(results) => {
775 info!("Translation batch completed with {} results", results.len());
776
777 for result in &results {
778 if let Some(response) = &result.response {
779 if let Some(choice) = response.choices.first() {
780 info!(
781 "Translation for {}: {}",
782 result.custom_id, choice.message.content
783 );
784 }
785 }
786 }
787 }
788 Err(e) => {
789 error!("Translation batch failed: {}", e);
790 }
791 }
792
793 info!("\n=== Example 4: Concurrent Batch Processing ===");
795
796 let small_batch_1 = BatchRequestGenerator::generate_sentiment_requests(vec![
797 "Great product!".to_string(),
798 "Could be better.".to_string(),
799 ]);
800
801 let small_batch_2 = BatchRequestGenerator::generate_summarization_requests(vec![
802 "Short text to summarize.".to_string(),
803 ]);
804
805 let batch_1_future =
807 batch_processor.process_batch_requests(small_batch_1, "concurrent_batch_1");
808 let batch_2_future =
809 batch_processor.process_batch_requests(small_batch_2, "concurrent_batch_2");
810
811 let (result_1, result_2) = tokio::try_join!(batch_1_future, batch_2_future)?;
812
813 info!(
814 "Concurrent batch 1 completed with {} results",
815 result_1.len()
816 );
817 info!(
818 "Concurrent batch 2 completed with {} results",
819 result_2.len()
820 );
821
822 info!("Batch processing example completed successfully!");
823 Ok(())
824}
825
826mod uuid {
828 pub struct Uuid;
830
831 impl Uuid {
832 pub fn new_v4() -> String {
833 use std::time::{SystemTime, UNIX_EPOCH};
834 let timestamp = SystemTime::now()
835 .duration_since(UNIX_EPOCH)
836 .unwrap()
837 .as_nanos();
838 format!("uuid-{:x}", timestamp)
839 }
840 }
841}
842
843mod chrono {
845 pub struct Utc;
846
847 impl Utc {
848 pub fn now() -> DateTime {
849 DateTime
850 }
851 }
852
853 pub struct DateTime;
854
855 impl DateTime {
856 pub fn timestamp(&self) -> i64 {
857 use std::time::{SystemTime, UNIX_EPOCH};
858 SystemTime::now()
859 .duration_since(UNIX_EPOCH)
860 .unwrap()
861 .as_secs() as i64
862 }
863 }
864}