1use crate::gemini::types::{
2 Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Part,
3 StreamGenerateContentResponse,
4};
5use futures_util::{Stream, StreamExt};
6use log::{debug, error, info, warn};
7use reqwest::Client;
8use serde::de::Error as SerdeError;
9use serde_json::json;
10use std::fmt;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15
16#[derive(Debug)]
18pub enum GeminiClientError {
19 RequestError(String),
21 NetworkError(reqwest::Error),
23 ParseError(serde_json::Error),
25 ApiError(String),
27}
28
29impl fmt::Display for GeminiClientError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 GeminiClientError::RequestError(msg) => write!(f, "Request error: {}", msg),
33 GeminiClientError::NetworkError(err) => write!(f, "Network error: {}", err),
34 GeminiClientError::ParseError(err) => write!(f, "Parse error: {}", err),
35 GeminiClientError::ApiError(msg) => write!(f, "API error: {}", msg),
36 }
37 }
38}
39
40impl std::error::Error for GeminiClientError {}
41
42impl From<reqwest::Error> for GeminiClientError {
43 fn from(err: reqwest::Error) -> Self {
44 GeminiClientError::NetworkError(err)
45 }
46}
47
48impl From<serde_json::Error> for GeminiClientError {
49 fn from(err: serde_json::Error) -> Self {
50 GeminiClientError::ParseError(err)
51 }
52}
53
54#[derive(Debug)]
56pub struct GeminiClient {
57 api_key: String,
58 model: String,
59 base_url: String,
60 client: Client,
61}
62
63impl GeminiClient {
64 pub fn new(api_key: &str, model: &str) -> Self {
75 info!("Creating new GeminiClient with model: {}", model);
76 GeminiClient {
77 api_key: api_key.to_string(),
78 model: model.to_string(),
79 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
80 client: Client::new(),
81 }
82 }
83
84 pub fn setup(api_key: &str) -> Self {
86 Self::new(api_key, "gemini-1.5-pro")
87 }
88
89 pub fn model(mut self, model: &str) -> Self {
91 info!("Setting model to {}", model);
92 self.model = model.to_string();
93 self
94 }
95
96 pub async fn generate_content(
106 &self,
107 prompt: &str,
108 ) -> Result<GenerateContentResponse, GeminiClientError> {
109 let request = GenerateContentRequest {
110 contents: vec![Content {
111 role: "user".to_string(),
112 parts: vec![Part {
113 text: Some(prompt.to_string()),
114 inline_data: None,
115 }],
116 }],
117 generation_config: None,
118 safety_settings: None,
119 tools: None,
120 };
121
122 self.generate_content_with_request(request).await
123 }
124
125 pub async fn generate_content_with_request(
135 &self,
136 request: GenerateContentRequest,
137 ) -> Result<GenerateContentResponse, GeminiClientError> {
138 let url = format!("{}/models/{}:generateContent", self.base_url, self.model);
139 info!("Generating content with URL: {}", url);
140 debug!("GenerateContentRequest: {:?}", request);
141
142 let response = self
143 .client
144 .post(&url)
145 .header("x-goog-api-key", &self.api_key)
146 .json(&request)
147 .send()
148 .await?;
149
150 if response.status().is_success() {
151 let response_json: serde_json::Value = response.json().await?;
152 debug!("Response JSON: {:?}", response_json);
153
154 if let Some(error) = response_json.get("error") {
156 let error_message = error.to_string();
157 error!("Gemini API error: {}", error_message);
158 return Err(GeminiClientError::ApiError(error_message));
159 }
160
161 let generate_response: GenerateContentResponse = serde_json::from_value(response_json)?;
162 info!("Successfully generated content.");
163 debug!("GenerateContentResponse: {:?}", generate_response);
164 Ok(generate_response)
165 } else {
166 let error_message = response.text().await?;
167 error!("Failed to generate content: {}", error_message);
168 Err(GeminiClientError::RequestError(error_message))
169 }
170 }
171
172 pub async fn stream_content(
182 &self,
183 prompt: &str,
184 ) -> Result<
185 impl Stream<Item = Result<StreamGenerateContentResponse, GeminiClientError>>,
186 GeminiClientError,
187 > {
188 let request = GenerateContentRequest {
189 contents: vec![Content {
190 role: "user".to_string(),
191 parts: vec![Part {
192 text: Some(prompt.to_string()),
193 inline_data: None,
194 }],
195 }],
196 generation_config: None,
197 safety_settings: None,
198 tools: None,
199 };
200
201 self.stream_content_with_request(request).await
202 }
203
204 pub async fn stream_content_with_request(
214 &self,
215 request: GenerateContentRequest,
216 ) -> Result<
217 impl Stream<Item = Result<StreamGenerateContentResponse, GeminiClientError>>,
218 GeminiClientError,
219 > {
220 let url = format!(
221 "{}/models/{}:streamGenerateContent",
222 self.base_url, self.model
223 );
224 info!("Streaming content with URL: {}", url);
225 debug!("StreamRequest: {:?}", request);
226
227 let response = self
228 .client
229 .post(&url)
230 .header("x-goog-api-key", &self.api_key)
231 .json(&request)
232 .send()
233 .await?;
234
235 if response.status().is_success() {
236 let (tx, rx) = mpsc::channel(100);
237 let stream = response.bytes_stream();
238
239 tokio::spawn(async move {
240 let mut stream = stream;
241 while let Some(chunk) = stream.next().await {
242 match chunk {
243 Ok(bytes) => {
244 let chunk_str = String::from_utf8_lossy(&bytes);
245 debug!("Received chunk: {}", chunk_str);
246
247 for line in chunk_str.lines() {
249 if line.trim().is_empty() {
250 continue;
251 }
252
253 let json_str = if line.starts_with("data: ") {
255 &line[6..]
256 } else {
257 line
258 };
259
260 if json_str.trim() == "[DONE]" {
261 break;
262 }
263
264 match serde_json::from_str::<StreamGenerateContentResponse>(
265 json_str,
266 ) {
267 Ok(stream_response) => {
268 if let Err(e) = tx.send(Ok(stream_response)).await {
269 error!("Failed to send stream response: {}", e);
270 break;
271 }
272 }
273 Err(e) => {
274 error!("Failed to parse stream response: {}", e);
275 if let Err(e) =
276 tx.send(Err(GeminiClientError::ParseError(e))).await
277 {
278 error!("Failed to send error: {}", e);
279 break;
280 }
281 }
282 }
283 }
284 }
285 Err(e) => {
286 error!("Stream error: {}", e);
287 if let Err(e) = tx.send(Err(GeminiClientError::NetworkError(e))).await {
288 error!("Failed to send network error: {}", e);
289 }
290 break;
291 }
292 }
293 }
294 });
295
296 Ok(ReceiverStream::new(rx))
297 } else {
298 let error_message = response.text().await?;
299 error!("Failed to start streaming: {}", error_message);
300 Err(GeminiClientError::RequestError(error_message))
301 }
302 }
303
304 pub async fn generate_content_with_config(
315 &self,
316 prompt: &str,
317 config: GenerationConfig,
318 ) -> Result<GenerateContentResponse, GeminiClientError> {
319 let request = GenerateContentRequest {
320 contents: vec![Content {
321 role: "user".to_string(),
322 parts: vec![Part {
323 text: Some(prompt.to_string()),
324 inline_data: None,
325 }],
326 }],
327 generation_config: Some(config),
328 safety_settings: None,
329 tools: None,
330 };
331
332 self.generate_content_with_request(request).await
333 }
334
335 pub fn generate_content_sync(&self, prompt: &str) -> String {
337 let rt = tokio::runtime::Runtime::new().unwrap();
340 match rt.block_on(self.generate_content(prompt)) {
341 Ok(response) => response
342 .get_text()
343 .unwrap_or_else(|| "No response generated".to_string()),
344 Err(e) => format!("Error: {}", e),
345 }
346 }
347}