agent 0.0.1

A flexible AI Agent SDK for building intelligent agents
Documentation
use crate::error::{AgentError, Result};
use crate::message::{Message, MessageRole};
use crate::provider::{ModelConfig, ModelProvider, ModelResponse, Usage};
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;

pub struct GoogleProvider {
    api_key: String,
    base_url: String,
    client: reqwest::Client,
}

impl GoogleProvider {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            api_key: api_key.into(),
            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
            client: reqwest::Client::new(),
        }
    }

    fn convert_messages(&self, messages: Vec<Message>) -> (Option<String>, Vec<serde_json::Value>) {
        let mut system_instruction = None;
        let mut converted = Vec::new();

        for msg in messages {
            match msg.role {
                MessageRole::System => {
                    system_instruction = Some(msg.content);
                }
                MessageRole::User => {
                    converted.push(json!({
                        "role": "user",
                        "parts": [{"text": msg.content}]
                    }));
                }
                MessageRole::Assistant => {
                    converted.push(json!({
                        "role": "model",
                        "parts": [{"text": msg.content}]
                    }));
                }
                MessageRole::Tool => {
                    converted.push(json!({
                        "role": "user",
                        "parts": [{"text": msg.content}]
                    }));
                }
            }
        }

        (system_instruction, converted)
    }
}

#[derive(Debug, Deserialize)]
struct GoogleResponse {
    candidates: Vec<Candidate>,
    #[serde(rename = "usageMetadata")]
    usage_metadata: Option<GoogleUsage>,
}

#[derive(Debug, Deserialize)]
struct Candidate {
    content: Content,
    #[serde(rename = "finishReason")]
    finish_reason: Option<String>,
}

#[derive(Debug, Deserialize)]
struct Content {
    parts: Vec<Part>,
}

#[derive(Debug, Deserialize)]
struct Part {
    text: String,
}

#[derive(Debug, Deserialize)]
struct GoogleUsage {
    #[serde(rename = "promptTokenCount")]
    prompt_token_count: usize,
    #[serde(rename = "candidatesTokenCount")]
    candidates_token_count: usize,
    #[serde(rename = "totalTokenCount")]
    total_token_count: usize,
}

#[async_trait]
impl ModelProvider for GoogleProvider {
    fn name(&self) -> &str {
        "google"
    }

    async fn complete(&self, messages: Vec<Message>, config: &ModelConfig) -> Result<ModelResponse> {
        let url = format!(
            "{}/models/{}:generateContent?key={}",
            self.base_url, config.model, self.api_key
        );

        let (system_instruction, converted_messages) = self.convert_messages(messages);

        let mut body = json!({
            "contents": converted_messages,
            "generationConfig": {
                "temperature": config.temperature,
            }
        });

        if let Some(system) = system_instruction {
            body["systemInstruction"] = json!({
                "parts": [{"text": system}]
            });
        }

        if let Some(max_tokens) = config.max_tokens {
            body["generationConfig"]["maxOutputTokens"] = json!(max_tokens);
        }

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .map_err(|e| AgentError::ExecutionError(format!("Google API request failed: {}", e)))?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            return Err(AgentError::ExecutionError(format!(
                "Google API error: {}",
                error_text
            )));
        }

        let api_response: GoogleResponse = response
            .json()
            .await
            .map_err(|e| AgentError::ExecutionError(format!("Failed to parse Google response: {}", e)))?;

        let candidate = api_response
            .candidates
            .first()
            .ok_or_else(|| AgentError::ExecutionError("No candidates in Google response".to_string()))?;

        let content = candidate
            .content
            .parts
            .first()
            .map(|p| p.text.clone())
            .unwrap_or_default();

        let usage = api_response.usage_metadata.map(|u| Usage {
            prompt_tokens: u.prompt_token_count,
            completion_tokens: u.candidates_token_count,
            total_tokens: u.total_token_count,
        });

        Ok(ModelResponse {
            content,
            model: config.model.clone(),
            usage,
            finish_reason: candidate.finish_reason.clone(),
        })
    }

    async fn stream_complete(
        &self,
        _messages: Vec<Message>,
        _config: &ModelConfig,
    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Unpin + Send>> {
        Err(AgentError::ExecutionError(
            "Streaming not yet implemented for Google".to_string(),
        ))
    }
}