1use super::api_client::{ApiClient, AuthMethod};
2use super::errors::ProviderError;
3use super::retry::ProviderRetry;
4use super::utils::{
5 get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
6 RequestLog,
7};
8use crate::conversation::message::Message;
9use crate::model::ModelConfig;
10use crate::providers::base::{
11 ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage,
12};
13use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
14use anyhow::Result;
15use async_trait::async_trait;
16use rmcp::model::Tool;
17use serde_json::Value;
18pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
19pub const XAI_DEFAULT_MODEL: &str = "grok-code-fast-1";
20pub const XAI_KNOWN_MODELS: &[&str] = &[
21 "grok-code-fast-1",
22 "grok-4-0709",
23 "grok-3",
24 "grok-3-fast",
25 "grok-3-mini",
26 "grok-3-mini-fast",
27 "grok-2-vision-1212",
28 "grok-2-image-1212",
29 "grok-3-latest",
30 "grok-3-fast-latest",
31 "grok-3-mini-latest",
32 "grok-3-mini-fast-latest",
33 "grok-2-vision",
34 "grok-2-vision-latest",
35 "grok-2-image",
36 "grok-2-image-latest",
37 "grok-2",
38 "grok-2-latest",
39];
40
41pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";
42
43#[derive(serde::Serialize)]
44pub struct XaiProvider {
45 #[serde(skip)]
46 api_client: ApiClient,
47 model: ModelConfig,
48 supports_streaming: bool,
49 #[serde(skip)]
50 name: String,
51}
52
53impl XaiProvider {
54 pub async fn from_env(model: ModelConfig) -> Result<Self> {
55 let config = crate::config::Config::global();
56 let api_key: String = config.get_secret("XAI_API_KEY")?;
57 let host: String = config
58 .get_param("XAI_HOST")
59 .unwrap_or_else(|_| XAI_API_HOST.to_string());
60
61 let auth = AuthMethod::BearerToken(api_key);
62 let api_client = ApiClient::new(host, auth)?;
63
64 Ok(Self {
65 api_client,
66 model,
67 supports_streaming: true,
68 name: Self::metadata().name,
69 })
70 }
71
72 async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
73 let response = self
74 .api_client
75 .response_post("chat/completions", &payload)
76 .await?;
77
78 handle_response_openai_compat(response).await
79 }
80}
81
82#[async_trait]
83impl Provider for XaiProvider {
84 fn metadata() -> ProviderMetadata {
85 ProviderMetadata::new(
86 "xai",
87 "xAI",
88 "Grok models from xAI, including reasoning and multimodal capabilities",
89 XAI_DEFAULT_MODEL,
90 XAI_KNOWN_MODELS.to_vec(),
91 XAI_DOC_URL,
92 vec![
93 ConfigKey::new("XAI_API_KEY", true, true, None),
94 ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
95 ],
96 )
97 }
98
99 fn get_name(&self) -> &str {
100 &self.name
101 }
102
103 fn get_model_config(&self) -> ModelConfig {
104 self.model.clone()
105 }
106
107 #[tracing::instrument(
108 skip(self, model_config, system, messages, tools),
109 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
110 )]
111 async fn complete_with_model(
112 &self,
113 model_config: &ModelConfig,
114 system: &str,
115 messages: &[Message],
116 tools: &[Tool],
117 ) -> Result<(Message, ProviderUsage), ProviderError> {
118 let payload = create_request(
119 model_config,
120 system,
121 messages,
122 tools,
123 &super::utils::ImageFormat::OpenAi,
124 false,
125 )?;
126
127 let mut log = RequestLog::start(&self.model, &payload)?;
128 let response = self.with_retry(|| self.post(payload.clone())).await?;
129
130 let message = response_to_message(&response)?;
131 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
132 tracing::debug!("Failed to get usage data");
133 Usage::default()
134 });
135 let response_model = get_model(&response);
136 log.write(&response, Some(&usage))?;
137 Ok((message, ProviderUsage::new(response_model, usage)))
138 }
139
140 fn supports_streaming(&self) -> bool {
141 self.supports_streaming
142 }
143
144 async fn stream(
145 &self,
146 system: &str,
147 messages: &[Message],
148 tools: &[Tool],
149 ) -> Result<MessageStream, ProviderError> {
150 let payload = create_request(
151 &self.model,
152 system,
153 messages,
154 tools,
155 &super::utils::ImageFormat::OpenAi,
156 true,
157 )?;
158 let mut log = RequestLog::start(&self.model, &payload)?;
159
160 let response = self
161 .with_retry(|| async {
162 let resp = self
163 .api_client
164 .response_post("chat/completions", &payload)
165 .await?;
166 handle_status_openai_compat(resp).await
167 })
168 .await
169 .inspect_err(|e| {
170 let _ = log.error(e);
171 })?;
172
173 stream_openai_compat(response, log)
174 }
175}