use crate::error::{AgentError, Result};
use crate::message::{Message, MessageRole};
use crate::provider::{ModelConfig, ModelProvider, ModelResponse, Usage};
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;
pub struct GoogleProvider {
api_key: String,
base_url: String,
client: reqwest::Client,
}
impl GoogleProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
client: reqwest::Client::new(),
}
}
fn convert_messages(&self, messages: Vec<Message>) -> (Option<String>, Vec<serde_json::Value>) {
let mut system_instruction = None;
let mut converted = Vec::new();
for msg in messages {
match msg.role {
MessageRole::System => {
system_instruction = Some(msg.content);
}
MessageRole::User => {
converted.push(json!({
"role": "user",
"parts": [{"text": msg.content}]
}));
}
MessageRole::Assistant => {
converted.push(json!({
"role": "model",
"parts": [{"text": msg.content}]
}));
}
MessageRole::Tool => {
converted.push(json!({
"role": "user",
"parts": [{"text": msg.content}]
}));
}
}
}
(system_instruction, converted)
}
}
#[derive(Debug, Deserialize)]
struct GoogleResponse {
candidates: Vec<Candidate>,
#[serde(rename = "usageMetadata")]
usage_metadata: Option<GoogleUsage>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: Content,
#[serde(rename = "finishReason")]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Content {
parts: Vec<Part>,
}
#[derive(Debug, Deserialize)]
struct Part {
text: String,
}
#[derive(Debug, Deserialize)]
struct GoogleUsage {
#[serde(rename = "promptTokenCount")]
prompt_token_count: usize,
#[serde(rename = "candidatesTokenCount")]
candidates_token_count: usize,
#[serde(rename = "totalTokenCount")]
total_token_count: usize,
}
#[async_trait]
impl ModelProvider for GoogleProvider {
fn name(&self) -> &str {
"google"
}
async fn complete(&self, messages: Vec<Message>, config: &ModelConfig) -> Result<ModelResponse> {
let url = format!(
"{}/models/{}:generateContent?key={}",
self.base_url, config.model, self.api_key
);
let (system_instruction, converted_messages) = self.convert_messages(messages);
let mut body = json!({
"contents": converted_messages,
"generationConfig": {
"temperature": config.temperature,
}
});
if let Some(system) = system_instruction {
body["systemInstruction"] = json!({
"parts": [{"text": system}]
});
}
if let Some(max_tokens) = config.max_tokens {
body["generationConfig"]["maxOutputTokens"] = json!(max_tokens);
}
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| AgentError::ExecutionError(format!("Google API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AgentError::ExecutionError(format!(
"Google API error: {}",
error_text
)));
}
let api_response: GoogleResponse = response
.json()
.await
.map_err(|e| AgentError::ExecutionError(format!("Failed to parse Google response: {}", e)))?;
let candidate = api_response
.candidates
.first()
.ok_or_else(|| AgentError::ExecutionError("No candidates in Google response".to_string()))?;
let content = candidate
.content
.parts
.first()
.map(|p| p.text.clone())
.unwrap_or_default();
let usage = api_response.usage_metadata.map(|u| Usage {
prompt_tokens: u.prompt_token_count,
completion_tokens: u.candidates_token_count,
total_tokens: u.total_token_count,
});
Ok(ModelResponse {
content,
model: config.model.clone(),
usage,
finish_reason: candidate.finish_reason.clone(),
})
}
async fn stream_complete(
&self,
_messages: Vec<Message>,
_config: &ModelConfig,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Unpin + Send>> {
Err(AgentError::ExecutionError(
"Streaming not yet implemented for Google".to_string(),
))
}
}