1use std::{path::Path, sync::Arc};
2
3use base64::prelude::*;
4use enum_iterator::all;
5use file_format::FileFormat;
6use rust_mcp_sdk::McpClient;
7use serde_json::Value;
8use thiserror::Error;
9
10use crate::google::{
11 GoogleModel, GoogleModelVariant,
12 common::{Blob, Content, FileData, FunctionCall, HarmCategory, Part, Role},
13 request::{
14 GenerateContentRequest, GenerationConfig, HarmBlockThreshold, SafetySettings,
15 UpdateGenConfig,
16 },
17 response::ContentResponse,
18};
19
20const URL_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
21const URL_EXTENSION: &str = ":streamGenerateContent";
22
23#[derive(Error, Debug)]
24pub enum Error {
25 #[error(transparent)]
26 SerdeJson(#[from] serde_json::Error),
27 #[error(transparent)]
28 Reqwest(#[from] reqwest::Error),
29 #[error("Agent Request")]
30 Request { code: i32, message: String },
31 #[error(transparent)]
32 Io(#[from] std::io::Error),
33 #[error(transparent)]
34 MpcSdk(#[from] rust_mcp_sdk::error::McpSdkError),
35 #[error("{0}")]
36 UnsupportedConfig(String),
37 #[error("{0}")]
38 NotFound(String),
39}
40
41impl From<&Value> for Error {
42 fn from(value: &Value) -> Self {
43 let mut code = 0;
44 let mut message = String::new();
45 if let Ok(map) = serde_json::from_value::<serde_json::Map<String, Value>>(value.clone()) {
46 if let Some(cd) = map.get("code") {
47 code = serde_json::from_value::<i32>(cd.clone()).unwrap_or(0);
48 }
49 if let Some(msg) = map.get("message") {
50 message = serde_json::from_value::<String>(msg.clone())
51 .unwrap_or_else(|_| "Unknown error".to_string());
52 }
53 }
54 Error::Request { code, message }
55 }
56}
57
58#[derive(Clone)]
61pub struct Client {
62 client: reqwest::Client,
63 pub model: GoogleModel,
64 key: String,
65 request: GenerateContentRequest,
66 mcps: Vec<Arc<rust_mcp_sdk::mcp_client::ClientRuntime>>,
67}
68
69#[derive(Debug)]
72pub struct Responses(Vec<ContentResponse>);
73
74impl Responses {
75 pub fn inner(&self) -> &[ContentResponse] {
76 &self.0
77 }
78}
79
80impl Responses {
81 pub fn text(&self) -> Option<String> {
83 let mut text = String::new();
84 for content in &self.0 {
85 for candidate in &content.candidates {
86 for part in &candidate.content.parts {
87 if let Part::Text(txt) = part {
88 text += txt
89 }
90 }
91 }
92 }
93 if text.is_empty() { None } else { Some(text) }
94 }
95
96 pub fn images(&self) -> Vec<(String, String)> {
98 let mut images = Vec::new();
99 for content in &self.0 {
100 for candidate in &content.candidates {
101 for part in &candidate.content.parts {
102 if let Part::InlineData(blob) = part {
103 images.push((blob.mime_type.clone(), blob.data.clone()));
104 }
105 }
106 }
107 }
108
109 images
110 }
111}
112
113impl Client {
114 pub async fn new(model: &GoogleModel, key: &str) -> Result<Self, Error> {
117 Ok(Client {
118 client: reqwest::Client::new(),
119 model: model.clone(),
120 key: key.to_string(),
121 request: GenerateContentRequest {
122 system_instruction: None,
123 contents: vec![],
124 tools: vec![],
125 tool_config: None,
126 safety_settings: vec![],
127 generation_config: None,
128 cached_content: None,
129 },
130 mcps: vec![],
131 })
132 }
133
134 pub fn with_defaults(&mut self) -> Self {
136 let safety_settings = all::<HarmCategory>()
137 .collect::<Vec<_>>()
138 .into_iter()
139 .map(|cat| SafetySettings {
140 category: cat,
141 threshold: HarmBlockThreshold::default(),
142 })
143 .collect();
144
145 let generation_config = GenerationConfig {
146 response_modalities: self.model.output.clone(),
147 ..Default::default()
148 };
149
150 self.request.safety_settings = safety_settings;
151 self.request.generation_config = Some(generation_config);
152
153 self.to_owned()
154 }
155
156 pub async fn with_tools_client(
157 &mut self,
158 mcps: Vec<Arc<rust_mcp_sdk::mcp_client::ClientRuntime>>,
159 ) -> Result<Self, Error> {
160 let mut tools = Vec::new();
161
162 if matches!(
163 self.model.variant,
164 GoogleModelVariant::Gemini20FlashExpImageGen
165 ) {
166 return Err(Error::UnsupportedConfig(format!(
167 "Model {} does not support tool calls",
168 self.model
169 )));
170 }
171
172 self.mcps = mcps;
173
174 for client in &self.mcps {
175 tools.push(client.list_tools(None).await?.tools.into())
176 }
177
178 self.request.tools = tools;
179
180 Ok(self.to_owned())
181 }
182
183 pub fn with_safety(&mut self, safety_settings: &[SafetySettings]) -> Self {
185 self.request.safety_settings = safety_settings.to_vec();
186
187 self.to_owned()
188 }
189
190 pub fn update_options(&mut self, updates: &[UpdateGenConfig]) -> Self {
191 let mut gen_config = self.request.clone().generation_config.unwrap_or_default();
192
193 for update in updates {
194 match update {
195 UpdateGenConfig::StopSequences(items) => gen_config.stop_sequences = items.clone(),
196 UpdateGenConfig::ResponseMimeType(response_mime_type) => {
197 gen_config.response_mime_type = response_mime_type.clone()
198 }
199 UpdateGenConfig::ResponseSchema(schema) => {
200 gen_config.response_schema = schema.clone()
201 }
202 UpdateGenConfig::ResponseModalities(items) => {
203 gen_config.response_modalities = items.clone()
204 }
205 UpdateGenConfig::CandidateCount(candidate_count) => {
206 gen_config.candidate_count = *candidate_count
207 }
208 UpdateGenConfig::MaxOutputTokens(max_output_tokens) => {
209 gen_config.max_output_tokens = *max_output_tokens
210 }
211 UpdateGenConfig::Temperature(temp) => gen_config.temperature = *temp,
212 UpdateGenConfig::TopP(topp) => gen_config.top_p = *topp,
213 UpdateGenConfig::TopK(topk) => gen_config.top_k = *topk,
214 UpdateGenConfig::Seed(seed) => gen_config.seed = *seed,
215 UpdateGenConfig::PresencePenalty(presence_penalty) => {
216 gen_config.presence_penalty = *presence_penalty
217 }
218 UpdateGenConfig::FrequencyPenalty(frequency_penalty) => {
219 gen_config.frequency_penalty = *frequency_penalty
220 }
221 UpdateGenConfig::ResponseLogprobs(response_logprobs) => {
222 gen_config.response_logprobs = *response_logprobs
223 }
224 UpdateGenConfig::Logprobs(logprobs) => gen_config.logprobs = *logprobs,
225 UpdateGenConfig::EnableEnhancedCivicAnswers(enable_enhanced_civic_answers) => {
226 gen_config.enable_enhanced_civic_answers = *enable_enhanced_civic_answers
227 }
228 UpdateGenConfig::SpeechConfig(speech_config) => {
229 gen_config.speech_config = speech_config.clone()
230 }
231 UpdateGenConfig::ThinkingConfig(thinking_config) => {
232 gen_config.thinking_config = thinking_config.clone()
233 }
234 UpdateGenConfig::MediaResolution(media_resolution) => {
235 gen_config.media_resolution = media_resolution.clone()
236 }
237 }
238 }
239
240 self.request.generation_config = Some(gen_config);
241
242 self.to_owned()
243 }
244
245 pub fn with_instructions(&mut self, system_instruction: &str) -> &mut Self {
249 match self.model.variant {
250 GoogleModelVariant::Gemini20FlashExpImageGen => {
251 let mut contents = vec![Content {
254 parts: vec![Part::Text(system_instruction.to_string())],
255 role: Role::User,
256 }];
257
258 contents.extend(self.request.contents.clone());
259
260 self.request.contents = contents;
261 }
262 _ => {
263 self.request.system_instruction = Some(Content {
264 role: Role::User,
265 parts: vec![Part::Text(system_instruction.to_string())],
266 });
267 }
268 }
269
270 self
271 }
272
273 pub fn with_options(&mut self, options: &GenerationConfig) -> &mut Self {
274 self.request.generation_config = Some(options.clone());
275 self
276 }
277
278 fn merge_response(
282 &mut self,
283 responses: &[ContentResponse],
284 ) -> Result<Vec<ContentResponse>, Error> {
285 let mut success = Vec::new();
286
287 for response in responses {
288 if let Some(error) = &response.error {
289 return Err(error.into());
290 } else {
291 for candidate in &response.candidates {
292 if !candidate.content.parts.is_empty() {
293 self.request.contents.push(candidate.content.clone());
294 }
295 }
296 success.push(response.clone());
297 }
298 }
299
300 Ok(success)
301 }
302
303 async fn tool_call(&self, function_call: &FunctionCall) -> Result<Vec<Part>, Error> {
304 let mut parts = vec![];
305
306 let index = self
307 .request
308 .tools
309 .iter()
310 .enumerate()
311 .find(|(_i, t)| {
312 t.function_declarations
313 .iter()
314 .any(|f| f.name == function_call.name)
315 })
316 .ok_or_else(|| Error::NotFound(function_call.name.clone()))?
317 .0;
318
319 let t = self.mcps.get(index).ok_or_else(|| {
320 Error::NotFound(format!("Tool for function call {}", function_call.name))
321 })?;
322
323 let response = t
324 .call_tool(rust_mcp_sdk::schema::CallToolRequestParams {
325 arguments: function_call.args.clone(),
326 name: function_call.name.clone(),
327 })
328 .await?;
329
330 for content in &response.content {
331 let part = match content {
332 rust_mcp_sdk::schema::ContentBlock::TextContent(text_content) => {
333 Part::FunctionResponse(crate::google::common::FunctionResponse {
334 id: None,
335 name: function_call.name.clone(),
336 response: serde_json::from_str::<serde_json::Map<String, Value>>(
337 &serde_json::to_string(text_content)?,
338 )?,
339 })
340 }
341 rust_mcp_sdk::schema::ContentBlock::ImageContent(image_content) => {
342 Part::FunctionResponse(crate::google::common::FunctionResponse {
343 id: None,
344 name: function_call.name.clone(),
345 response: serde_json::from_str::<serde_json::Map<String, Value>>(
346 &serde_json::to_string(image_content)?,
347 )?,
348 })
349 }
350 rust_mcp_sdk::schema::ContentBlock::AudioContent(audio_content) => {
351 Part::FunctionResponse(crate::google::common::FunctionResponse {
352 id: None,
353 name: function_call.name.clone(),
354 response: serde_json::from_str::<serde_json::Map<String, Value>>(
355 &serde_json::to_string(audio_content)?,
356 )?,
357 })
358 }
359 rust_mcp_sdk::schema::ContentBlock::EmbeddedResource(embedded_resource) => {
360 Part::FunctionResponse(crate::google::common::FunctionResponse {
361 id: None,
362 name: function_call.name.clone(),
363 response: serde_json::from_str::<serde_json::Map<String, Value>>(
364 &serde_json::to_string(embedded_resource)?,
365 )?,
366 })
367 }
368 rust_mcp_sdk::schema::ContentBlock::ResourceLink(resource_link) => {
369 Part::FunctionResponse(crate::google::common::FunctionResponse {
370 id: None,
371 name: function_call.name.clone(),
372 response: serde_json::from_str::<serde_json::Map<String, Value>>(
373 &serde_json::to_string(resource_link)?,
374 )?,
375 })
376 }
377 };
378
379 parts.push(part);
380 }
381
382 Ok(parts)
383 }
384
385 async fn process_tools(&mut self, in_responses: &[ContentResponse]) -> Result<bool, Error> {
388 let mut fn_calls = Vec::new();
389
390 for in_response in in_responses {
391 for in_candidate in &in_response.candidates {
392 for in_part in &in_candidate.content.parts {
393 match in_part {
394 Part::Thought(_)
395 | Part::Text(_)
396 | Part::InlineData(_)
397 | Part::FileData(_)
398 | Part::ExecutableCode(_)
399 | Part::CodeExecutionResult(_)
400 | Part::FunctionResponse(_) => {}
401 Part::FunctionCall(function_call) => {
402 fn_calls.push(function_call.clone());
403 }
404 }
405 }
406 }
407 }
408
409 if !fn_calls.is_empty() {
410 for function_call in &fn_calls {
411 let parts = self.tool_call(function_call).await?;
412
413 self.request.contents.push(Content {
414 parts,
415 role: Role::User,
416 });
417 }
418 Ok(true)
419 } else {
420 Ok(false)
421 }
422 }
423
424 async fn do_post(&mut self) -> Result<Vec<ContentResponse>, Error> {
425 let request = self
426 .client
427 .post(self.url())
428 .header("Content-Type", "application/json")
429 .query(&[("key", &self.key)])
430 .json(&self.request);
431
432 let responses = request.send().await?.json::<Vec<ContentResponse>>().await?;
433
434 self.merge_response(&responses)
435 }
436
437 async fn post(&mut self) -> Result<Responses, Error> {
438 let mut responses = self.do_post().await?;
439
440 while self.process_tools(&responses).await? {
443 responses = self.do_post().await?;
444 }
445
446 Ok(Responses(responses))
447 }
448
449 pub async fn send_text(&mut self, text: &str) -> Result<Responses, Error> {
452 self.request.contents.push(Content {
453 parts: vec![Part::Text(text.to_string())],
454 role: Role::User,
455 });
456
457 self.post().await
458 }
459
460 pub async fn send_image(&mut self, blob: &Blob) -> Result<Responses, Error> {
461 self.request.contents.push(Content {
462 parts: vec![Part::InlineData(blob.clone())],
463 role: Role::User,
464 });
465
466 self.post().await
467 }
468
469 pub async fn send_file_data(&mut self, data: &FileData) -> Result<Responses, Error> {
470 self.request.contents.push(Content {
471 parts: vec![Part::FileData(data.clone())],
472 role: Role::User,
473 });
474
475 self.post().await
476 }
477
478 pub async fn send_image_file(
479 &mut self,
480 message: Option<String>,
481 img: &Path,
482 ) -> Result<Responses, Error> {
483 let format = FileFormat::from_file(img)?;
484
485 let data = BASE64_URL_SAFE.encode(&tokio::fs::read(img).await?);
486
487 self.send_image_bytes(message, format.media_type(), &data)
488 .await
489 }
490
491 pub async fn send_parts(&mut self, parts: &[Part]) -> Result<Responses, Error> {
492 self.request.contents.push(Content {
493 parts: parts.to_vec(),
494 role: Role::User,
495 });
496
497 self.post().await
498 }
499
500 pub async fn send_image_bytes(
505 &mut self,
506 message: Option<String>,
507 mime_type: &str,
508 data: &str,
509 ) -> Result<Responses, Error> {
510 let mut parts = Vec::new();
511
512 if let Some(message) = message {
513 parts.push(Part::Text(message.to_string()));
514 }
515
516 parts.push(Part::InlineData(Blob {
517 mime_type: mime_type.to_string(),
518 data: data.to_string(),
519 }));
520
521 self.request.contents.push(Content {
522 parts,
523 role: Role::User,
524 });
525
526 self.post().await
527 }
528
529 fn url(&self) -> String {
530 format!("{URL_BASE}/{}{URL_EXTENSION}", self.model.name)
531 }
532
533 pub fn history(&self) -> &[Content] {
535 &self.request.contents
536 }
537}