rsclaw 0.0.1-alpha.1

rsclaw: High-performance AI agent (BETA). Optimized for M4 Max and 2GB VPS. 100% compatible with openclaw
Documentation
use super::traits::{
    ChatRequest, ChatResponse, FinishReason, LlmProvider, MessageRole, TokenUsage,
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

/// Google Gemini API provider implementation.
pub struct GeminiProvider {
    name: Arc<str>,
    api_key: Arc<str>,
    base_url: Arc<str>,
    client: Client,
}

impl GeminiProvider {
    /// Create a new Gemini provider.
    pub fn new(name: Arc<str>, api_key: Arc<str>, base_url: Option<Arc<str>>) -> Self {
        Self {
            name,
            api_key,
            base_url: base_url.unwrap_or_else(|| {
                Arc::from("https://generativelanguage.googleapis.com/v1beta")
            }),
            client: Client::new(),
        }
    }
}

#[async_trait]
impl LlmProvider for GeminiProvider {
    fn name(&self) -> &str {
        &self.name
    }

    fn provider_type(&self) -> &str {
        "gemini"
    }

    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
        let url = format!(
            "{}/models/{}:generateContent?key={}",
            self.base_url, request.model, self.api_key
        );

        // Convert messages to Gemini format
        let contents: Vec<GeminiContent> = request
            .messages
            .iter()
            .filter(|m| m.role != MessageRole::System)
            .map(|m| GeminiContent {
                role: match m.role {
                    MessageRole::User | MessageRole::Tool => "user".to_string(),
                    MessageRole::Assistant => "model".to_string(),
                    MessageRole::System => "user".to_string(),
                },
                parts: vec![GeminiPart {
                    text: m.content.clone(),
                }],
            })
            .collect();

        // Extract system instruction
        let system_instruction = request
            .messages
            .iter()
            .find(|m| m.role == MessageRole::System)
            .map(|m| GeminiSystemInstruction {
                parts: vec![GeminiPart {
                    text: m.content.clone(),
                }],
            });

        let body = GeminiRequest {
            contents,
            generation_config: GeminiGenerationConfig {
                max_output_tokens: request.max_tokens,
                temperature: request.temperature,
            },
            system_instruction,
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .context("Failed to send request to Gemini")?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            anyhow::bail!("Gemini API error: {}", error_text);
        }

        let api_response: GeminiResponse = response
            .json()
            .await
            .context("Failed to parse Gemini response")?;

        let candidate = api_response
            .candidates
            .into_iter()
            .next()
            .context("No candidates in Gemini response")?;

        let content = candidate
            .content
            .parts
            .into_iter()
            .map(|p| p.text)
            .collect::<Vec<String>>()
            .join("");

        let finish_reason = match candidate.finish_reason.as_deref() {
            Some("STOP") => FinishReason::Stop,
            Some("MAX_TOKENS") => FinishReason::Length,
            _ => FinishReason::Stop,
        };

        Ok(ChatResponse {
            id: uuid::Uuid::new_v4().to_string(),
            model: request.model,
            content,
            finish_reason,
            usage: TokenUsage {
                prompt_tokens: api_response.usage_metadata.prompt_token_count,
                completion_tokens: api_response.usage_metadata.candidates_token_count,
                total_tokens: api_response.usage_metadata.total_token_count,
            },
            tool_calls: None,
        })
    }

    async fn is_available(&self) -> bool {
        if self.api_key.is_empty() {
            return false;
        }

        // Check if we can list models
        let url = format!("{}/models?key={}", self.base_url, self.api_key);
        let response = self.client.get(&url).send().await;

        match response {
            Ok(resp) => resp.status().is_success(),
            Err(_) => false,
        }
    }
}

#[derive(Debug, Serialize)]
struct GeminiRequest {
    contents: Vec<GeminiContent>,
    generation_config: GeminiGenerationConfig,
    #[serde(skip_serializing_if = "Option::is_none")]
    system_instruction: Option<GeminiSystemInstruction>,
}

#[derive(Debug, Serialize, Deserialize)]
struct GeminiContent {
    role: String,
    parts: Vec<GeminiPart>,
}

#[derive(Debug, Serialize, Deserialize)]
struct GeminiPart {
    text: String,
}

#[derive(Debug, Serialize)]
struct GeminiGenerationConfig {
    #[serde(skip_serializing_if = "Option::is_none")]
    max_output_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    temperature: Option<f32>,
}

#[derive(Debug, Serialize)]
struct GeminiSystemInstruction {
    parts: Vec<GeminiPart>,
}

#[derive(Debug, Deserialize)]
struct GeminiResponse {
    candidates: Vec<GeminiCandidate>,
    usage_metadata: GeminiUsageMetadata,
}

#[derive(Debug, Deserialize)]
struct GeminiCandidate {
    content: GeminiContent,
    finish_reason: Option<String>,
}

#[derive(Debug, Deserialize)]
struct GeminiUsageMetadata {
    prompt_token_count: u32,
    candidates_token_count: u32,
    total_token_count: u32,
}