rsclaw 0.0.1-alpha.1

rsclaw: High-performance AI agent (BETA). Optimized for M4 Max and 2GB VPS. 100% compatible with openclaw
Documentation
use super::traits::{ChatRequest, ChatResponse, FinishReason, LlmProvider, TokenUsage};
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

/// OpenAI API provider implementation.
pub struct OpenAiProvider {
    name: Arc<str>,
    api_key: Arc<str>,
    base_url: Arc<str>,
    client: Client,
}

impl OpenAiProvider {
    /// Create a new OpenAI provider.
    pub fn new(name: Arc<str>, api_key: Arc<str>, base_url: Option<Arc<str>>) -> Self {
        Self {
            name,
            api_key,
            base_url: base_url.unwrap_or_else(|| Arc::from("https://api.openai.com/v1")),
            client: Client::new(),
        }
    }
}

#[async_trait]
impl LlmProvider for OpenAiProvider {
    fn name(&self) -> &str {
        &self.name
    }

    fn provider_type(&self) -> &str {
        "openai"
    }

    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
        let url = format!("{}/chat/completions", self.base_url);

        let messages: Vec<OpenAiMessage> = request
            .messages
            .into_iter()
            .map(|m| OpenAiMessage {
                role: match m.role {
                    super::traits::MessageRole::System => "system".to_string(),
                    super::traits::MessageRole::User => "user".to_string(),
                    super::traits::MessageRole::Assistant => "assistant".to_string(),
                    super::traits::MessageRole::Tool => "tool".to_string(),
                },
                content: m.content,
                name: m.name,
                tool_call_id: m.tool_call_id,
            })
            .collect();

        let body = OpenAiRequest {
            model: request.model,
            messages,
            max_tokens: request.max_tokens,
            temperature: request.temperature,
        };

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .context("Failed to send request to OpenAI")?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            anyhow::bail!("OpenAI API error: {}", error_text);
        }

        let api_response: OpenAiResponse = response
            .json()
            .await
            .context("Failed to parse OpenAI response")?;

        let choice = api_response
            .choices
            .into_iter()
            .next()
            .context("No choices in OpenAI response")?;

        Ok(ChatResponse {
            id: api_response.id,
            model: api_response.model,
            content: choice.message.content,
            finish_reason: match choice.finish_reason.as_str() {
                "stop" => FinishReason::Stop,
                "length" => FinishReason::Length,
                "tool_calls" => FinishReason::ToolCalls,
                _ => FinishReason::Stop,
            },
            usage: TokenUsage {
                prompt_tokens: api_response.usage.prompt_tokens,
                completion_tokens: api_response.usage.completion_tokens,
                total_tokens: api_response.usage.total_tokens,
            },
            tool_calls: None,
        })
    }

    async fn is_available(&self) -> bool {
        if self.api_key.is_empty() {
            return false;
        }

        let url = format!("{}/models", self.base_url);
        let response = self
            .client
            .get(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .send()
            .await;

        match response {
            Ok(resp) => resp.status().is_success(),
            Err(_) => false,
        }
    }
}

#[derive(Debug, Serialize)]
struct OpenAiRequest {
    model: String,
    messages: Vec<OpenAiMessage>,
    #[serde(skip_serializing_if = "Option::is_none")]
    max_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    temperature: Option<f32>,
}

#[derive(Debug, Serialize, Deserialize)]
struct OpenAiMessage {
    role: String,
    content: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    name: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    tool_call_id: Option<String>,
}

#[derive(Debug, Deserialize)]
struct OpenAiResponse {
    id: String,
    model: String,
    choices: Vec<OpenAiChoice>,
    usage: OpenAiUsage,
}

#[derive(Debug, Deserialize)]
struct OpenAiChoice {
    message: OpenAiMessage,
    finish_reason: String,
}

#[derive(Debug, Deserialize)]
struct OpenAiUsage {
    prompt_tokens: u32,
    completion_tokens: u32,
    total_tokens: u32,
}