use async_trait::async_trait;
use base64::Engine;
use futures::stream::{BoxStream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::error::{AgentError, Result};
use crate::models::LLM;
use crate::types::{File, GenerationChunk, GenerationResponse, Message, Role};
pub struct GeminiLLM {
client: Client,
api_key: String,
model: String,
}
#[derive(Debug, Serialize)]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Serialize)]
struct GeminiContent {
role: String,
parts: Vec<GeminiPart>,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum GeminiPart {
Text { text: String },
InlineData { inline_data: GeminiBlob },
}
#[derive(Debug, Serialize)]
struct GeminiBlob {
mime_type: String,
data: String,
}
#[derive(Debug, Deserialize)]
struct GeminiResponse {
candidates: Option<Vec<GeminiCandidate>>,
}
#[derive(Debug, Deserialize)]
struct GeminiCandidate {
content: Option<GeminiContentResponse>,
}
#[derive(Debug, Deserialize)]
struct GeminiContentResponse {
parts: Option<Vec<GeminiPartResponse>>,
}
#[derive(Debug, Deserialize)]
struct GeminiPartResponse {
text: Option<String>,
}
impl GeminiLLM {
pub fn new(model: impl Into<String>) -> Result<Self> {
let api_key = std::env::var("GOOGLE_API_KEY")
.or_else(|_| std::env::var("GEMINI_API_KEY"))
.map_err(|_| {
AgentError::ConfigError(
"GOOGLE_API_KEY or GEMINI_API_KEY environment variable not set".to_string(),
)
})?;
Ok(Self {
client: Client::new(),
api_key,
model: model.into(),
})
}
pub fn with_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: Client::new(),
api_key: api_key.into(),
model: model.into(),
}
}
fn convert_role(role: &Role) -> String {
match role {
Role::User => "user".to_string(),
Role::Assistant => "model".to_string(),
Role::System => "user".to_string(), Role::Tool => "user".to_string(),
}
}
fn prepare_request_body(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> GeminiRequest {
let mut contents: Vec<GeminiContent> = messages
.iter()
.map(|m| GeminiContent {
role: Self::convert_role(&m.role),
parts: vec![GeminiPart::Text {
text: m.content.clone(),
}],
})
.collect();
if let Some(files) = files {
if let Some(last_content) = contents.last_mut() {
for file in files {
last_content.parts.push(GeminiPart::InlineData {
inline_data: GeminiBlob {
mime_type: file.mime_type,
data: base64::engine::general_purpose::STANDARD.encode(&file.data),
},
});
}
}
}
GeminiRequest {
contents,
tools: None,
}
}
}
#[async_trait]
impl LLM for GeminiLLM {
async fn generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<GenerationResponse> {
let request = self.prepare_request_body(messages, files);
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
self.model, self.api_key
);
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| AgentError::ModelError(format!("Gemini API error: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentError::ModelError(format!(
"Gemini API error {}: {}",
status, text
)));
}
let gemini_response: GeminiResponse = response
.json()
.await
.map_err(|e| AgentError::ModelError(format!("Failed to parse response: {}", e)))?;
let content = gemini_response
.candidates
.as_ref()
.and_then(|c| c.first())
.and_then(|c| c.content.as_ref())
.and_then(|c| c.parts.as_ref())
.and_then(|p| p.first())
.and_then(|p| p.text.clone())
.ok_or_else(|| AgentError::ModelError("No content in response".to_string()))?;
Ok(GenerationResponse {
content,
metadata: None,
})
}
async fn stream_generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
let request = self.prepare_request_body(messages, files);
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?key={}&alt=sse",
self.model, self.api_key
);
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| AgentError::ModelError(format!("Gemini API error: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentError::ModelError(format!(
"Gemini API error {}: {}",
status, text
)));
}
let stream = response.bytes_stream();
let buffer = Vec::new();
let s = futures::stream::unfold((stream, buffer), |(mut stream, mut buffer)| async move {
loop {
if let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line = buffer.drain(0..=pos).collect::<Vec<u8>>();
let s = String::from_utf8_lossy(&line);
let trimmed = s.trim();
if trimmed.starts_with("data: ") {
let json_str = trimmed.trim_start_matches("data: ").trim();
if let Ok(resp) = serde_json::from_str::<GeminiResponse>(json_str) {
let content_opt = resp
.candidates
.as_ref()
.and_then(|c| c.first())
.and_then(|c| c.content.as_ref())
.and_then(|c| c.parts.as_ref())
.and_then(|p| p.first())
.and_then(|p| match p {
GeminiPartResponse { text: Some(t) } => Some(t.clone()),
_ => None,
});
if let Some(content) = content_opt {
return Some((
Ok(GenerationChunk {
content,
metadata: None,
}),
(stream, buffer),
));
}
}
}
continue;
}
match stream.next().await {
Some(Ok(chunk)) => {
buffer.extend_from_slice(&chunk);
}
Some(Err(e)) => {
return Some((
Err(AgentError::ModelError(e.to_string())),
(stream, buffer),
))
}
None => {
return None;
}
}
}
});
Ok(Box::pin(s))
}
fn model_name(&self) -> &str {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_gemini_generate() {
let llm = GeminiLLM::new("gemini-2.0-flash").unwrap();
let messages = vec![Message {
role: Role::User,
content: "Say 'Hello, World!' and nothing else.".to_string(),
metadata: None,
}];
let response = llm.generate(messages, None).await.unwrap();
assert!(response.content.contains("Hello"));
}
}