agent 0.0.1

A flexible AI Agent SDK for building intelligent agents
Documentation
use crate::error::{AgentError, Result};
use crate::message::{Message, MessageRole};
use crate::provider::{ModelConfig, ModelProvider, ModelResponse, Usage};
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;

pub struct XAIProvider {
    api_key: String,
    base_url: String,
    client: reqwest::Client,
}

impl XAIProvider {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            api_key: api_key.into(),
            base_url: "https://api.x.ai/v1".to_string(),
            client: reqwest::Client::new(),
        }
    }

    fn convert_messages(&self, messages: Vec<Message>) -> Vec<serde_json::Value> {
        messages
            .into_iter()
            .map(|msg| {
                let role = match msg.role {
                    MessageRole::System => "system",
                    MessageRole::User => "user",
                    MessageRole::Assistant => "assistant",
                    MessageRole::Tool => "tool",
                };
                json!({
                    "role": role,
                    "content": msg.content
                })
            })
            .collect()
    }
}

#[derive(Debug, Deserialize)]
struct XAIResponse {
    choices: Vec<Choice>,
    usage: XAIUsage,
    model: String,
}

#[derive(Debug, Deserialize)]
struct Choice {
    message: MessageContent,
    finish_reason: Option<String>,
}

#[derive(Debug, Deserialize)]
struct MessageContent {
    content: String,
}

#[derive(Debug, Deserialize)]
struct XAIUsage {
    prompt_tokens: usize,
    completion_tokens: usize,
    total_tokens: usize,
}

#[async_trait]
impl ModelProvider for XAIProvider {
    fn name(&self) -> &str {
        "xai"
    }

    async fn complete(&self, messages: Vec<Message>, config: &ModelConfig) -> Result<ModelResponse> {
        let url = format!("{}/chat/completions", self.base_url);
        let converted_messages = self.convert_messages(messages);

        let mut body = json!({
            "model": config.model,
            "messages": converted_messages,
            "temperature": config.temperature,
        });

        if let Some(max_tokens) = config.max_tokens {
            body["max_tokens"] = json!(max_tokens);
        }
        if let Some(top_p) = config.top_p {
            body["top_p"] = json!(top_p);
        }

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .map_err(|e| AgentError::ExecutionError(format!("xAI API request failed: {}", e)))?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            return Err(AgentError::ExecutionError(format!(
                "xAI API error: {}",
                error_text
            )));
        }

        let api_response: XAIResponse = response
            .json()
            .await
            .map_err(|e| AgentError::ExecutionError(format!("Failed to parse xAI response: {}", e)))?;

        let choice = api_response
            .choices
            .first()
            .ok_or_else(|| AgentError::ExecutionError("No choices in xAI response".to_string()))?;

        Ok(ModelResponse {
            content: choice.message.content.clone(),
            model: api_response.model,
            usage: Some(Usage {
                prompt_tokens: api_response.usage.prompt_tokens,
                completion_tokens: api_response.usage.completion_tokens,
                total_tokens: api_response.usage.total_tokens,
            }),
            finish_reason: choice.finish_reason.clone(),
        })
    }

    async fn stream_complete(
        &self,
        _messages: Vec<Message>,
        _config: &ModelConfig,
    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Unpin + Send>> {
        Err(AgentError::ExecutionError(
            "Streaming not yet implemented for xAI".to_string(),
        ))
    }
}