#![allow(dead_code)]
#![allow(clippy::too_many_lines)]
#![allow(clippy::missing_const_for_fn)]
#![allow(clippy::unused_async)]
#![allow(clippy::unreadable_literal)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::uninlined_format_args)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::redundant_closure_for_method_calls)]
#![allow(clippy::unused_self)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::struct_field_names)]
#![allow(clippy::cast_possible_wrap)]
use openai_ergonomic::{Client, Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchRequest {
custom_id: String,
method: String,
url: String,
body: BatchRequestBody,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchRequestBody {
model: String,
messages: Vec<ChatMessage>,
max_tokens: Option<i32>,
temperature: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Clone)]
struct BatchJob {
id: String,
status: String,
input_file_id: String,
output_file_id: Option<String>,
error_file_id: Option<String>,
created_at: i64,
completed_at: Option<i64>,
request_counts: BatchRequestCounts,
}
#[derive(Debug, Clone)]
struct BatchRequestCounts {
total: i32,
completed: i32,
failed: i32,
}
#[derive(Debug)]
struct BatchProcessor {
client: Client,
batch_dir: String,
max_batch_size: usize,
poll_interval: Duration,
max_wait_time: Duration,
}
impl BatchProcessor {
fn new(client: Client) -> Self {
Self {
client,
batch_dir: "./batch_files".to_string(),
max_batch_size: 50_000,
poll_interval: Duration::from_secs(30),
max_wait_time: Duration::from_secs(25 * 60 * 60), }
}
fn with_config(
mut self,
batch_dir: &str,
max_batch_size: usize,
poll_interval: Duration,
max_wait_time: Duration,
) -> Self {
self.batch_dir = batch_dir.to_string();
self.max_batch_size = max_batch_size;
self.poll_interval = poll_interval;
self.max_wait_time = max_wait_time;
self
}
async fn process_batch_requests(
&self,
requests: Vec<BatchRequest>,
batch_name: &str,
) -> Result<Vec<BatchProcessingResult>> {
info!("Starting batch processing for {} requests", requests.len());
fs::create_dir_all(&self.batch_dir).map_err(|e| {
Error::InvalidRequest(format!("Failed to create batch directory: {}", e))
})?;
let batches = self.split_into_batches(requests);
let mut all_results = Vec::new();
for (batch_idx, batch_requests) in batches.into_iter().enumerate() {
let batch_id = format!("{}_batch_{}", batch_name, batch_idx);
info!(
"Processing batch {}/{}: {} requests",
batch_idx + 1,
batch_requests.len(),
batch_requests.len()
);
let results = self.process_single_batch(batch_requests, &batch_id).await?;
all_results.extend(results);
}
info!(
"Completed batch processing with {} total results",
all_results.len()
);
Ok(all_results)
}
async fn process_single_batch(
&self,
requests: Vec<BatchRequest>,
batch_id: &str,
) -> Result<Vec<BatchProcessingResult>> {
let input_file_path = format!("{}/{}_input.jsonl", self.batch_dir, batch_id);
self.create_batch_file(&requests, &input_file_path)?;
let file_upload_result = self.upload_batch_file(&input_file_path).await?;
info!("Uploaded batch file with ID: {}", file_upload_result.id);
let batch_job = self
.create_batch_job(&file_upload_result.id, batch_id)
.await?;
info!("Created batch job with ID: {}", batch_job.id);
let completed_batch = self.monitor_batch_progress(batch_job).await?;
let results = self.download_batch_results(&completed_batch).await?;
info!(
"Successfully processed batch with {} results",
results.len()
);
Ok(results)
}
fn split_into_batches(&self, requests: Vec<BatchRequest>) -> Vec<Vec<BatchRequest>> {
requests
.chunks(self.max_batch_size)
.map(|chunk| chunk.to_vec())
.collect()
}
fn create_batch_file(&self, requests: &[BatchRequest], file_path: &str) -> Result<()> {
let mut content = String::new();
for request in requests {
let json_line = serde_json::to_string(request).map_err(|e| {
Error::InvalidRequest(format!("Failed to serialize request: {}", e))
})?;
content.push_str(&json_line);
content.push('\n');
}
fs::write(file_path, content)
.map_err(|e| Error::InvalidRequest(format!("Failed to write batch file: {}", e)))?;
debug!(
"Created batch file: {} ({} requests)",
file_path,
requests.len()
);
Ok(())
}
async fn upload_batch_file(&self, file_path: &str) -> Result<FileUploadResult> {
info!("Uploading batch file: {}", file_path);
let file_id = format!("file-{}", uuid::Uuid::new_v4());
Ok(FileUploadResult {
id: file_id,
bytes: fs::metadata(file_path)
.map_err(|e| Error::InvalidRequest(format!("Failed to get file size: {}", e)))?
.len(),
filename: Path::new(file_path)
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
})
}
async fn create_batch_job(&self, input_file_id: &str, _batch_name: &str) -> Result<BatchJob> {
info!("Creating batch job for file: {}", input_file_id);
let batch_id = format!("batch_{}", uuid::Uuid::new_v4());
Ok(BatchJob {
id: batch_id,
status: "validating".to_string(),
input_file_id: input_file_id.to_string(),
output_file_id: None,
error_file_id: None,
created_at: chrono::Utc::now().timestamp(),
completed_at: None,
request_counts: BatchRequestCounts {
total: 0,
completed: 0,
failed: 0,
},
})
}
async fn monitor_batch_progress(&self, mut batch_job: BatchJob) -> Result<BatchJob> {
let start_time = Instant::now();
loop {
if start_time.elapsed() > self.max_wait_time {
return Err(Error::InvalidRequest(format!(
"Batch processing timed out after {:?}",
self.max_wait_time
)));
}
batch_job = self.get_batch_status(&batch_job.id).await?;
match batch_job.status.as_str() {
"completed" => {
info!("Batch {} completed successfully", batch_job.id);
return Ok(batch_job);
}
"failed" | "expired" | "cancelled" => {
return Err(Error::InvalidRequest(format!(
"Batch {} failed with status: {}",
batch_job.id, batch_job.status
)));
}
"validating" | "in_progress" | "finalizing" => {
info!(
"Batch {} status: {} ({}s elapsed)",
batch_job.id,
batch_job.status,
start_time.elapsed().as_secs()
);
}
_ => {
warn!("Unknown batch status: {}", batch_job.status);
}
}
sleep(self.poll_interval).await;
}
}
async fn get_batch_status(&self, batch_id: &str) -> Result<BatchJob> {
debug!("Checking status for batch: {}", batch_id);
let current_time = chrono::Utc::now().timestamp();
Ok(BatchJob {
id: batch_id.to_string(),
status: "completed".to_string(), input_file_id: format!("file-input-{}", batch_id),
output_file_id: Some(format!("file-output-{}", batch_id)),
error_file_id: None,
created_at: current_time - 3600, completed_at: Some(current_time),
request_counts: BatchRequestCounts {
total: 100,
completed: 98,
failed: 2,
},
})
}
async fn download_batch_results(
&self,
batch_job: &BatchJob,
) -> Result<Vec<BatchProcessingResult>> {
let output_file_id = batch_job
.output_file_id
.as_ref()
.ok_or_else(|| Error::InvalidRequest("No output file available".to_string()))?;
info!("Downloading results from file: {}", output_file_id);
let mut results = Vec::new();
for i in 0..batch_job.request_counts.completed {
results.push(BatchProcessingResult {
custom_id: format!("request_{}", i),
status: "completed".to_string(),
response: Some(BatchResponseData {
id: format!("chatcmpl_{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
model: "gpt-3.5-turbo".to_string(),
choices: vec![BatchChoiceData {
index: 0,
message: BatchMessageData {
role: "assistant".to_string(),
content: format!("This is a sample response for request {}", i),
},
finish_reason: "stop".to_string(),
}],
usage: BatchUsageData {
prompt_tokens: 20,
completion_tokens: 15,
total_tokens: 35,
},
}),
error: None,
});
}
for i in 0..batch_job.request_counts.failed {
results.push(BatchProcessingResult {
custom_id: format!("failed_request_{}", i),
status: "failed".to_string(),
response: None,
error: Some(BatchErrorData {
code: "rate_limit_exceeded".to_string(),
message: "Rate limit exceeded, please try again later".to_string(),
}),
});
}
info!("Downloaded {} batch results", results.len());
Ok(results)
}
fn calculate_cost_savings(&self, results: &[BatchProcessingResult]) -> CostAnalysis {
let successful_requests = results.iter().filter(|r| r.response.is_some()).count();
let total_tokens: i32 = results
.iter()
.filter_map(|r| r.response.as_ref())
.map(|resp| resp.usage.total_tokens)
.sum();
let synchronous_cost = total_tokens as f64 * 0.002; let batch_cost = synchronous_cost * 0.5; let savings = synchronous_cost - batch_cost;
CostAnalysis {
successful_requests,
total_tokens,
synchronous_cost,
batch_cost,
savings,
savings_percentage: (savings / synchronous_cost) * 100.0,
}
}
}
#[derive(Debug)]
struct FileUploadResult {
id: String,
bytes: u64,
filename: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchProcessingResult {
custom_id: String,
status: String,
response: Option<BatchResponseData>,
error: Option<BatchErrorData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchResponseData {
id: String,
object: String,
model: String,
choices: Vec<BatchChoiceData>,
usage: BatchUsageData,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchChoiceData {
index: i32,
message: BatchMessageData,
finish_reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchMessageData {
role: String,
content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchUsageData {
prompt_tokens: i32,
completion_tokens: i32,
total_tokens: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BatchErrorData {
code: String,
message: String,
}
#[derive(Debug)]
struct CostAnalysis {
successful_requests: usize,
total_tokens: i32,
synchronous_cost: f64,
batch_cost: f64,
savings: f64,
savings_percentage: f64,
}
struct BatchRequestGenerator;
impl BatchRequestGenerator {
fn generate_summarization_requests(contents: Vec<String>) -> Vec<BatchRequest> {
contents
.into_iter()
.enumerate()
.map(|(idx, content)| BatchRequest {
custom_id: format!("summarize_{}", idx),
method: "POST".to_string(),
url: "/v1/chat/completions".to_string(),
body: BatchRequestBody {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: "You are a helpful assistant that creates concise summaries."
.to_string(),
},
ChatMessage {
role: "user".to_string(),
content: format!(
"Please summarize the following text in 2-3 sentences:\n\n{}",
content
),
},
],
max_tokens: Some(150),
temperature: Some(0.3),
},
})
.collect()
}
fn generate_sentiment_requests(texts: Vec<String>) -> Vec<BatchRequest> {
texts
.into_iter()
.enumerate()
.map(|(idx, text)| BatchRequest {
custom_id: format!("sentiment_{}", idx),
method: "POST".to_string(),
url: "/v1/chat/completions".to_string(),
body: BatchRequestBody {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: "Analyze the sentiment of the given text. Respond with only: POSITIVE, NEGATIVE, or NEUTRAL.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: text,
},
],
max_tokens: Some(10),
temperature: Some(0.0),
},
})
.collect()
}
fn generate_translation_requests(
texts: Vec<String>,
target_language: &str,
) -> Vec<BatchRequest> {
texts
.into_iter()
.enumerate()
.map(|(idx, text)| BatchRequest {
custom_id: format!("translate_{}_{}", target_language, idx),
method: "POST".to_string(),
url: "/v1/chat/completions".to_string(),
body: BatchRequestBody {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: format!(
"Translate the following text to {}. Provide only the translation.",
target_language
),
},
ChatMessage {
role: "user".to_string(),
content: text,
},
],
max_tokens: Some(200),
temperature: Some(0.3),
},
})
.collect()
}
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
info!("Starting batch processing example");
let client = Client::from_env()?.build();
let batch_processor = BatchProcessor::new(client).with_config(
"./batch_results", 1000, Duration::from_secs(10), Duration::from_secs(30 * 60), );
info!("=== Example 1: Content Summarization Batch ===");
let content_samples = vec![
"Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans and animals. Leading AI textbooks define the field as the study of \"intelligent agents\": any device that perceives its environment and takes actions that maximize its chance of successfully achieving its goals.".to_string(),
"Machine learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.".to_string(),
"Deep learning is part of a broader family of machine learning methods based on artificial neural networks with representation learning. Learning can be supervised, semi-supervised or unsupervised.".to_string(),
];
let summarization_requests =
BatchRequestGenerator::generate_summarization_requests(content_samples);
match batch_processor
.process_batch_requests(summarization_requests, "content_summarization")
.await
{
Ok(results) => {
info!(
"Summarization batch completed with {} results",
results.len()
);
for result in &results {
match &result.response {
Some(response) => {
if let Some(choice) = response.choices.first() {
info!(
"Summary for {}: {}",
result.custom_id, choice.message.content
);
}
}
None => {
if let Some(error) = &result.error {
error!(
"Failed {}: {} - {}",
result.custom_id, error.code, error.message
);
}
}
}
}
let cost_analysis = batch_processor.calculate_cost_savings(&results);
info!("Cost Analysis:");
info!(
" Successful requests: {}",
cost_analysis.successful_requests
);
info!(" Total tokens: {}", cost_analysis.total_tokens);
info!(" Synchronous cost: ${:.4}", cost_analysis.synchronous_cost);
info!(" Batch cost: ${:.4}", cost_analysis.batch_cost);
info!(
" Savings: ${:.4} ({:.1}%)",
cost_analysis.savings, cost_analysis.savings_percentage
);
}
Err(e) => {
error!("Summarization batch failed: {}", e);
}
}
info!("\n=== Example 2: Sentiment Analysis Batch ===");
let sentiment_texts = vec![
"I absolutely love this product! It exceeded all my expectations.".to_string(),
"The service was terrible and I'm very disappointed.".to_string(),
"It's an okay product, nothing special but gets the job done.".to_string(),
"Outstanding quality and amazing customer support!".to_string(),
"Not worth the money, poor build quality.".to_string(),
];
let sentiment_requests = BatchRequestGenerator::generate_sentiment_requests(sentiment_texts);
match batch_processor
.process_batch_requests(sentiment_requests, "sentiment_analysis")
.await
{
Ok(results) => {
info!(
"Sentiment analysis batch completed with {} results",
results.len()
);
let mut sentiment_counts = HashMap::new();
for result in &results {
if let Some(response) = &result.response {
if let Some(choice) = response.choices.first() {
let sentiment = choice.message.content.trim();
*sentiment_counts.entry(sentiment.to_string()).or_insert(0) += 1;
info!("Sentiment for {}: {}", result.custom_id, sentiment);
}
}
}
info!("Sentiment Distribution:");
for (sentiment, count) in sentiment_counts {
info!(" {}: {} occurrences", sentiment, count);
}
}
Err(e) => {
error!("Sentiment analysis batch failed: {}", e);
}
}
info!("\n=== Example 3: Translation Batch ===");
let english_texts = vec![
"Hello, how are you today?".to_string(),
"Thank you for your help.".to_string(),
"The weather is beautiful today.".to_string(),
];
let translation_requests =
BatchRequestGenerator::generate_translation_requests(english_texts, "Spanish");
match batch_processor
.process_batch_requests(translation_requests, "translation")
.await
{
Ok(results) => {
info!("Translation batch completed with {} results", results.len());
for result in &results {
if let Some(response) = &result.response {
if let Some(choice) = response.choices.first() {
info!(
"Translation for {}: {}",
result.custom_id, choice.message.content
);
}
}
}
}
Err(e) => {
error!("Translation batch failed: {}", e);
}
}
info!("\n=== Example 4: Concurrent Batch Processing ===");
let small_batch_1 = BatchRequestGenerator::generate_sentiment_requests(vec![
"Great product!".to_string(),
"Could be better.".to_string(),
]);
let small_batch_2 = BatchRequestGenerator::generate_summarization_requests(vec![
"Short text to summarize.".to_string(),
]);
let batch_1_future =
batch_processor.process_batch_requests(small_batch_1, "concurrent_batch_1");
let batch_2_future =
batch_processor.process_batch_requests(small_batch_2, "concurrent_batch_2");
let (result_1, result_2) = tokio::try_join!(batch_1_future, batch_2_future)?;
info!(
"Concurrent batch 1 completed with {} results",
result_1.len()
);
info!(
"Concurrent batch 2 completed with {} results",
result_2.len()
);
info!("Batch processing example completed successfully!");
Ok(())
}
mod uuid {
pub struct Uuid;
impl Uuid {
pub fn new_v4() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
format!("uuid-{:x}", timestamp)
}
}
}
mod chrono {
pub struct Utc;
impl Utc {
pub fn now() -> DateTime {
DateTime
}
}
pub struct DateTime;
impl DateTime {
pub fn timestamp(&self) -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
}
}