gemini_client_api/gemini/
ask.rs

1use super::types::request::*;
2use super::types::response::*;
3use super::types::sessions::Session;
4use awc::Client;
5use serde_json::{Value, json};
6use std::time::Duration;
7
8const API_TIMEOUT: Duration = Duration::from_secs(60);
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11pub struct Gemini {
12    client: Client,
13    api_key: String,
14    model: String,
15    sys_prompt: Option<SystemInstruction>,
16    generation_config: Option<Value>,
17    tools: Option<Vec<Tool>>,
18}
19impl Gemini {
20    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
21    pub fn new(api_key: String, model: String, sys_prompt: Option<SystemInstruction>) -> Self {
22        Self {
23            client: Client::builder().timeout(API_TIMEOUT).finish(),
24            api_key,
25            model,
26            sys_prompt,
27            generation_config: None,
28            tools: None,
29        }
30    }
31    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/gemini-api/docs/text-generation#configuration-parameters)
32    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
33        self.generation_config = Some(generation_config);
34        self
35    }
36    pub fn set_model(&mut self, model: String) -> &mut Self {
37        self.model = model;
38        self
39    }
40    pub fn set_api_key(&mut self, api_key: String) -> &mut Self {
41        self.api_key = api_key;
42        self
43    }
44    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
45    pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
46        if let None = self.generation_config {
47            self.generation_config = Some(json!({
48                "response_mime_type": "application/json",
49                "response_schema":schema
50            }))
51        } else if let Some(config) = self.generation_config.as_mut() {
52            config["response_mime_type"] = "application/json".into();
53            config["response_schema"] = schema.into();
54        }
55        self
56    }
57    pub fn unset_json_mode(&mut self) -> &mut Self {
58        if let Some(ref mut generation_config) = self.generation_config {
59            generation_config["response_schema"] = None::<Value>.into();
60            generation_config["response_mime_type"] = None::<Value>.into();
61        }
62        self
63    }
64    ///- `tools` can be None to unset tools from using.  
65    ///- Or Vec tools to be allowed
66    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
67        self.tools = tools;
68        self
69    }
70    pub fn unset_code_execution_mode(&mut self) -> &mut Self {
71        self.tools.take();
72        self
73    }
74
75    pub async fn ask<'b>(
76        &self,
77        session: &'b mut Session,
78    ) -> Result<GeminiResponse, Box<dyn std::error::Error>> {
79        let req_url = format!(
80            "{BASE_URL}/{}:generateContent?key={}",
81            self.model, self.api_key
82        );
83
84        let mut response = self
85            .client
86            .post(req_url)
87            .send_json(&GeminiRequestBody::new(
88                self.sys_prompt.as_ref(),
89                self.tools.as_deref(),
90                &session.get_history().as_slice(),
91                self.generation_config.as_ref(),
92            ))
93            .await?;
94
95        if !response.status().is_success() {
96            let body = response.body().await?;
97            let text = std::str::from_utf8(&body)?;
98            return Err(text.into());
99        }
100
101        let reply = GeminiResponse::new(response).await?;
102        session.update(&reply);
103        Ok(reply)
104    }
105    /// # Warining
106    /// You must read the response stream to get reply stored context in sessions.
107    /// # Example
108    ///```ignore
109    ///let mut response_stream = gemini.ask_as_stream(&mut session).await.unwrap();
110    ///while let Some(response) = response_stream.next().await {
111    ///    if let Ok(response) = response {
112    ///        println!("{}", response.get_text(""));
113    ///    }
114    ///}
115    ///```
116    pub async fn ask_as_stream<'b>(
117        &self,
118        session: &'b mut Session,
119    ) -> Result<GeminiResponseStream<'b>, Box<dyn std::error::Error>> {
120        let req_url = format!(
121            "{BASE_URL}/{}:streamGenerateContent?key={}",
122            self.model, self.api_key
123        );
124
125        let mut response = self
126            .client
127            .post(req_url)
128            .send_json(&GeminiRequestBody::new(
129                self.sys_prompt.as_ref(),
130                self.tools.as_deref(),
131                session.get_history().as_slice(),
132                self.generation_config.as_ref(),
133            ))
134            .await?;
135
136        if !response.status().is_success() {
137            let body = response.body().await?;
138            let text = std::str::from_utf8(&body)?;
139            return Err(text.into());
140        }
141
142        Ok(GeminiResponseStream::new(response, session))
143    }
144}