gemini_rust/batch/
handle.rs

1//! The Batch module for managing batch operations.
2//!
3//! This module provides the [`BatchHandle`] struct, which is a handle to a long-running batch
4//! operation on the Gemini API. It allows for checking the status, canceling, and deleting
5//! the operation.
6//!
7//! The status of a batch operation is represented by the [`BatchStatus`] enum, which can be
8//! retrieved using the [`BatchHandle::status()`] method. When a batch completes successfully,
9//! it transitions to the [`BatchStatus::Succeeded`] state, which contains a vector of
10//! [`BatchGenerationResponseItem`].
11//!
12//! ## Batch Results
13//!
14//! The [`BatchGenerationResponseItem`] enum represents the outcome of a single request within the batch:
15//! - `Success`: Contains the generated `GenerationResponse` and the original request key.
16//! - `Error`: Contains an `IndividualRequestError` and the original request key.
17//!
18//! Results can be delivered in two ways, depending on the size of the batch job:
19//! 1.  **Inlined Responses**: For smaller jobs, the results are included directly in the
20//!     batch operation's metadata.
21//! 2.  **Response File**: For larger jobs (typically >20MB), the results are written to a
22//!     file, and the batch metadata will contain a reference to this file. The SDK
23//!     handles the downloading and parsing of this file automatically when you call
24//!     `status()` on a completed batch.
25//!
26//! The results are automatically sorted by their original request key (as a number) to ensure
27//! a consistent and predictable order.
28//!
29//! For more information, see the official Google AI documentation:
30//! - [Batch Mode Guide](https://ai.google.dev/gemini-api/docs/batch-mode)
31//! - [Batch API Reference](https://ai.google.dev/api/batch-mode)
32//!
33//! # Design Note: Resource Management in Batch Operations
34//!
35//! The Batch API methods that consume the [`BatchHandle`] struct (`cancel`, `delete`)
36//! return `std::result::Result<T, (Self, crate::Error)>` instead of the crate's `Result<T>`.
37//! This design follows patterns used in channel libraries (e.g., `std::sync::mpsc::Receiver`)
38//! and provides two key benefits:
39//!
40//! 1. **Resource Safety**: Once a [`BatchHandle`] is consumed by an operation, it cannot be used again,
41//!    preventing invalid operations on deleted or canceled batches.
42//!
43//! 2. **Error Recovery**: If an operation fails due to transient network issues, both the
44//!    [`BatchHandle`] and error information are returned, allowing callers to retry the operation.
45//!
46//! ## Example usage:
47//! ```rust,no_run
48//! use gemini_rust::{Gemini, Message};
49//!
50//! #[tokio::main]
51//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
52//!     let client = Gemini::new(std::env::var("GEMINI_API_KEY")?)?;
53//!     let request = client.generate_content().with_user_message("Why is the sky blue?").build();
54//!     let batch = client.batch_generate_content().with_request(request).execute().await?;
55//!
56//!     match batch.delete().await {
57//!         Ok(()) => println!("Batch deleted successfully!"),
58//!         Err((batch, error)) => {
59//!             println!("Failed to delete batch: {}", error);
60//!             // Can retry: batch.delete().await
61//!         }
62//!     }
63//!     Ok(())
64//! }
65//! ```
66
67use snafu::{OptionExt, ResultExt, Snafu};
68use std::{result::Result, sync::Arc};
69
70use super::model::*;
71use crate::{
72    client::{Error as ClientError, GeminiClient},
73    files::handle::FileHandle,
74    GenerationResponse,
75};
76
77#[derive(Debug, Snafu)]
78pub enum Error {
79    #[snafu(display("batch '{name}' expired before finishing"))]
80    BatchExpired {
81        /// Batch name.
82        name: String,
83    },
84
85    #[snafu(display("batch '{name}' failed"))]
86    BatchFailed {
87        source: OperationError,
88        /// Batch name.
89        name: String,
90    },
91
92    #[snafu(display("client invocation error"))]
93    Client { source: Box<ClientError> },
94
95    #[snafu(display("failed to download batch result file '{file_name}'"))]
96    FileDownload {
97        source: crate::files::Error,
98        file_name: String,
99    },
100
101    #[snafu(display("failed to decode batch result file content as UTF-8"))]
102    FileDecode { source: std::string::FromUtf8Error },
103
104    #[snafu(display("failed to parse line in batch result file"))]
105    FileParse {
106        source: serde_json::Error,
107        line: String,
108    },
109
110    /// This error should never occur, as the Google API contract
111    /// guarantees that a result will always be provided.
112    ///
113    /// I put it here anyway to avoid potential panic in case of
114    /// Google's dishonesty or GCP internal errors.
115    #[snafu(display("batch '{name}' completed but no result provided - API contract violation"))]
116    MissingResult {
117        /// Batch name.
118        name: String,
119    },
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub struct BatchGenerationResponseItem {
124    pub response: Result<GenerationResponse, IndividualRequestError>,
125    pub meta: RequestMetadata,
126}
127
128/// Represents the overall status of a batch operation.
129#[derive(Debug, Clone, PartialEq)]
130pub enum BatchStatus {
131    /// The operation is waiting to be processed.
132    Pending,
133    /// The operation is currently being processed.
134    Running {
135        pending_count: i64,
136        completed_count: i64,
137        failed_count: i64,
138        total_count: i64,
139    },
140    /// The operation has completed successfully.
141    Succeeded {
142        results: Vec<BatchGenerationResponseItem>,
143    },
144    /// The operation was cancelled by the user.
145    Cancelled,
146    /// The operation has expired.
147    Expired,
148}
149
150impl BatchStatus {
151    async fn parse_response_file(
152        response_file: crate::files::model::File,
153        client: Arc<GeminiClient>,
154    ) -> Result<Vec<BatchGenerationResponseItem>, Error> {
155        let file = FileHandle::new(client.clone(), response_file);
156        let file_content_bytes = file.download().await.context(FileDownloadSnafu {
157            file_name: file.name(),
158        })?;
159        let file_content = String::from_utf8(file_content_bytes).context(FileDecodeSnafu)?;
160
161        let mut results = vec![];
162        for line in file_content.lines() {
163            if line.trim().is_empty() {
164                continue;
165            }
166            let item: BatchResponseFileItem =
167                serde_json::from_str(line).context(FileParseSnafu {
168                    line: line.to_string(),
169                })?;
170
171            results.push(BatchGenerationResponseItem {
172                response: item.response.into(),
173                meta: RequestMetadata { key: item.key },
174            });
175        }
176        Ok(results)
177    }
178
179    async fn process_successful_response(
180        response: BatchOperationResponse,
181        client: Arc<GeminiClient>,
182    ) -> Result<Vec<BatchGenerationResponseItem>, Error> {
183        let results = match response {
184            BatchOperationResponse::InlinedResponses { inlined_responses } => inlined_responses
185                .inlined_responses
186                .into_iter()
187                .map(|item| BatchGenerationResponseItem {
188                    response: item.result.into(),
189                    meta: item.metadata,
190                })
191                .collect(),
192            BatchOperationResponse::ResponsesFile { responses_file } => {
193                let file = crate::files::model::File {
194                    name: responses_file,
195                    ..Default::default()
196                };
197                Self::parse_response_file(file, client).await?
198            }
199        };
200        Ok(results)
201    }
202
203    async fn from_operation(
204        operation: BatchOperation,
205        client: Arc<GeminiClient>,
206    ) -> Result<Self, Error> {
207        if operation.done {
208            // According to Google API documentation, when done=true, result must be present
209            let result = operation.result.context(MissingResultSnafu {
210                name: operation.name.clone(),
211            })?;
212
213            let response = Result::from(result).context(BatchFailedSnafu {
214                name: operation.name,
215            })?;
216
217            let mut results = Self::process_successful_response(response, client).await?;
218            results.sort_by_key(|r| r.meta.key);
219
220            // Handle terminal states based on metadata for edge cases
221            match operation.metadata.state {
222                BatchState::BatchStateCancelled => Ok(BatchStatus::Cancelled),
223                BatchState::BatchStateExpired => Ok(BatchStatus::Expired),
224                _ => Ok(BatchStatus::Succeeded { results }),
225            }
226        } else {
227            // The operation is still in progress.
228            match operation.metadata.state {
229                BatchState::BatchStatePending => Ok(BatchStatus::Pending),
230                BatchState::BatchStateRunning => {
231                    let total_count = operation.metadata.batch_stats.request_count;
232                    let pending_count = operation
233                        .metadata
234                        .batch_stats
235                        .pending_request_count
236                        .unwrap_or(total_count);
237                    let completed_count = operation
238                        .metadata
239                        .batch_stats
240                        .completed_request_count
241                        .unwrap_or(0);
242                    let failed_count = operation
243                        .metadata
244                        .batch_stats
245                        .failed_request_count
246                        .unwrap_or(0);
247                    Ok(BatchStatus::Running {
248                        pending_count,
249                        completed_count,
250                        failed_count,
251                        total_count,
252                    })
253                }
254                // For non-running states when done=false, treat as pending
255                _ => Ok(BatchStatus::Pending),
256            }
257        }
258    }
259}
260
261/// Represents a long-running batch operation, providing methods to manage its lifecycle.
262///
263/// A `Batch` object is a handle to a batch operation on the Gemini API. It allows you to
264/// check the status, cancel the operation, or delete it once it's no longer needed.
265pub struct BatchHandle {
266    /// The unique resource name of the batch operation, e.g., `operations/batch-xxxxxxxx`.
267    pub name: String,
268    client: Arc<GeminiClient>,
269}
270
271impl BatchHandle {
272    /// Creates a new Batch instance.
273    pub(crate) fn new(name: String, client: Arc<GeminiClient>) -> Self {
274        Self { name, client }
275    }
276
277    /// Returns the unique resource name of the batch operation.
278    pub fn name(&self) -> &str {
279        &self.name
280    }
281
282    /// Retrieves the current status of the batch operation by making an API call.
283    ///
284    /// This method provides a snapshot of the batch's state at a single point in time.
285    pub async fn status(&self) -> Result<BatchStatus, Error> {
286        let operation: BatchOperation = self
287            .client
288            .get_batch_operation(&self.name)
289            .await
290            .map_err(Box::new)
291            .context(ClientSnafu)?;
292
293        BatchStatus::from_operation(operation, self.client.clone()).await
294    }
295
296    /// Sends a request to the API to cancel the batch operation.
297    ///
298    /// Cancellation is not guaranteed to be instantaneous. The operation may continue to run for
299    /// some time after the cancellation request is made.
300    ///
301    /// Consumes the batch. If cancellation fails, returns the batch and error information
302    /// so it can be retried.
303    pub async fn cancel(self) -> Result<(), (Self, ClientError)> {
304        match self.client.cancel_batch_operation(&self.name).await {
305            Ok(()) => Ok(()),
306            Err(e) => Err((self, e)),
307        }
308    }
309
310    /// Deletes the batch operation resource from the server.
311    ///
312    /// Note: This method indicates that the client is no longer interested in the operation result.
313    /// It does not cancel a running operation. To stop a running batch, use the `cancel` method.
314    /// This method should typically be used after the batch has completed.
315    ///
316    /// Consumes the batch. If deletion fails, returns the batch and error information
317    /// so it can be retried.
318    pub async fn delete(self) -> Result<(), (Self, ClientError)> {
319        match self.client.delete_batch_operation(&self.name).await {
320            Ok(()) => Ok(()),
321            Err(e) => Err((self, e)),
322        }
323    }
324}