gemini_client_api/gemini/
ask.rs

1use super::types::request::*;
2use super::types::response::*;
3use super::types::sessions::Session;
4use awc::Client;
5use serde_json::{Value, json};
6use std::time::Duration;
7
8const API_TIMEOUT: Duration = Duration::from_secs(60);
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
22    pub fn new(api_key: String, model: String, sys_prompt: Option<SystemInstruction>) -> Self {
23        Self {
24            client: Client::builder().timeout(API_TIMEOUT).finish(),
25            api_key,
26            model,
27            sys_prompt,
28            generation_config: None,
29            tools: None,
30        }
31    }
32    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/gemini-api/docs/text-generation#configuration-parameters)
33    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
34        self.generation_config = Some(generation_config);
35        self
36    }
37    pub fn set_model(&mut self, model: String) -> &mut Self {
38        self.model = model;
39        self
40    }
41    pub fn set_api_key(&mut self, api_key: String) -> &mut Self {
42        self.api_key = api_key;
43        self
44    }
45    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
46    pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
47        if let None = self.generation_config {
48            self.generation_config = Some(json!({
49                "response_mime_type": "application/json",
50                "response_schema":schema
51            }))
52        } else if let Some(config) = self.generation_config.as_mut() {
53            config["response_mime_type"] = "application/json".into();
54            config["response_schema"] = schema.into();
55        }
56        self
57    }
58    pub fn unset_json_mode(&mut self) -> &mut Self {
59        if let Some(ref mut generation_config) = self.generation_config {
60            generation_config["response_schema"] = None::<Value>.into();
61            generation_config["response_mime_type"] = None::<Value>.into();
62        }
63        self
64    }
65    ///- `tools` can be None to unset tools from using.  
66    ///- Or Vec tools to be allowed
67    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
68        self.tools = tools;
69        self
70    }
71    pub fn unset_code_execution_mode(&mut self) -> &mut Self {
72        self.tools.take();
73        self
74    }
75
76    pub async fn ask<'b>(
77        &self,
78        session: &'b mut Session,
79    ) -> Result<GeminiResponse, Box<dyn std::error::Error>> {
80        let req_url = format!(
81            "{BASE_URL}/{}:generateContent?key={}",
82            self.model, self.api_key
83        );
84
85        let mut response = self
86            .client
87            .post(req_url)
88            .send_json(&GeminiRequestBody::new(
89                self.sys_prompt.as_ref(),
90                self.tools.as_deref(),
91                &session.get_history().as_slice(),
92                self.generation_config.as_ref(),
93            ))
94            .await?;
95
96        if !response.status().is_success() {
97            let body = response.body().await?;
98            let text = std::str::from_utf8(&body)?;
99            return Err(text.into());
100        }
101
102        let reply = GeminiResponse::new(response).await?;
103        session.update(&reply);
104        Ok(reply)
105    }
106    /// # Warining
107    /// You must read the response stream to get reply stored context in sessions.
108    /// # Example
109    ///```ignore
110    ///let mut response_stream = gemini.ask_as_stream(&mut session).await.unwrap();
111    ///while let Some(response) = response_stream.next().await {
112    ///    if let Ok(response) = response {
113    ///        println!("{}", response.get_text(""));
114    ///    }
115    ///}
116    ///```
117    pub async fn ask_as_stream<'b>(
118        &self,
119        session: &'b mut Session,
120    ) -> Result<GeminiResponseStream<'b>, Box<dyn std::error::Error>> {
121        let req_url = format!(
122            "{BASE_URL}/{}:streamGenerateContent?key={}",
123            self.model, self.api_key
124        );
125
126        let mut response = self
127            .client
128            .post(req_url)
129            .send_json(&GeminiRequestBody::new(
130                self.sys_prompt.as_ref(),
131                self.tools.as_deref(),
132                session.get_history().as_slice(),
133                self.generation_config.as_ref(),
134            ))
135            .await?;
136
137        if !response.status().is_success() {
138            let body = response.body().await?;
139            let text = std::str::from_utf8(&body)?;
140            return Err(text.into());
141        }
142
143        Ok(GeminiResponseStream::new(response, session))
144    }
145}