Skip to main content

st/proxy/
grok.rs

1//! 🤖 Grok Provider Implementation (X.AI)
2//!
3//! "Elon's AI enters the chat!" - The Cheet 😺
4//!
5//! Grok uses an OpenAI-compatible API at https://api.x.ai/v1
6
7use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12
13pub struct GrokProvider {
14    client: Client,
15    api_key: String,
16    base_url: String,
17}
18
19impl GrokProvider {
20    pub fn new(api_key: String) -> Self {
21        Self {
22            client: Client::new(),
23            api_key,
24            base_url: "https://api.x.ai/v1".to_string(),
25        }
26    }
27
28    /// Create with custom base URL (for testing or proxies)
29    pub fn with_base_url(api_key: String, base_url: String) -> Self {
30        Self {
31            client: Client::new(),
32            api_key,
33            base_url,
34        }
35    }
36}
37
38impl Default for GrokProvider {
39    fn default() -> Self {
40        let api_key = std::env::var("XAI_API_KEY")
41            .or_else(|_| std::env::var("GROK_API_KEY"))
42            .unwrap_or_default();
43        Self::new(api_key)
44    }
45}
46
47#[async_trait]
48impl LlmProvider for GrokProvider {
49    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
50        let url = format!("{}/chat/completions", self.base_url);
51
52        // Default to grok-beta if no model specified
53        let model = if request.model.is_empty() || request.model == "default" {
54            "grok-beta".to_string()
55        } else {
56            request.model.clone()
57        };
58
59        let grok_request = GrokChatRequest {
60            model,
61            messages: request.messages.into_iter().map(Into::into).collect(),
62            temperature: request.temperature,
63            max_tokens: request.max_tokens,
64            stream: request.stream,
65        };
66
67        let response = self
68            .client
69            .post(&url)
70            .header("Authorization", format!("Bearer {}", self.api_key))
71            .header("Content-Type", "application/json")
72            .json(&grok_request)
73            .send()
74            .await
75            .context("Failed to send request to Grok API")?;
76
77        if !response.status().is_success() {
78            let error_text = response.text().await?;
79            return Err(anyhow::anyhow!("Grok API error: {}", error_text));
80        }
81
82        let grok_response: GrokChatResponse = response.json().await?;
83
84        let content = grok_response
85            .choices
86            .first()
87            .map(|c| c.message.content.clone())
88            .unwrap_or_default();
89
90        Ok(LlmResponse {
91            content,
92            model: grok_response.model,
93            usage: grok_response.usage.map(Into::into),
94        })
95    }
96
97    fn name(&self) -> &'static str {
98        "Grok"
99    }
100}
101
102#[derive(Debug, Serialize)]
103struct GrokChatRequest {
104    model: String,
105    messages: Vec<GrokMessage>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    temperature: Option<f32>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    max_tokens: Option<usize>,
110    stream: bool,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114struct GrokMessage {
115    role: String,
116    content: String,
117}
118
119impl From<LlmMessage> for GrokMessage {
120    fn from(msg: LlmMessage) -> Self {
121        Self {
122            role: match msg.role {
123                LlmRole::System => "system".to_string(),
124                LlmRole::User => "user".to_string(),
125                LlmRole::Assistant => "assistant".to_string(),
126            },
127            content: msg.content,
128        }
129    }
130}
131
132#[derive(Debug, Deserialize)]
133struct GrokChatResponse {
134    model: String,
135    choices: Vec<GrokChoice>,
136    usage: Option<GrokUsage>,
137}
138
139#[derive(Debug, Deserialize)]
140struct GrokChoice {
141    message: GrokMessage,
142}
143
144#[derive(Debug, Deserialize)]
145struct GrokUsage {
146    prompt_tokens: usize,
147    completion_tokens: usize,
148    total_tokens: usize,
149}
150
151impl From<GrokUsage> for LlmUsage {
152    fn from(usage: GrokUsage) -> Self {
153        Self {
154            prompt_tokens: usage.prompt_tokens,
155            completion_tokens: usage.completion_tokens,
156            total_tokens: usage.total_tokens,
157        }
158    }
159}