gemini_rust/batch/
builder.rs1use snafu::ResultExt;
2use std::sync::Arc;
3use tracing::{instrument, Span};
4
5use super::handle::BatchHandle;
6use super::model::*;
7use super::*;
8use crate::{client::GeminiClient, generation::GenerateContentRequest};
9
10#[derive(Clone)]
16pub struct BatchBuilder {
17 client: Arc<GeminiClient>,
18 display_name: String,
19 requests: Vec<GenerateContentRequest>,
20}
21
22impl BatchBuilder {
23 pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
25 Self {
26 client,
27 display_name: "RustBatch".to_string(),
28 requests: Vec::new(),
29 }
30 }
31
32 pub fn with_name(mut self, name: String) -> Self {
34 self.display_name = name;
35 self
36 }
37
38 pub fn with_requests(mut self, requests: Vec<GenerateContentRequest>) -> Self {
40 self.requests = requests;
41 self
42 }
43
44 pub fn with_request(mut self, request: GenerateContentRequest) -> Self {
46 self.requests.push(request);
47 self
48 }
49
50 pub fn build(self) -> BatchGenerateContentRequest {
54 let batch_requests: Vec<BatchRequestItem> = self
55 .requests
56 .into_iter()
57 .enumerate()
58 .map(|(key, request)| BatchRequestItem {
59 request,
60 metadata: RequestMetadata { key },
61 })
62 .collect();
63
64 BatchGenerateContentRequest {
65 batch: BatchConfig {
66 display_name: self.display_name,
67 input_config: InputConfig::Requests(RequestsContainer {
68 requests: batch_requests,
69 }),
70 },
71 }
72 }
73
74 #[instrument(skip_all, fields(
78 batch.display_name = self.display_name,
79 batch.size = self.requests.len()
80 ))]
81 pub async fn execute(self) -> Result<BatchHandle, Error> {
82 let client = self.client.clone();
83 let request = self.build();
84 let response = client
85 .batch_generate_content(request)
86 .await
87 .context(ClientSnafu)?;
88 Ok(BatchHandle::new(response.name, client))
89 }
90
91 #[instrument(skip_all, fields(
97 batch.display_name = self.display_name,
98 batch.size = self.requests.len()
99 ))]
100 pub async fn execute_as_file(self) -> Result<BatchHandle, Error> {
101 let mut json_lines = String::new();
102 for (index, item) in self.requests.into_iter().enumerate() {
103 let item = BatchRequestFileItem {
104 request: item,
105 key: index,
106 };
107
108 let line = serde_json::to_string(&item).context(SerializeSnafu)?;
109 json_lines.push_str(&line);
110 json_lines.push('\n');
111 }
112 let json_bytes = json_lines.into_bytes();
113 Span::current().record("file.size", json_bytes.len());
114
115 let file_display_name = format!("{}-input.jsonl", self.display_name);
116 let file = crate::files::builder::FileBuilder::new(self.client.clone(), json_bytes)
117 .display_name(file_display_name)
118 .with_mime_type(
119 "application/jsonl"
120 .parse()
121 .expect("failed to parse MIME type 'application/jsonl'"),
122 )
123 .upload()
124 .await
125 .context(FileSnafu)?;
126
127 let request = BatchGenerateContentRequest {
128 batch: BatchConfig {
129 display_name: self.display_name,
130 input_config: InputConfig::FileName(file.name().to_string()),
131 },
132 };
133
134 let client = self.client.clone();
135 let response = client
136 .batch_generate_content(request)
137 .await
138 .context(ClientSnafu)?;
139
140 Ok(BatchHandle::new(response.name, client))
141 }
142}