Skip to main content

aster/providers/
xai.rs

1use super::api_client::{ApiClient, AuthMethod};
2use super::errors::ProviderError;
3use super::retry::ProviderRetry;
4use super::utils::{
5    get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
6    RequestLog,
7};
8use crate::conversation::message::Message;
9use crate::model::ModelConfig;
10use crate::providers::base::{
11    ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage,
12};
13use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
14use anyhow::Result;
15use async_trait::async_trait;
16use rmcp::model::Tool;
17use serde_json::Value;
18pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
19pub const XAI_DEFAULT_MODEL: &str = "grok-code-fast-1";
20pub const XAI_KNOWN_MODELS: &[&str] = &[
21    "grok-code-fast-1",
22    "grok-4-0709",
23    "grok-3",
24    "grok-3-fast",
25    "grok-3-mini",
26    "grok-3-mini-fast",
27    "grok-2-vision-1212",
28    "grok-2-image-1212",
29    "grok-3-latest",
30    "grok-3-fast-latest",
31    "grok-3-mini-latest",
32    "grok-3-mini-fast-latest",
33    "grok-2-vision",
34    "grok-2-vision-latest",
35    "grok-2-image",
36    "grok-2-image-latest",
37    "grok-2",
38    "grok-2-latest",
39];
40
41pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";
42
43#[derive(serde::Serialize)]
44pub struct XaiProvider {
45    #[serde(skip)]
46    api_client: ApiClient,
47    model: ModelConfig,
48    supports_streaming: bool,
49    #[serde(skip)]
50    name: String,
51}
52
53impl XaiProvider {
54    pub async fn from_env(model: ModelConfig) -> Result<Self> {
55        let config = crate::config::Config::global();
56        let api_key: String = config.get_secret("XAI_API_KEY")?;
57        let host: String = config
58            .get_param("XAI_HOST")
59            .unwrap_or_else(|_| XAI_API_HOST.to_string());
60
61        let auth = AuthMethod::BearerToken(api_key);
62        let api_client = ApiClient::new(host, auth)?;
63
64        Ok(Self {
65            api_client,
66            model,
67            supports_streaming: true,
68            name: Self::metadata().name,
69        })
70    }
71
72    async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
73        let response = self
74            .api_client
75            .response_post("chat/completions", &payload)
76            .await?;
77
78        handle_response_openai_compat(response).await
79    }
80}
81
82#[async_trait]
83impl Provider for XaiProvider {
84    fn metadata() -> ProviderMetadata {
85        ProviderMetadata::new(
86            "xai",
87            "xAI",
88            "Grok models from xAI, including reasoning and multimodal capabilities",
89            XAI_DEFAULT_MODEL,
90            XAI_KNOWN_MODELS.to_vec(),
91            XAI_DOC_URL,
92            vec![
93                ConfigKey::new("XAI_API_KEY", true, true, None),
94                ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
95            ],
96        )
97    }
98
99    fn get_name(&self) -> &str {
100        &self.name
101    }
102
103    fn get_model_config(&self) -> ModelConfig {
104        self.model.clone()
105    }
106
107    #[tracing::instrument(
108        skip(self, model_config, system, messages, tools),
109        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
110    )]
111    async fn complete_with_model(
112        &self,
113        model_config: &ModelConfig,
114        system: &str,
115        messages: &[Message],
116        tools: &[Tool],
117    ) -> Result<(Message, ProviderUsage), ProviderError> {
118        let payload = create_request(
119            model_config,
120            system,
121            messages,
122            tools,
123            &super::utils::ImageFormat::OpenAi,
124            false,
125        )?;
126
127        let mut log = RequestLog::start(&self.model, &payload)?;
128        let response = self.with_retry(|| self.post(payload.clone())).await?;
129
130        let message = response_to_message(&response)?;
131        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
132            tracing::debug!("Failed to get usage data");
133            Usage::default()
134        });
135        let response_model = get_model(&response);
136        log.write(&response, Some(&usage))?;
137        Ok((message, ProviderUsage::new(response_model, usage)))
138    }
139
140    fn supports_streaming(&self) -> bool {
141        self.supports_streaming
142    }
143
144    async fn stream(
145        &self,
146        system: &str,
147        messages: &[Message],
148        tools: &[Tool],
149    ) -> Result<MessageStream, ProviderError> {
150        let payload = create_request(
151            &self.model,
152            system,
153            messages,
154            tools,
155            &super::utils::ImageFormat::OpenAi,
156            true,
157        )?;
158        let mut log = RequestLog::start(&self.model, &payload)?;
159
160        let response = self
161            .with_retry(|| async {
162                let resp = self
163                    .api_client
164                    .response_post("chat/completions", &payload)
165                    .await?;
166                handle_status_openai_compat(resp).await
167            })
168            .await
169            .inspect_err(|e| {
170                let _ = log.error(e);
171            })?;
172
173        stream_openai_compat(response, log)
174    }
175}