gemini_client_api/gemini/
ask.rs

1use super::types::*;
2use actix_web::dev::{Decompress, Payload};
3use awc::{Client, ClientResponse};
4use futures::Stream;
5use serde_json::{Value, json};
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12const API_TIMEOUT: Duration = Duration::from_secs(30);
13const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
14
15pin_project_lite::pin_project! {
16    pub struct GeminiResponseStream<'a>{
17        #[pin]
18        response_stream:ClientResponse<Decompress<Payload>>,
19        reply_storage: &'a mut String
20    }
21}
22impl<'a> GeminiResponseStream<'a> {
23    fn new(
24        response_stream: ClientResponse<Decompress<Payload>>,
25        reply_storage: &'a mut String,
26    ) -> Self {
27        Self {
28            response_stream,
29            reply_storage,
30        }
31    }
32    pub fn parse_json(text: &str) -> Result<Value, serde_json::Error> {
33        let unescaped_str = text.replace("\\\"", "\"").replace("\\n", "\n");
34        serde_json::from_str::<Value>(&unescaped_str)
35    }
36    fn get_response_text(response: &Value) -> Option<&str> {
37        response["candidates"][0]["content"]["parts"][0]["text"].as_str()
38    }
39}
40impl<'a> Stream for GeminiResponseStream<'a> {
41    type Item = Result<String, Box<dyn std::error::Error>>;
42
43    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44        let this = self.project();
45
46        match this.response_stream.poll_next(cx) {
47            Poll::Ready(Some(Ok(bytes))) => {
48                let text = String::from_utf8_lossy(&bytes);
49                if text == "]" {
50                    Poll::Ready(None)
51                } else {
52                    match serde_json::from_str(text[1..].trim()) {
53                        Ok(ref response) => {
54                            let reply = GeminiResponseStream::get_response_text(response)
55                                .map(|response| {
56                                    this.reply_storage.push_str(response);
57                                    response.to_string()
58                                })
59                                .ok_or(
60                                    format!("Gemini API sent invalid response:\n{response}").into(),
61                                );
62                            Poll::Ready(Some(reply))
63                        }
64                        Err(error) => Poll::Ready(Some(Err(error.into()))),
65                    }
66                }
67            }
68            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
69            Poll::Ready(None) => Poll::Ready(None),
70            Poll::Pending => Poll::Pending,
71        }
72    }
73}
74
75pub struct Gemini<'a> {
76    client: Client,
77    api_key: String,
78    model: String,
79    sys_prompt: Option<SystemInstruction<'a>>,
80    generation_config: Option<Value>,
81}
82impl<'a> Gemini<'a> {
83    pub fn new(api_key: String, model: String, sys_prompt: Option<SystemInstruction<'a>>) -> Self {
84        Self {
85            client: Client::builder().timeout(API_TIMEOUT).finish(),
86            api_key,
87            model,
88            sys_prompt,
89            generation_config: None,
90        }
91    }
92    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
93        self.generation_config = Some(generation_config);
94        self
95    }
96    pub fn set_model(&mut self, model: String) {
97        self.model = model;
98    }
99    pub fn set_api_key(&mut self, api_key: String) {
100        self.api_key = api_key;
101    }
102    pub fn set_json_mode(&mut self, schema: Value) -> &Self {
103        if let None = self.generation_config {
104            self.generation_config = Some(json!({
105                "response_mime_type": "application/json",
106                "response_schema":schema
107            }))
108        } else if let Some(config) = self.generation_config.as_mut() {
109            config["response_mime_type"] = "application/json".into();
110            config["response_schema"] = schema.into();
111        }
112        self
113    }
114
115    pub async fn ask<'b>(&self, session: &'b mut Session) -> Result<&'b str, Box<dyn std::error::Error>> {
116        let req_url = format!(
117            "{BASE_URL}/{}:generateContent?key={}",
118            self.model, self.api_key
119        );
120
121        let response: Value = self
122            .client
123            .post(req_url)
124            .send_json(&GeminiBody::new(
125                self.sys_prompt.as_ref(),
126                &session.get_history().as_slice(),
127                self.generation_config.as_ref(),
128            ))
129            .await?
130            .json()
131            .await?;
132        let reply = GeminiResponseStream::get_response_text(&response)
133            .ok_or::<Box<dyn std::error::Error>>(format!("Gemini API sent invalid response:\n{response}").into())?;
134        session.update(reply);
135
136        let destination_string = session
137            .last_reply()
138            .ok_or::<Box<dyn std::error::Error>>(
139                "Something went wrong in ask_as_stream, sorry".into(),
140            )?;
141        Ok(destination_string)
142    }
143    pub async fn ask_as_stream<'b>(
144        &self,
145        session: &'b mut Session,
146    ) -> Result<GeminiResponseStream<'b>, Box<dyn std::error::Error>> {
147        let req_url = format!(
148            "{BASE_URL}/{}:streamGenerateContent?key={}",
149            self.model, self.api_key
150        );
151
152        let response = self
153            .client
154            .post(req_url)
155            .send_json(&GeminiBody::new(
156                self.sys_prompt.as_ref(),
157                session.get_history().as_slice(),
158                self.generation_config.as_ref(),
159            ))
160            .await?;
161        if !response.status().is_success() {
162            return Err(format!(
163                "Found status due to {} from Gemini endpoint",
164                response.status()
165            )
166            .into());
167        }
168        session.update("");
169        let destination_string = session
170            .last_reply_mut()
171            .ok_or::<Box<dyn std::error::Error>>(
172                "Something went wrong in ask_as_stream, sorry".into(),
173            )?;
174
175        Ok(GeminiResponseStream::new(response, destination_string))
176    }
177}