gemini_rust/batch/
builder.rs1use snafu::ResultExt;
2use std::sync::Arc;
3use tracing::{instrument, Span};
4
5use super::handle::BatchHandle;
6use super::model::*;
7use super::*;
8use crate::{client::GeminiClient, generation::GenerateContentRequest};
9
10pub struct BatchBuilder {
16 client: Arc<GeminiClient>,
17 display_name: String,
18 requests: Vec<GenerateContentRequest>,
19}
20
21impl BatchBuilder {
22 pub(crate) fn new(client: Arc<GeminiClient>) -> Self {
24 Self {
25 client,
26 display_name: "RustBatch".to_string(),
27 requests: Vec::new(),
28 }
29 }
30
31 pub fn with_name(mut self, name: String) -> Self {
33 self.display_name = name;
34 self
35 }
36
37 pub fn with_requests(mut self, requests: Vec<GenerateContentRequest>) -> Self {
39 self.requests = requests;
40 self
41 }
42
43 pub fn with_request(mut self, request: GenerateContentRequest) -> Self {
45 self.requests.push(request);
46 self
47 }
48
49 pub fn build(self) -> BatchGenerateContentRequest {
53 let batch_requests: Vec<BatchRequestItem> = self
54 .requests
55 .into_iter()
56 .enumerate()
57 .map(|(key, request)| BatchRequestItem {
58 request,
59 metadata: RequestMetadata { key },
60 })
61 .collect();
62
63 BatchGenerateContentRequest {
64 batch: BatchConfig {
65 display_name: self.display_name,
66 input_config: InputConfig::Requests(RequestsContainer {
67 requests: batch_requests,
68 }),
69 },
70 }
71 }
72
73 #[instrument(skip_all, fields(
77 batch.display_name = self.display_name,
78 batch.size = self.requests.len()
79 ))]
80 pub async fn execute(self) -> Result<BatchHandle, Error> {
81 let client = self.client.clone();
82 let request = self.build();
83 let response = client
84 .batch_generate_content(request)
85 .await
86 .context(ClientSnafu)?;
87 Ok(BatchHandle::new(response.name, client))
88 }
89
90 #[instrument(skip_all, fields(
96 batch.display_name = self.display_name,
97 batch.size = self.requests.len()
98 ))]
99 pub async fn execute_as_file(self) -> Result<BatchHandle, Error> {
100 let mut json_lines = String::new();
101 for (index, item) in self.requests.into_iter().enumerate() {
102 let item = BatchRequestFileItem {
103 request: item,
104 key: index,
105 };
106
107 let line = serde_json::to_string(&item).context(SerializeSnafu)?;
108 json_lines.push_str(&line);
109 json_lines.push('\n');
110 }
111 let json_bytes = json_lines.into_bytes();
112 Span::current().record("file.size", json_bytes.len());
113
114 let file_display_name = format!("{}-input.jsonl", self.display_name);
115 let file = crate::files::builder::FileBuilder::new(self.client.clone(), json_bytes)
116 .display_name(file_display_name)
117 .with_mime_type(
118 "application/jsonl"
119 .parse()
120 .expect("failed to parse MIME type 'application/jsonl'"),
121 )
122 .upload()
123 .await
124 .context(FileSnafu)?;
125
126 let request = BatchGenerateContentRequest {
127 batch: BatchConfig {
128 display_name: self.display_name,
129 input_config: InputConfig::FileName(file.name().to_string()),
130 },
131 };
132
133 let client = self.client.clone();
134 let response = client
135 .batch_generate_content(request)
136 .await
137 .context(ClientSnafu)?;
138
139 Ok(BatchHandle::new(response.name, client))
140 }
141}