gemini_client_api/gemini/
ask.rs

1use super::types::request::*;
2use super::types::response::*;
3use super::types::sessions::Session;
4use awc::Client;
5use serde_json::{Value, json};
6use std::time::Duration;
7
8const API_TIMEOUT: Duration = Duration::from_secs(60);
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
22    pub fn new(
23        api_key: impl Into<String>,
24        model: impl Into<String>,
25        sys_prompt: Option<SystemInstruction>,
26    ) -> Self {
27        Self {
28            client: Client::builder().timeout(API_TIMEOUT).finish(),
29            api_key: api_key.into(),
30            model: model.into(),
31            sys_prompt,
32            generation_config: None,
33            tools: None,
34        }
35    }
36    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/gemini-api/docs/text-generation#configuration-parameters)
37    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
38        self.generation_config = Some(generation_config);
39        self
40    }
41    pub fn set_model(&mut self, model: impl Into<String>) -> &mut Self {
42        self.model = model.into();
43        self
44    }
45    pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
46        self.api_key = api_key.into();
47        self
48    }
49    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
50    pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
51        if let None = self.generation_config {
52            self.generation_config = Some(json!({
53                "response_mime_type": "application/json",
54                "response_schema":schema
55            }))
56        } else if let Some(config) = self.generation_config.as_mut() {
57            config["response_mime_type"] = "application/json".into();
58            config["response_schema"] = schema.into();
59        }
60        self
61    }
62    pub fn unset_json_mode(&mut self) -> &mut Self {
63        if let Some(ref mut generation_config) = self.generation_config {
64            generation_config["response_schema"] = None::<Value>.into();
65            generation_config["response_mime_type"] = None::<Value>.into();
66        }
67        self
68    }
69    ///- `tools` can be None to unset tools from using.  
70    ///- Or Vec tools to be allowed
71    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
72        self.tools = tools;
73        self
74    }
75    pub fn unset_code_execution_mode(&mut self) -> &mut Self {
76        self.tools.take();
77        self
78    }
79
80    pub async fn ask(
81        &self,
82        session: &mut Session,
83    ) -> Result<GeminiResponse, Box<dyn std::error::Error>> {
84        let req_url = format!(
85            "{BASE_URL}/{}:generateContent?key={}",
86            self.model, self.api_key
87        );
88
89        let mut response = self
90            .client
91            .post(req_url)
92            .send_json(&GeminiRequestBody::new(
93                self.sys_prompt.as_ref(),
94                self.tools.as_deref(),
95                &session.get_history().as_slice(),
96                self.generation_config.as_ref(),
97            ))
98            .await?;
99
100        if !response.status().is_success() {
101            let body = response.body().await?;
102            let text = std::str::from_utf8(&body)?;
103            return Err(text.into());
104        }
105
106        let reply = GeminiResponse::new(response).await?;
107        session.update(&reply);
108        Ok(reply)
109    }
110    /// # Warning
111    /// You must read the response stream to get reply stored context in sessions.
112    /// `data_extractor` is used to extract data that you get as a stream of futures.
113    /// # Example
114    ///```ignore
115    ///use futures::StreamExt
116    ///let mut response_stream = gemini.ask_as_stream(session, 
117    ///|_session, gemini_response| gemini_response).await.unwrap();
118    ///
119    ///while let Some(response) = response_stream.next().await {
120    ///    if let Ok(response) = response {
121    ///        println!("{}", response.get_text(""));
122    ///    }
123    ///}
124    ///```
125    pub async fn ask_as_stream<StreamType>(
126        &self,
127        session: Session,
128        data_extractor: StreamDataExtractor<StreamType>,
129    ) -> Result<GeminiResponseStream<StreamType>, Box<dyn std::error::Error>> {
130        let req_url = format!(
131            "{BASE_URL}/{}:streamGenerateContent?key={}",
132            self.model, self.api_key
133        );
134
135        let mut response = self
136            .client
137            .post(req_url)
138            .send_json(&GeminiRequestBody::new(
139                self.sys_prompt.as_ref(),
140                self.tools.as_deref(),
141                session.get_history().as_slice(),
142                self.generation_config.as_ref(),
143            ))
144            .await?;
145
146        if !response.status().is_success() {
147            let body = response.body().await?;
148            let text = std::str::from_utf8(&body)?;
149            return Err(text.into());
150        }
151
152        Ok(GeminiResponseStream::new(response, session, data_extractor))
153    }
154}