gemini_rust/batch/handle.rs
1//! The Batch module for managing batch operations.
2//!
3//! This module provides the [`BatchHandle`] struct, which is a handle to a long-running batch
4//! operation on the Gemini API. It allows for checking the status, canceling, and deleting
5//! the operation.
6//!
7//! The status of a batch operation is represented by the [`BatchStatus`] enum, which can be
8//! retrieved using the [`BatchHandle::status()`] method. When a batch completes successfully,
9//! it transitions to the [`BatchStatus::Succeeded`] state, which contains a vector of
10//! [`BatchGenerationResponseItem`].
11//!
12//! ## Batch Results
13//!
14//! The [`BatchGenerationResponseItem`] enum represents the outcome of a single request within the batch:
15//! - `Success`: Contains the generated `GenerationResponse` and the original request key.
16//! - `Error`: Contains an `IndividualRequestError` and the original request key.
17//!
18//! Results can be delivered in two ways, depending on the size of the batch job:
19//! 1. **Inlined Responses**: For smaller jobs, the results are included directly in the
20//! batch operation's metadata.
21//! 2. **Response File**: For larger jobs (typically >20MB), the results are written to a
22//! file, and the batch metadata will contain a reference to this file. The SDK
23//! handles the downloading and parsing of this file automatically when you call
24//! `status()` on a completed batch.
25//!
26//! The results are automatically sorted by their original request key (as a number) to ensure
27//! a consistent and predictable order.
28//!
29//! For more information, see the official Google AI documentation:
30//! - [Batch Mode Guide](https://ai.google.dev/gemini-api/docs/batch-mode)
31//! - [Batch API Reference](https://ai.google.dev/api/batch-mode)
32//!
33//! # Design Note: Resource Management in Batch Operations
34//!
35//! The Batch API methods that consume the [`BatchHandle`] struct (`cancel`, `delete`)
36//! return `std::result::Result<T, (Self, crate::Error)>` instead of the crate's `Result<T>`.
37//! This design follows patterns used in channel libraries (e.g., `std::sync::mpsc::Receiver`)
38//! and provides two key benefits:
39//!
40//! 1. **Resource Safety**: Once a [`BatchHandle`] is consumed by an operation, it cannot be used again,
41//! preventing invalid operations on deleted or canceled batches.
42//!
43//! 2. **Error Recovery**: If an operation fails due to transient network issues, both the
44//! [`BatchHandle`] and error information are returned, allowing callers to retry the operation.
45//!
46//! ## Example usage:
47//! ```rust,no_run
48//! use gemini_rust::{Gemini, Message};
49//!
50//! #[tokio::main]
51//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
52//! let client = Gemini::new(std::env::var("GEMINI_API_KEY")?)?;
53//! let request = client.generate_content().with_user_message("Why is the sky blue?").build();
54//! let batch = client.batch_generate_content().with_request(request).execute().await?;
55//!
56//! match batch.delete().await {
57//! Ok(()) => println!("Batch deleted successfully!"),
58//! Err((batch, error)) => {
59//! println!("Failed to delete batch: {}", error);
60//! // Can retry: batch.delete().await
61//! }
62//! }
63//! Ok(())
64//! }
65//! ```
66
67use snafu::{OptionExt, ResultExt, Snafu};
68use std::{result::Result, sync::Arc};
69
70use super::model::*;
71use crate::{
72 client::{Error as ClientError, GeminiClient},
73 files::handle::FileHandle,
74 GenerationResponse,
75};
76
77#[derive(Debug, Snafu)]
78pub enum Error {
79 #[snafu(display("batch '{name}' expired before finishing"))]
80 BatchExpired {
81 /// Batch name.
82 name: String,
83 },
84
85 #[snafu(display("batch '{name}' failed"))]
86 BatchFailed {
87 source: OperationError,
88 /// Batch name.
89 name: String,
90 },
91
92 #[snafu(display("client invocation error"))]
93 Client { source: Box<ClientError> },
94
95 #[snafu(display("failed to download batch result file '{file_name}'"))]
96 FileDownload {
97 source: crate::files::Error,
98 file_name: String,
99 },
100
101 #[snafu(display("failed to decode batch result file content as UTF-8"))]
102 FileDecode { source: std::string::FromUtf8Error },
103
104 #[snafu(display("failed to parse line in batch result file"))]
105 FileParse {
106 source: serde_json::Error,
107 line: String,
108 },
109
110 /// This error should never occur, as the Google API contract
111 /// guarantees that a result will always be provided.
112 ///
113 /// I put it here anyway to avoid potential panic in case of
114 /// Google's dishonesty or GCP internal errors.
115 #[snafu(display("batch '{name}' completed but no result provided - API contract violation"))]
116 MissingResult {
117 /// Batch name.
118 name: String,
119 },
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub struct BatchGenerationResponseItem {
124 pub response: Result<GenerationResponse, IndividualRequestError>,
125 pub meta: RequestMetadata,
126}
127
128/// Represents the overall status of a batch operation.
129#[derive(Debug, Clone, PartialEq)]
130pub enum BatchStatus {
131 /// The operation is waiting to be processed.
132 Pending,
133 /// The operation is currently being processed.
134 Running {
135 pending_count: i64,
136 completed_count: i64,
137 failed_count: i64,
138 total_count: i64,
139 },
140 /// The operation has completed successfully.
141 Succeeded {
142 results: Vec<BatchGenerationResponseItem>,
143 },
144 /// The operation was cancelled by the user.
145 Cancelled,
146 /// The operation has expired.
147 Expired,
148}
149
150impl BatchStatus {
151 async fn parse_response_file(
152 response_file: crate::files::model::File,
153 client: Arc<GeminiClient>,
154 ) -> Result<Vec<BatchGenerationResponseItem>, Error> {
155 let file = FileHandle::new(client.clone(), response_file);
156 let file_content_bytes = file.download().await.context(FileDownloadSnafu {
157 file_name: file.name(),
158 })?;
159 let file_content = String::from_utf8(file_content_bytes).context(FileDecodeSnafu)?;
160
161 let mut results = vec![];
162 for line in file_content.lines() {
163 if line.trim().is_empty() {
164 continue;
165 }
166 let item: BatchResponseFileItem =
167 serde_json::from_str(line).context(FileParseSnafu {
168 line: line.to_string(),
169 })?;
170
171 results.push(BatchGenerationResponseItem {
172 response: item.response.into(),
173 meta: RequestMetadata { key: item.key },
174 });
175 }
176 Ok(results)
177 }
178
179 async fn process_successful_response(
180 response: BatchOperationResponse,
181 client: Arc<GeminiClient>,
182 ) -> Result<Vec<BatchGenerationResponseItem>, Error> {
183 let results = match response {
184 BatchOperationResponse::InlinedResponses { inlined_responses } => inlined_responses
185 .inlined_responses
186 .into_iter()
187 .map(|item| BatchGenerationResponseItem {
188 response: item.result.into(),
189 meta: item.metadata,
190 })
191 .collect(),
192 BatchOperationResponse::ResponsesFile { responses_file } => {
193 let file = crate::files::model::File {
194 name: responses_file,
195 ..Default::default()
196 };
197 Self::parse_response_file(file, client).await?
198 }
199 };
200 Ok(results)
201 }
202
203 async fn from_operation(
204 operation: BatchOperation,
205 client: Arc<GeminiClient>,
206 ) -> Result<Self, Error> {
207 if operation.done {
208 // According to Google API documentation, when done=true, result must be present
209 let result = operation.result.context(MissingResultSnafu {
210 name: operation.name.clone(),
211 })?;
212
213 let response = Result::from(result).context(BatchFailedSnafu {
214 name: operation.name,
215 })?;
216
217 let mut results = Self::process_successful_response(response, client).await?;
218 results.sort_by_key(|r| r.meta.key);
219
220 // Handle terminal states based on metadata for edge cases
221 match operation.metadata.state {
222 BatchState::BatchStateCancelled => Ok(BatchStatus::Cancelled),
223 BatchState::BatchStateExpired => Ok(BatchStatus::Expired),
224 _ => Ok(BatchStatus::Succeeded { results }),
225 }
226 } else {
227 // The operation is still in progress.
228 match operation.metadata.state {
229 BatchState::BatchStatePending => Ok(BatchStatus::Pending),
230 BatchState::BatchStateRunning => {
231 let total_count = operation.metadata.batch_stats.request_count;
232 let pending_count = operation
233 .metadata
234 .batch_stats
235 .pending_request_count
236 .unwrap_or(total_count);
237 let completed_count = operation
238 .metadata
239 .batch_stats
240 .completed_request_count
241 .unwrap_or(0);
242 let failed_count = operation
243 .metadata
244 .batch_stats
245 .failed_request_count
246 .unwrap_or(0);
247 Ok(BatchStatus::Running {
248 pending_count,
249 completed_count,
250 failed_count,
251 total_count,
252 })
253 }
254 // For non-running states when done=false, treat as pending
255 _ => Ok(BatchStatus::Pending),
256 }
257 }
258 }
259}
260
261/// Represents a long-running batch operation, providing methods to manage its lifecycle.
262///
263/// A `Batch` object is a handle to a batch operation on the Gemini API. It allows you to
264/// check the status, cancel the operation, or delete it once it's no longer needed.
265pub struct BatchHandle {
266 /// The unique resource name of the batch operation, e.g., `operations/batch-xxxxxxxx`.
267 pub name: String,
268 client: Arc<GeminiClient>,
269}
270
271impl BatchHandle {
272 /// Creates a new Batch instance.
273 pub(crate) fn new(name: String, client: Arc<GeminiClient>) -> Self {
274 Self { name, client }
275 }
276
277 /// Returns the unique resource name of the batch operation.
278 pub fn name(&self) -> &str {
279 &self.name
280 }
281
282 /// Retrieves the current status of the batch operation by making an API call.
283 ///
284 /// This method provides a snapshot of the batch's state at a single point in time.
285 pub async fn status(&self) -> Result<BatchStatus, Error> {
286 let operation: BatchOperation = self
287 .client
288 .get_batch_operation(&self.name)
289 .await
290 .map_err(Box::new)
291 .context(ClientSnafu)?;
292
293 BatchStatus::from_operation(operation, self.client.clone()).await
294 }
295
296 /// Sends a request to the API to cancel the batch operation.
297 ///
298 /// Cancellation is not guaranteed to be instantaneous. The operation may continue to run for
299 /// some time after the cancellation request is made.
300 ///
301 /// Consumes the batch. If cancellation fails, returns the batch and error information
302 /// so it can be retried.
303 pub async fn cancel(self) -> Result<(), (Self, ClientError)> {
304 match self.client.cancel_batch_operation(&self.name).await {
305 Ok(()) => Ok(()),
306 Err(e) => Err((self, e)),
307 }
308 }
309
310 /// Deletes the batch operation resource from the server.
311 ///
312 /// Note: This method indicates that the client is no longer interested in the operation result.
313 /// It does not cancel a running operation. To stop a running batch, use the `cancel` method.
314 /// This method should typically be used after the batch has completed.
315 ///
316 /// Consumes the batch. If deletion fails, returns the batch and error information
317 /// so it can be retried.
318 pub async fn delete(self) -> Result<(), (Self, ClientError)> {
319 match self.client.delete_batch_operation(&self.name).await {
320 Ok(()) => Ok(()),
321 Err(e) => Err((self, e)),
322 }
323 }
324}