Skip to main content

google_gemini_rs/client/
mod.rs

1use std::{path::Path, sync::Arc};
2
3use base64::prelude::*;
4use enum_iterator::all;
5use file_format::FileFormat;
6use rust_mcp_sdk::McpClient;
7use serde_json::Value;
8use thiserror::Error;
9
10use crate::google::{
11    GoogleModel, GoogleModelVariant,
12    common::{Blob, Content, FileData, FunctionCall, HarmCategory, Part, Role},
13    request::{
14        GenerateContentRequest, GenerationConfig, HarmBlockThreshold, SafetySettings,
15        UpdateGenConfig,
16    },
17    response::ContentResponse,
18};
19
20const URL_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
21const URL_EXTENSION: &str = ":streamGenerateContent";
22
23#[derive(Error, Debug)]
24pub enum Error {
25    #[error(transparent)]
26    SerdeJson(#[from] serde_json::Error),
27    #[error(transparent)]
28    Reqwest(#[from] reqwest::Error),
29    #[error("Agent Request")]
30    Request { code: i32, message: String },
31    #[error(transparent)]
32    Io(#[from] std::io::Error),
33    #[error(transparent)]
34    MpcSdk(#[from] rust_mcp_sdk::error::McpSdkError),
35    #[error("{0}")]
36    UnsupportedConfig(String),
37    #[error("{0}")]
38    NotFound(String),
39}
40
41impl From<&Value> for Error {
42    fn from(value: &Value) -> Self {
43        let mut code = 0;
44        let mut message = String::new();
45        if let Ok(map) = serde_json::from_value::<serde_json::Map<String, Value>>(value.clone()) {
46            if let Some(cd) = map.get("code") {
47                code = serde_json::from_value::<i32>(cd.clone()).unwrap_or(0);
48            }
49            if let Some(msg) = map.get("message") {
50                message = serde_json::from_value::<String>(msg.clone())
51                    .unwrap_or_else(|_| "Unknown error".to_string());
52            }
53        }
54        Error::Request { code, message }
55    }
56}
57
58/// Wrapper struct which stores the HTTP Reqwest client and the request history.  The `send`
59/// methods are used to send text and images without having to manage the history manually.
60#[derive(Clone)]
61pub struct Client {
62    client: reqwest::Client,
63    pub model: GoogleModel,
64    key: String,
65    request: GenerateContentRequest,
66    mcps: Vec<Arc<rust_mcp_sdk::mcp_client::ClientRuntime>>,
67}
68
69/// The model may return more than one output since we use streaming.  This wrapper
70/// is used as a helper to consolidate the outputs.
71#[derive(Debug)]
72pub struct Responses(Vec<ContentResponse>);
73
74impl Responses {
75    pub fn inner(&self) -> &[ContentResponse] {
76        &self.0
77    }
78}
79
80impl Responses {
81    /// Squash multiple text responses into a single string.
82    pub fn text(&self) -> Option<String> {
83        let mut text = String::new();
84        for content in &self.0 {
85            for candidate in &content.candidates {
86                for part in &candidate.content.parts {
87                    if let Part::Text(txt) = part {
88                        text += txt
89                    }
90                }
91            }
92        }
93        if text.is_empty() { None } else { Some(text) }
94    }
95
96    /// Helper to extract the image mime types and Base64 encoded data.
97    pub fn images(&self) -> Vec<(String, String)> {
98        let mut images = Vec::new();
99        for content in &self.0 {
100            for candidate in &content.candidates {
101                for part in &candidate.content.parts {
102                    if let Part::InlineData(blob) = part {
103                        images.push((blob.mime_type.clone(), blob.data.clone()));
104                    }
105                }
106            }
107        }
108
109        images
110    }
111}
112
113impl Client {
114    /// Creates a new instance of a Reqwest client.  The client is setup to utilize the given
115    /// Google Gemini model.
116    pub async fn new(model: &GoogleModel, key: &str) -> Result<Self, Error> {
117        Ok(Client {
118            client: reqwest::Client::new(),
119            model: model.clone(),
120            key: key.to_string(),
121            request: GenerateContentRequest {
122                system_instruction: None,
123                contents: vec![],
124                tools: vec![],
125                tool_config: None,
126                safety_settings: vec![],
127                generation_config: None,
128                cached_content: None,
129            },
130            mcps: vec![],
131        })
132    }
133
134    /// Mutates the client by setting sane default configurations based on the model.
135    pub fn with_defaults(&mut self) -> Self {
136        let safety_settings = all::<HarmCategory>()
137            .collect::<Vec<_>>()
138            .into_iter()
139            .map(|cat| SafetySettings {
140                category: cat,
141                threshold: HarmBlockThreshold::default(),
142            })
143            .collect();
144
145        let generation_config = GenerationConfig {
146            response_modalities: self.model.output.clone(),
147            ..Default::default()
148        };
149
150        self.request.safety_settings = safety_settings;
151        self.request.generation_config = Some(generation_config);
152
153        self.to_owned()
154    }
155
156    pub async fn with_tools_client(
157        &mut self,
158        mcps: Vec<Arc<rust_mcp_sdk::mcp_client::ClientRuntime>>,
159    ) -> Result<Self, Error> {
160        let mut tools = Vec::new();
161
162        if matches!(
163            self.model.variant,
164            GoogleModelVariant::Gemini20FlashExpImageGen
165        ) {
166            return Err(Error::UnsupportedConfig(format!(
167                "Model {} does not support tool calls",
168                self.model
169            )));
170        }
171
172        self.mcps = mcps;
173
174        for client in &self.mcps {
175            tools.push(client.list_tools(None).await?.tools.into())
176        }
177
178        self.request.tools = tools;
179
180        Ok(self.to_owned())
181    }
182
183    /// Mutate the client by setting the specified safety settings.
184    pub fn with_safety(&mut self, safety_settings: &[SafetySettings]) -> Self {
185        self.request.safety_settings = safety_settings.to_vec();
186
187        self.to_owned()
188    }
189
190    pub fn update_options(&mut self, updates: &[UpdateGenConfig]) -> Self {
191        let mut gen_config = self.request.clone().generation_config.unwrap_or_default();
192
193        for update in updates {
194            match update {
195                UpdateGenConfig::StopSequences(items) => gen_config.stop_sequences = items.clone(),
196                UpdateGenConfig::ResponseMimeType(response_mime_type) => {
197                    gen_config.response_mime_type = response_mime_type.clone()
198                }
199                UpdateGenConfig::ResponseSchema(schema) => {
200                    gen_config.response_schema = schema.clone()
201                }
202                UpdateGenConfig::ResponseModalities(items) => {
203                    gen_config.response_modalities = items.clone()
204                }
205                UpdateGenConfig::CandidateCount(candidate_count) => {
206                    gen_config.candidate_count = *candidate_count
207                }
208                UpdateGenConfig::MaxOutputTokens(max_output_tokens) => {
209                    gen_config.max_output_tokens = *max_output_tokens
210                }
211                UpdateGenConfig::Temperature(temp) => gen_config.temperature = *temp,
212                UpdateGenConfig::TopP(topp) => gen_config.top_p = *topp,
213                UpdateGenConfig::TopK(topk) => gen_config.top_k = *topk,
214                UpdateGenConfig::Seed(seed) => gen_config.seed = *seed,
215                UpdateGenConfig::PresencePenalty(presence_penalty) => {
216                    gen_config.presence_penalty = *presence_penalty
217                }
218                UpdateGenConfig::FrequencyPenalty(frequency_penalty) => {
219                    gen_config.frequency_penalty = *frequency_penalty
220                }
221                UpdateGenConfig::ResponseLogprobs(response_logprobs) => {
222                    gen_config.response_logprobs = *response_logprobs
223                }
224                UpdateGenConfig::Logprobs(logprobs) => gen_config.logprobs = *logprobs,
225                UpdateGenConfig::EnableEnhancedCivicAnswers(enable_enhanced_civic_answers) => {
226                    gen_config.enable_enhanced_civic_answers = *enable_enhanced_civic_answers
227                }
228                UpdateGenConfig::SpeechConfig(speech_config) => {
229                    gen_config.speech_config = speech_config.clone()
230                }
231                UpdateGenConfig::ThinkingConfig(thinking_config) => {
232                    gen_config.thinking_config = thinking_config.clone()
233                }
234                UpdateGenConfig::MediaResolution(media_resolution) => {
235                    gen_config.media_resolution = media_resolution.clone()
236                }
237            }
238        }
239
240        self.request.generation_config = Some(gen_config);
241
242        self.to_owned()
243    }
244
245    /// Mutate the client by setting the specified system instructions.  Some models do
246    /// not support system instructions, so in these cases we front-load the system instructions
247    /// as user text content.
248    pub fn with_instructions(&mut self, system_instruction: &str) -> &mut Self {
249        match self.model.variant {
250            GoogleModelVariant::Gemini20FlashExpImageGen => {
251                // The 2.0 flash experimentation image gen model does not support system instructions
252                // as this time, so we'll front-load the instructions as a user message.
253                let mut contents = vec![Content {
254                    parts: vec![Part::Text(system_instruction.to_string())],
255                    role: Role::User,
256                }];
257
258                contents.extend(self.request.contents.clone());
259
260                self.request.contents = contents;
261            }
262            _ => {
263                self.request.system_instruction = Some(Content {
264                    role: Role::User,
265                    parts: vec![Part::Text(system_instruction.to_string())],
266                });
267            }
268        }
269
270        self
271    }
272
273    pub fn with_options(&mut self, options: &GenerationConfig) -> &mut Self {
274        self.request.generation_config = Some(options.clone());
275        self
276    }
277
278    /// Since we're dealing with streams it is possible (?) for the stream to contain
279    /// a mixture of successful responses and errors.  For simplicity we bail on error
280    /// and return just the error, while we reconsolidate all successful responses.
281    fn merge_response(
282        &mut self,
283        responses: &[ContentResponse],
284    ) -> Result<Vec<ContentResponse>, Error> {
285        let mut success = Vec::new();
286
287        for response in responses {
288            if let Some(error) = &response.error {
289                return Err(error.into());
290            } else {
291                for candidate in &response.candidates {
292                    if !candidate.content.parts.is_empty() {
293                        self.request.contents.push(candidate.content.clone());
294                    }
295                }
296                success.push(response.clone());
297            }
298        }
299
300        Ok(success)
301    }
302
303    async fn tool_call(&self, function_call: &FunctionCall) -> Result<Vec<Part>, Error> {
304        let mut parts = vec![];
305
306        let index = self
307            .request
308            .tools
309            .iter()
310            .enumerate()
311            .find(|(_i, t)| {
312                t.function_declarations
313                    .iter()
314                    .any(|f| f.name == function_call.name)
315            })
316            .ok_or_else(|| Error::NotFound(function_call.name.clone()))?
317            .0;
318
319        let t = self.mcps.get(index).ok_or_else(|| {
320            Error::NotFound(format!("Tool for function call {}", function_call.name))
321        })?;
322
323        let response = t
324            .call_tool(rust_mcp_sdk::schema::CallToolRequestParams {
325                arguments: function_call.args.clone(),
326                name: function_call.name.clone(),
327            })
328            .await?;
329
330        for content in &response.content {
331            let part = match content {
332                rust_mcp_sdk::schema::ContentBlock::TextContent(text_content) => {
333                    Part::FunctionResponse(crate::google::common::FunctionResponse {
334                        id: None,
335                        name: function_call.name.clone(),
336                        response: serde_json::from_str::<serde_json::Map<String, Value>>(
337                            &serde_json::to_string(text_content)?,
338                        )?,
339                    })
340                }
341                rust_mcp_sdk::schema::ContentBlock::ImageContent(image_content) => {
342                    Part::FunctionResponse(crate::google::common::FunctionResponse {
343                        id: None,
344                        name: function_call.name.clone(),
345                        response: serde_json::from_str::<serde_json::Map<String, Value>>(
346                            &serde_json::to_string(image_content)?,
347                        )?,
348                    })
349                }
350                rust_mcp_sdk::schema::ContentBlock::AudioContent(audio_content) => {
351                    Part::FunctionResponse(crate::google::common::FunctionResponse {
352                        id: None,
353                        name: function_call.name.clone(),
354                        response: serde_json::from_str::<serde_json::Map<String, Value>>(
355                            &serde_json::to_string(audio_content)?,
356                        )?,
357                    })
358                }
359                rust_mcp_sdk::schema::ContentBlock::EmbeddedResource(embedded_resource) => {
360                    Part::FunctionResponse(crate::google::common::FunctionResponse {
361                        id: None,
362                        name: function_call.name.clone(),
363                        response: serde_json::from_str::<serde_json::Map<String, Value>>(
364                            &serde_json::to_string(embedded_resource)?,
365                        )?,
366                    })
367                }
368                rust_mcp_sdk::schema::ContentBlock::ResourceLink(resource_link) => {
369                    Part::FunctionResponse(crate::google::common::FunctionResponse {
370                        id: None,
371                        name: function_call.name.clone(),
372                        response: serde_json::from_str::<serde_json::Map<String, Value>>(
373                            &serde_json::to_string(resource_link)?,
374                        )?,
375                    })
376                }
377            };
378
379            parts.push(part);
380        }
381
382        Ok(parts)
383    }
384
385    /// Processes tool requests from the model.  We need to push all results onto the content
386    /// request stack for the history.
387    async fn process_tools(&mut self, in_responses: &[ContentResponse]) -> Result<bool, Error> {
388        let mut fn_calls = Vec::new();
389
390        for in_response in in_responses {
391            for in_candidate in &in_response.candidates {
392                for in_part in &in_candidate.content.parts {
393                    match in_part {
394                        Part::Thought(_)
395                        | Part::Text(_)
396                        | Part::InlineData(_)
397                        | Part::FileData(_)
398                        | Part::ExecutableCode(_)
399                        | Part::CodeExecutionResult(_)
400                        | Part::FunctionResponse(_) => {}
401                        Part::FunctionCall(function_call) => {
402                            fn_calls.push(function_call.clone());
403                        }
404                    }
405                }
406            }
407        }
408
409        if !fn_calls.is_empty() {
410            for function_call in &fn_calls {
411                let parts = self.tool_call(function_call).await?;
412
413                self.request.contents.push(Content {
414                    parts,
415                    role: Role::User,
416                });
417            }
418            Ok(true)
419        } else {
420            Ok(false)
421        }
422    }
423
424    async fn do_post(&mut self) -> Result<Vec<ContentResponse>, Error> {
425        let request = self
426            .client
427            .post(self.url())
428            .header("Content-Type", "application/json")
429            .query(&[("key", &self.key)])
430            .json(&self.request);
431
432        let responses = request.send().await?.json::<Vec<ContentResponse>>().await?;
433
434        self.merge_response(&responses)
435    }
436
437    async fn post(&mut self) -> Result<Responses, Error> {
438        let mut responses = self.do_post().await?;
439
440        // Process all functions that the model maay be calling and feed the results
441        // back in.
442        while self.process_tools(&responses).await? {
443            responses = self.do_post().await?;
444        }
445
446        Ok(Responses(responses))
447    }
448
449    /// Send the given text to the model.  Returns the responses or an error
450    /// message if an error was returned.
451    pub async fn send_text(&mut self, text: &str) -> Result<Responses, Error> {
452        self.request.contents.push(Content {
453            parts: vec![Part::Text(text.to_string())],
454            role: Role::User,
455        });
456
457        self.post().await
458    }
459
460    pub async fn send_image(&mut self, blob: &Blob) -> Result<Responses, Error> {
461        self.request.contents.push(Content {
462            parts: vec![Part::InlineData(blob.clone())],
463            role: Role::User,
464        });
465
466        self.post().await
467    }
468
469    pub async fn send_file_data(&mut self, data: &FileData) -> Result<Responses, Error> {
470        self.request.contents.push(Content {
471            parts: vec![Part::FileData(data.clone())],
472            role: Role::User,
473        });
474
475        self.post().await
476    }
477
478    pub async fn send_image_file(
479        &mut self,
480        message: Option<String>,
481        img: &Path,
482    ) -> Result<Responses, Error> {
483        let format = FileFormat::from_file(img)?;
484
485        let data = BASE64_URL_SAFE.encode(&tokio::fs::read(img).await?);
486
487        self.send_image_bytes(message, format.media_type(), &data)
488            .await
489    }
490
491    pub async fn send_parts(&mut self, parts: &[Part]) -> Result<Responses, Error> {
492        self.request.contents.push(Content {
493            parts: parts.to_vec(),
494            role: Role::User,
495        });
496
497        self.post().await
498    }
499
500    /// Send the given image to the model.  This must be a UTF-8 Base64 encoded
501    /// string which is required by the Google API.  Optional text may be sent with
502    /// the image to create a single consolidated message.  Returns the responses
503    /// or an error message if an error was returned.
504    pub async fn send_image_bytes(
505        &mut self,
506        message: Option<String>,
507        mime_type: &str,
508        data: &str,
509    ) -> Result<Responses, Error> {
510        let mut parts = Vec::new();
511
512        if let Some(message) = message {
513            parts.push(Part::Text(message.to_string()));
514        }
515
516        parts.push(Part::InlineData(Blob {
517            mime_type: mime_type.to_string(),
518            data: data.to_string(),
519        }));
520
521        self.request.contents.push(Content {
522            parts,
523            role: Role::User,
524        });
525
526        self.post().await
527    }
528
529    fn url(&self) -> String {
530        format!("{URL_BASE}/{}{URL_EXTENSION}", self.model.name)
531    }
532
533    /// Returns the entire session content.
534    pub fn history(&self) -> &[Content] {
535        &self.request.contents
536    }
537}