use snafu::{OptionExt, ResultExt, Snafu};
use std::{result::Result, sync::Arc};
use super::model::*;
use crate::{
GenerationResponse,
client::{Error as ClientError, GeminiClient},
files::handle::FileHandle,
};
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("batch '{name}' expired before finishing"))]
BatchExpired {
name: String,
},
#[snafu(display("batch '{name}' failed"))]
BatchFailed {
source: OperationError,
name: String,
},
#[snafu(display("client invocation error"))]
Client { source: Box<ClientError> },
#[snafu(display("failed to download batch result file '{file_name}'"))]
FileDownload { source: crate::files::Error, file_name: String },
#[snafu(display("failed to decode batch result file content as UTF-8"))]
FileDecode { source: std::string::FromUtf8Error },
#[snafu(display("failed to parse line in batch result file"))]
FileParse { source: serde_json::Error, line: String },
#[snafu(display("batch '{name}' completed but no result provided - API contract violation"))]
MissingResult {
name: String,
},
}
#[derive(Debug, Clone, PartialEq)]
pub struct BatchGenerationResponseItem {
pub response: Result<GenerationResponse, IndividualRequestError>,
pub meta: RequestMetadata,
}
#[derive(Debug, Clone, PartialEq)]
pub enum BatchStatus {
Pending,
Running { pending_count: i64, completed_count: i64, failed_count: i64, total_count: i64 },
Succeeded { results: Vec<BatchGenerationResponseItem> },
Cancelled,
Expired,
}
impl BatchStatus {
async fn parse_response_file(
response_file: crate::files::model::File,
client: Arc<GeminiClient>,
) -> Result<Vec<BatchGenerationResponseItem>, Error> {
let file = FileHandle::new(client.clone(), response_file);
let file_content_bytes =
file.download().await.context(FileDownloadSnafu { file_name: file.name() })?;
let file_content = String::from_utf8(file_content_bytes).context(FileDecodeSnafu)?;
let mut results = vec![];
for line in file_content.lines() {
if line.trim().is_empty() {
continue;
}
let item: BatchResponseFileItem =
serde_json::from_str(line).context(FileParseSnafu { line: line.to_string() })?;
results.push(BatchGenerationResponseItem {
response: item.response.into(),
meta: RequestMetadata { key: item.key },
});
}
Ok(results)
}
async fn process_successful_response(
response: BatchOperationResponse,
client: Arc<GeminiClient>,
) -> Result<Vec<BatchGenerationResponseItem>, Error> {
let results = match response {
BatchOperationResponse::InlinedResponses { inlined_responses } => inlined_responses
.inlined_responses
.into_iter()
.map(|item| BatchGenerationResponseItem {
response: item.result.into(),
meta: item.metadata,
})
.collect(),
BatchOperationResponse::ResponsesFile { responses_file } => {
let file = crate::files::model::File { name: responses_file, ..Default::default() };
Self::parse_response_file(file, client).await?
}
};
Ok(results)
}
async fn from_operation(
operation: BatchOperation,
client: Arc<GeminiClient>,
) -> Result<Self, Error> {
if operation.done {
let result =
operation.result.context(MissingResultSnafu { name: operation.name.clone() })?;
let response =
Result::from(result).context(BatchFailedSnafu { name: operation.name })?;
let mut results = Self::process_successful_response(response, client).await?;
results.sort_by_key(|r| r.meta.key);
match operation.metadata.state {
BatchState::BatchStateCancelled => Ok(BatchStatus::Cancelled),
BatchState::BatchStateExpired => Ok(BatchStatus::Expired),
_ => Ok(BatchStatus::Succeeded { results }),
}
} else {
match operation.metadata.state {
BatchState::BatchStatePending => Ok(BatchStatus::Pending),
BatchState::BatchStateRunning => {
let total_count = operation.metadata.batch_stats.request_count;
let pending_count =
operation.metadata.batch_stats.pending_request_count.unwrap_or(total_count);
let completed_count =
operation.metadata.batch_stats.completed_request_count.unwrap_or(0);
let failed_count =
operation.metadata.batch_stats.failed_request_count.unwrap_or(0);
Ok(BatchStatus::Running {
pending_count,
completed_count,
failed_count,
total_count,
})
}
_ => Ok(BatchStatus::Pending),
}
}
}
}
pub struct BatchHandle {
pub name: String,
client: Arc<GeminiClient>,
}
impl BatchHandle {
pub(crate) fn new(name: String, client: Arc<GeminiClient>) -> Self {
Self { name, client }
}
pub fn name(&self) -> &str {
&self.name
}
pub async fn status(&self) -> Result<BatchStatus, Error> {
let operation: BatchOperation = self
.client
.get_batch_operation(&self.name)
.await
.map_err(Box::new)
.context(ClientSnafu)?;
BatchStatus::from_operation(operation, self.client.clone()).await
}
pub async fn cancel(self) -> Result<(), (Self, ClientError)> {
match self.client.cancel_batch_operation(&self.name).await {
Ok(()) => Ok(()),
Err(e) => Err((self, e)),
}
}
pub async fn delete(self) -> Result<(), (Self, ClientError)> {
match self.client.delete_batch_operation(&self.name).await {
Ok(()) => Ok(()),
Err(e) => Err((self, e)),
}
}
}