gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21    /// # Arguments
22    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
23    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
24    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
25    pub fn new(
26        api_key: impl Into<String>,
27        model: impl Into<String>,
28        sys_prompt: Option<SystemInstruction>,
29    ) -> Self {
30        Self {
31            client: Client::builder()
32                .timeout(Duration::from_secs(60))
33                .build()
34                .unwrap(),
35            api_key: api_key.into(),
36            model: model.into(),
37            sys_prompt,
38            generation_config: None,
39            tools: None,
40        }
41    }
42    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
43    pub fn new_with_timeout(
44        api_key: impl Into<String>,
45        model: impl Into<String>,
46        sys_prompt: Option<SystemInstruction>,
47        api_timeout: Duration,
48    ) -> Self {
49        Self {
50            client: Client::builder().timeout(api_timeout).build().unwrap(),
51            api_key: api_key.into(),
52            model: model.into(),
53            sys_prompt,
54            generation_config: None,
55            tools: None,
56        }
57    }
58    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/gemini-api/docs/text-generation#configuration-parameters)
59    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
60        self.generation_config = Some(generation_config);
61        self
62    }
63    pub fn set_model(&mut self, model: impl Into<String>) -> &mut Self {
64        self.model = model.into();
65        self
66    }
67    pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
68        self.api_key = api_key.into();
69        self
70    }
71    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
72    pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
73        if let None = self.generation_config {
74            self.generation_config = Some(json!({
75                "response_mime_type": "application/json",
76                "response_schema":schema
77            }))
78        } else if let Some(config) = self.generation_config.as_mut() {
79            config["response_mime_type"] = "application/json".into();
80            config["response_schema"] = schema.into();
81        }
82        self
83    }
84    pub fn unset_json_mode(&mut self) -> &mut Self {
85        if let Some(ref mut generation_config) = self.generation_config {
86            generation_config["response_schema"] = None::<Value>.into();
87            generation_config["response_mime_type"] = None::<Value>.into();
88        }
89        self
90    }
91    ///- `tools` can be None to unset tools from using.  
92    ///- Or Vec tools to be allowed
93    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
94        self.tools = tools;
95        self
96    }
97    pub fn unset_code_execution_mode(&mut self) -> &mut Self {
98        self.tools.take();
99        self
100    }
101
102    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
103        let req_url = format!(
104            "{BASE_URL}/{}:generateContent?key={}",
105            self.model, self.api_key
106        );
107
108        let response = self
109            .client
110            .post(req_url)
111            .json(&GeminiRequestBody::new(
112                self.sys_prompt.as_ref(),
113                self.tools.as_deref(),
114                &session.get_history().as_slice(),
115                self.generation_config.as_ref(),
116            ))
117            .send()
118            .await?;
119
120        if !response.status().is_success() {
121            let text = response.text().await?;
122            return Err(text.into());
123        }
124
125        let reply = GeminiResponse::new(response).await?;
126        session.update(&reply);
127        Ok(reply)
128    }
129    /// # Warning
130    /// You must read the response stream to get reply stored context in `session`.
131    /// `data_extractor` is used to extract data that you get as a stream of futures.
132    /// # Example
133    ///```ignore
134    ///use futures::StreamExt
135    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
136    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
137    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
138    ///
139    ///while let Some(response) = response_stream.next().await {
140    ///    if let Ok(response) = response {
141    ///        println!("{}", response);
142    ///    }
143    ///}
144    ///```
145    pub async fn ask_as_stream_with_extractor<F, StreamType>(
146        &self,
147        session: Session,
148        data_extractor: F,
149    ) -> Result<ResponseStream<F, StreamType>, GeminiResponseError>
150    where
151        F: FnMut(&Session, GeminiResponse) -> StreamType,
152    {
153        let req_url = format!(
154            "{BASE_URL}/{}:streamGenerateContent?key={}",
155            self.model, self.api_key
156        );
157
158        let response = self
159            .client
160            .post(req_url)
161            .json(&GeminiRequestBody::new(
162                self.sys_prompt.as_ref(),
163                self.tools.as_deref(),
164                session.get_history().as_slice(),
165                self.generation_config.as_ref(),
166            ))
167            .send()
168            .await?;
169
170        if !response.status().is_success() {
171            let text = response.text().await?;
172            return Err(text.into());
173        }
174
175        Ok(ResponseStream::new(
176            Box::new(response.bytes_stream()),
177            session,
178            data_extractor,
179        ))
180    }
181    /// # Warning
182    /// You must read the response stream to get reply stored context in `session`.
183    /// # Example
184    ///```ignore
185    ///use futures::StreamExt
186    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
187    ///
188    ///while let Some(response) = response_stream.next().await {
189    ///    if let Ok(response) = response {
190    ///        println!("{}", response.get_text(""));
191    ///    }
192    ///}
193    ///```
194    pub async fn ask_as_stream(
195        &self,
196        session: Session,
197    ) -> Result<GeminiResponseStream, GeminiResponseError> {
198        self.ask_as_stream_with_extractor(
199            session,
200            (|_, gemini_response| gemini_response)
201                as fn(&Session, GeminiResponse) -> GeminiResponse,
202        )
203        .await
204    }
205}