gemini_client_api/gemini/
ask.rs1use super::types::*;
2use actix_web::dev::{Decompress, Payload};
3use awc::{Client, ClientResponse};
4use futures::Stream;
5use serde_json::{Value, json};
6use std::{
7 pin::Pin,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12const API_TIMEOUT: Duration = Duration::from_secs(30);
13const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
14
15pin_project_lite::pin_project! {
16 pub struct GeminiResponseStream<'a>{
17 #[pin]
18 response_stream:ClientResponse<Decompress<Payload>>,
19 reply_storage: &'a mut String
20 }
21}
22impl<'a> GeminiResponseStream<'a> {
23 fn new(
24 response_stream: ClientResponse<Decompress<Payload>>,
25 reply_storage: &'a mut String,
26 ) -> Self {
27 Self {
28 response_stream,
29 reply_storage,
30 }
31 }
32 pub fn parse_json(text: &str) -> Result<Value, serde_json::Error> {
33 let unescaped_str = text.replace("\\\"", "\"").replace("\\n", "\n");
34 serde_json::from_str::<Value>(&unescaped_str)
35 }
36 fn get_response_text(response: &Value) -> Option<&str> {
37 response["candidates"][0]["content"]["parts"][0]["text"].as_str()
38 }
39}
40impl<'a> Stream for GeminiResponseStream<'a> {
41 type Item = Result<String, Box<dyn std::error::Error>>;
42
43 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44 let this = self.project();
45
46 match this.response_stream.poll_next(cx) {
47 Poll::Ready(Some(Ok(bytes))) => {
48 let text = String::from_utf8_lossy(&bytes);
49 if text == "]" {
50 Poll::Ready(None)
51 } else {
52 match serde_json::from_str(text[1..].trim()) {
53 Ok(ref response) => {
54 let reply = GeminiResponseStream::get_response_text(response)
55 .map(|response| {
56 this.reply_storage.push_str(response);
57 response.to_string()
58 })
59 .ok_or(
60 format!("Gemini API sent invalid response:\n{response}").into(),
61 );
62 Poll::Ready(Some(reply))
63 }
64 Err(error) => Poll::Ready(Some(Err(error.into()))),
65 }
66 }
67 }
68 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
69 Poll::Ready(None) => Poll::Ready(None),
70 Poll::Pending => Poll::Pending,
71 }
72 }
73}
74
75pub struct Gemini<'a> {
76 client: Client,
77 api_key: String,
78 model: String,
79 sys_prompt: Option<SystemInstruction<'a>>,
80 generation_config: Option<Value>,
81}
82impl<'a> Gemini<'a> {
83 pub fn new(api_key: String, model: String, sys_prompt: Option<SystemInstruction<'a>>) -> Self {
84 Self {
85 client: Client::builder().timeout(API_TIMEOUT).finish(),
86 api_key,
87 model,
88 sys_prompt,
89 generation_config: None,
90 }
91 }
92 pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
93 self.generation_config = Some(generation_config);
94 self
95 }
96 pub fn set_model(&mut self, model: String) {
97 self.model = model;
98 }
99 pub fn set_api_key(&mut self, api_key: String) {
100 self.api_key = api_key;
101 }
102 pub fn set_json_mode(&mut self, schema: Value) -> &Self {
103 if let None = self.generation_config {
104 self.generation_config = Some(json!({
105 "response_mime_type": "application/json",
106 "response_schema":schema
107 }))
108 } else if let Some(config) = self.generation_config.as_mut() {
109 config["response_mime_type"] = "application/json".into();
110 config["response_schema"] = schema.into();
111 }
112 self
113 }
114
115 pub async fn ask<'b>(&self, session: &'b mut Session) -> Result<&'b str, Box<dyn std::error::Error>> {
116 let req_url = format!(
117 "{BASE_URL}/{}:generateContent?key={}",
118 self.model, self.api_key
119 );
120
121 let response: Value = self
122 .client
123 .post(req_url)
124 .send_json(&GeminiBody::new(
125 self.sys_prompt.as_ref(),
126 &session.get_history().as_slice(),
127 self.generation_config.as_ref(),
128 ))
129 .await?
130 .json()
131 .await?;
132 let reply = GeminiResponseStream::get_response_text(&response)
133 .ok_or::<Box<dyn std::error::Error>>(format!("Gemini API sent invalid response:\n{response}").into())?;
134 session.update(reply);
135
136 let destination_string = session
137 .last_reply()
138 .ok_or::<Box<dyn std::error::Error>>(
139 "Something went wrong in ask_as_stream, sorry".into(),
140 )?;
141 Ok(destination_string)
142 }
143 pub async fn ask_as_stream<'b>(
144 &self,
145 session: &'b mut Session,
146 ) -> Result<GeminiResponseStream<'b>, Box<dyn std::error::Error>> {
147 let req_url = format!(
148 "{BASE_URL}/{}:streamGenerateContent?key={}",
149 self.model, self.api_key
150 );
151
152 let response = self
153 .client
154 .post(req_url)
155 .send_json(&GeminiBody::new(
156 self.sys_prompt.as_ref(),
157 session.get_history().as_slice(),
158 self.generation_config.as_ref(),
159 ))
160 .await?;
161 if !response.status().is_success() {
162 return Err(format!(
163 "Found status due to {} from Gemini endpoint",
164 response.status()
165 )
166 .into());
167 }
168 session.update("");
169 let destination_string = session
170 .last_reply_mut()
171 .ok_or::<Box<dyn std::error::Error>>(
172 "Something went wrong in ask_as_stream, sorry".into(),
173 )?;
174
175 Ok(GeminiResponseStream::new(response, destination_string))
176 }
177}