gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const API_TIMEOUT: Duration = Duration::from_secs(60);
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12#[derive(Clone, Default)]
13pub struct Gemini {
14    client: Client,
15    api_key: String,
16    model: String,
17    sys_prompt: Option<SystemInstruction>,
18    generation_config: Option<Value>,
19    tools: Option<Vec<Tool>>,
20}
21impl Gemini {
22    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
23    pub fn new(
24        api_key: impl Into<String>,
25        model: impl Into<String>,
26        sys_prompt: Option<SystemInstruction>,
27    ) -> Self {
28        Self {
29            client: Client::builder().timeout(API_TIMEOUT).build().unwrap(),
30            api_key: api_key.into(),
31            model: model.into(),
32            sys_prompt,
33            generation_config: None,
34            tools: None,
35        }
36    }
37    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/gemini-api/docs/text-generation#configuration-parameters)
38    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
39        self.generation_config = Some(generation_config);
40        self
41    }
42    pub fn set_model(&mut self, model: impl Into<String>) -> &mut Self {
43        self.model = model.into();
44        self
45    }
46    pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
47        self.api_key = api_key.into();
48        self
49    }
50    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
51    pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
52        if let None = self.generation_config {
53            self.generation_config = Some(json!({
54                "response_mime_type": "application/json",
55                "response_schema":schema
56            }))
57        } else if let Some(config) = self.generation_config.as_mut() {
58            config["response_mime_type"] = "application/json".into();
59            config["response_schema"] = schema.into();
60        }
61        self
62    }
63    pub fn unset_json_mode(&mut self) -> &mut Self {
64        if let Some(ref mut generation_config) = self.generation_config {
65            generation_config["response_schema"] = None::<Value>.into();
66            generation_config["response_mime_type"] = None::<Value>.into();
67        }
68        self
69    }
70    ///- `tools` can be None to unset tools from using.  
71    ///- Or Vec tools to be allowed
72    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
73        self.tools = tools;
74        self
75    }
76    pub fn unset_code_execution_mode(&mut self) -> &mut Self {
77        self.tools.take();
78        self
79    }
80
81    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
82        let req_url = format!(
83            "{BASE_URL}/{}:generateContent?key={}",
84            self.model, self.api_key
85        );
86
87        let response = self
88            .client
89            .post(req_url)
90            .json(&GeminiRequestBody::new(
91                self.sys_prompt.as_ref(),
92                self.tools.as_deref(),
93                &session.get_history().as_slice(),
94                self.generation_config.as_ref(),
95            ))
96            .send()
97            .await?;
98
99        if !response.status().is_success() {
100            let text = response.text().await?;
101            return Err(text.into());
102        }
103
104        let reply = GeminiResponse::new(response).await?;
105        session.update(&reply);
106        Ok(reply)
107    }
108    /// # Warning
109    /// You must read the response stream to get reply stored context in sessions.
110    /// `data_extractor` is used to extract data that you get as a stream of futures.
111    /// # Example
112    ///```ignore
113    ///use futures::StreamExt
114    ///let mut response_stream = gemini.ask_as_stream(session,
115    ///|_session, gemini_response| gemini_response).await.unwrap();
116    ///
117    ///while let Some(response) = response_stream.next().await {
118    ///    if let Ok(response) = response {
119    ///        println!("{}", response.get_text(""));
120    ///    }
121    ///}
122    ///```
123    pub async fn ask_as_stream<StreamType>(
124        &self,
125        session: Session,
126        data_extractor: StreamDataExtractor<StreamType>,
127    ) -> Result<GeminiResponseStream<StreamType>, GeminiResponseError> {
128        let req_url = format!(
129            "{BASE_URL}/{}:streamGenerateContent?key={}",
130            self.model, self.api_key
131        );
132
133        let response = self
134            .client
135            .post(req_url)
136            .json(&GeminiRequestBody::new(
137                self.sys_prompt.as_ref(),
138                self.tools.as_deref(),
139                session.get_history().as_slice(),
140                self.generation_config.as_ref(),
141            ))
142            .send()
143            .await?;
144
145        if !response.status().is_success() {
146            let text = response.text().await?;
147            return Err(text.into());
148        }
149
150        Ok(GeminiResponseStream::new(
151            Box::new(response.bytes_stream()),
152            session,
153            data_extractor,
154        ))
155    }
156}