use std::pin::Pin;
use crate::llm_client::{
ClientError, LlmClient, Request, event::Event, scheme::gemini::GeminiScheme,
};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
pub struct GeminiClient {
http_client: reqwest::Client,
api_key: String,
model: String,
scheme: GeminiScheme,
base_url: String,
}
impl GeminiClient {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
http_client: reqwest::Client::new(),
api_key: api_key.into(),
model: model.into(),
scheme: GeminiScheme::default(),
base_url: "https://generativelanguage.googleapis.com".to_string(),
}
}
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = client;
self
}
pub fn with_scheme(mut self, scheme: GeminiScheme) -> Self {
self.scheme = scheme;
self
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
fn build_headers(&self) -> Result<HeaderMap, ClientError> {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
Ok(headers)
}
}
#[async_trait]
impl LlmClient for GeminiClient {
async fn stream(
&self,
request: Request,
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
let url = format!(
"{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
self.base_url, self.model, self.api_key
);
let headers = self.build_headers()?;
let body = self.scheme.build_request(&request);
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let text = response.text().await.unwrap_or_default();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
let error = json.get("error").unwrap_or(&json);
let code = error
.get("status")
.and_then(|v| v.as_str())
.map(String::from);
let message = error
.get("message")
.and_then(|v| v.as_str())
.unwrap_or(&text)
.to_string();
return Err(ClientError::Api {
status: Some(status),
code,
message,
});
}
return Err(ClientError::Api {
status: Some(status),
code: None,
message: text,
});
}
let scheme = self.scheme.clone();
let byte_stream = response
.bytes_stream()
.map_err(|e| std::io::Error::other(e));
let event_stream = byte_stream.eventsource();
let stream = event_stream
.map(move |result| {
match result {
Ok(event) => {
match scheme.parse_event(&event.data) {
Ok(Some(events)) => Ok(Some(events)),
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}
Err(e) => Err(ClientError::Sse(e.to_string())),
}
})
.map(|res| {
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
Ok(None) => Box::pin(futures::stream::empty()),
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
};
s
})
.flatten();
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = GeminiClient::new("test-key", "gemini-2.0-flash");
assert_eq!(client.model, "gemini-2.0-flash");
}
#[test]
fn test_build_headers() {
let client = GeminiClient::new("test-key", "gemini-2.0-flash");
let headers = client.build_headers().unwrap();
assert!(headers.contains_key("content-type"));
}
#[test]
fn test_custom_base_url() {
let client = GeminiClient::new("test-key", "gemini-2.0-flash")
.with_base_url("https://custom.api.example.com");
assert_eq!(client.base_url, "https://custom.api.example.com");
}
}