rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    http_client::HttpClientExt,
9    providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use bytes::Bytes;
17use serde::{Deserialize, Serialize};
18use tracing::{Instrument, Level, enabled, info_span};
19use xai_api_types::{CompletionResponse, ToolDefinition};
20
21/// xAI completion models as of 2025-06-04
22pub const GROK_2_1212: &str = "grok-2-1212";
23pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
24pub const GROK_3: &str = "grok-3";
25pub const GROK_3_FAST: &str = "grok-3-fast";
26pub const GROK_3_MINI: &str = "grok-3-mini";
27pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
28pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
29pub const GROK_4: &str = "grok-4-0709";
30
31#[derive(Debug, Serialize, Deserialize)]
32pub(super) struct XAICompletionRequest {
33    model: String,
34    pub messages: Vec<Message>,
35    #[serde(flatten, skip_serializing_if = "Option::is_none")]
36    temperature: Option<f64>,
37    #[serde(skip_serializing_if = "Vec::is_empty")]
38    tools: Vec<ToolDefinition>,
39    #[serde(flatten, skip_serializing_if = "Option::is_none")]
40    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
41    #[serde(flatten, skip_serializing_if = "Option::is_none")]
42    pub additional_params: Option<serde_json::Value>,
43}
44
45impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
46    type Error = CompletionError;
47
48    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
49        let mut full_history: Vec<Message> = match &req.preamble {
50            Some(preamble) => vec![Message::system(preamble)],
51            None => vec![],
52        };
53        if let Some(docs) = req.normalized_documents() {
54            let docs: Vec<Message> = docs.try_into()?;
55            full_history.extend(docs);
56        }
57
58        let chat_history: Vec<Message> = req
59            .chat_history
60            .clone()
61            .into_iter()
62            .map(|message| message.try_into())
63            .collect::<Result<Vec<Vec<Message>>, _>>()?
64            .into_iter()
65            .flatten()
66            .collect();
67
68        full_history.extend(chat_history);
69
70        let tool_choice = req
71            .tool_choice
72            .clone()
73            .map(crate::providers::openrouter::ToolChoice::try_from)
74            .transpose()?;
75
76        Ok(Self {
77            model: model.to_string(),
78            messages: full_history,
79            temperature: req.temperature,
80            tools: req
81                .tools
82                .clone()
83                .into_iter()
84                .map(ToolDefinition::from)
85                .collect::<Vec<_>>(),
86            tool_choice,
87            additional_params: req.additional_params,
88        })
89    }
90}
91
92#[derive(Clone)]
93pub struct CompletionModel<T = reqwest::Client> {
94    pub(crate) client: Client<T>,
95    pub model: String,
96}
97
98impl<T> CompletionModel<T> {
99    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
100        Self {
101            client,
102            model: model.into(),
103        }
104    }
105}
106
107impl<T> completion::CompletionModel for CompletionModel<T>
108where
109    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
110{
111    type Response = CompletionResponse;
112    type StreamingResponse = openai::StreamingCompletionResponse;
113
114    type Client = Client<T>;
115
116    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
117        Self::new(client.clone(), model)
118    }
119
120    #[cfg_attr(feature = "worker", worker::send)]
121    async fn completion(
122        &self,
123        completion_request: completion::CompletionRequest,
124    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
125        let span = if tracing::Span::current().is_disabled() {
126            info_span!(
127                target: "rig::completions",
128                "chat",
129                gen_ai.operation.name = "chat",
130                gen_ai.provider.name = "xai",
131                gen_ai.request.model = self.model,
132                gen_ai.system_instructions = tracing::field::Empty,
133                gen_ai.response.id = tracing::field::Empty,
134                gen_ai.response.model = tracing::field::Empty,
135                gen_ai.usage.output_tokens = tracing::field::Empty,
136                gen_ai.usage.input_tokens = tracing::field::Empty,
137            )
138        } else {
139            tracing::Span::current()
140        };
141
142        span.record("gen_ai.system_instructions", &completion_request.preamble);
143
144        let request =
145            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
146
147        if enabled!(Level::TRACE) {
148            tracing::trace!(target: "rig::completions",
149                "xAI completion request: {}",
150                serde_json::to_string_pretty(&request)?
151            );
152        }
153
154        let body = serde_json::to_vec(&request)?;
155        let req = self
156            .client
157            .post("/v1/chat/completions")?
158            .body(body)
159            .map_err(|e| CompletionError::HttpError(e.into()))?;
160
161        async move {
162            let response = self.client.send::<_, Bytes>(req).await?;
163            let status = response.status();
164            let response_body = response.into_body().into_future().await?.to_vec();
165
166            if status.is_success() {
167                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
168                    ApiResponse::Ok(response) => {
169                        if enabled!(Level::TRACE) {
170                            tracing::trace!(target: "rig::completions",
171                                "xAI completion response: {}",
172                                serde_json::to_string_pretty(&response)?
173                            );
174                        }
175
176                        response.try_into()
177                    }
178                    ApiResponse::Error(error) => {
179                        Err(CompletionError::ProviderError(error.message()))
180                    }
181                }
182            } else {
183                Err(CompletionError::ProviderError(
184                    String::from_utf8_lossy(&response_body).to_string(),
185                ))
186            }
187        }
188        .instrument(span)
189        .await
190    }
191
192    #[cfg_attr(feature = "worker", worker::send)]
193    async fn stream(
194        &self,
195        request: CompletionRequest,
196    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
197        CompletionModel::stream(self, request).await
198    }
199}
200
201pub mod xai_api_types {
202    use serde::{Deserialize, Serialize};
203
204    use crate::OneOrMany;
205    use crate::completion::{self, CompletionError};
206    use crate::providers::openai::{AssistantContent, Message};
207
208    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
209        type Error = CompletionError;
210
211        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
212            let choice = response.choices.first().ok_or_else(|| {
213                CompletionError::ResponseError("Response contained no choices".to_owned())
214            })?;
215            let content = match &choice.message {
216                Message::Assistant {
217                    content,
218                    tool_calls,
219                    ..
220                } => {
221                    let mut content = content
222                        .iter()
223                        .map(|c| match c {
224                            AssistantContent::Text { text } => {
225                                completion::AssistantContent::text(text)
226                            }
227                            AssistantContent::Refusal { refusal } => {
228                                completion::AssistantContent::text(refusal)
229                            }
230                        })
231                        .collect::<Vec<_>>();
232
233                    content.extend(
234                        tool_calls
235                            .iter()
236                            .map(|call| {
237                                completion::AssistantContent::tool_call(
238                                    &call.id,
239                                    &call.function.name,
240                                    call.function.arguments.clone(),
241                                )
242                            })
243                            .collect::<Vec<_>>(),
244                    );
245                    Ok(content)
246                }
247                _ => Err(CompletionError::ResponseError(
248                    "Response did not contain a valid message or tool call".into(),
249                )),
250            }?;
251
252            let choice = OneOrMany::many(content).map_err(|_| {
253                CompletionError::ResponseError(
254                    "Response contained no message or tool call (empty)".to_owned(),
255                )
256            })?;
257
258            let usage = completion::Usage {
259                input_tokens: response.usage.prompt_tokens as u64,
260                output_tokens: response.usage.completion_tokens as u64,
261                total_tokens: response.usage.total_tokens as u64,
262            };
263
264            Ok(completion::CompletionResponse {
265                choice,
266                usage,
267                raw_response: response,
268            })
269        }
270    }
271
272    impl From<completion::ToolDefinition> for ToolDefinition {
273        fn from(tool: completion::ToolDefinition) -> Self {
274            Self {
275                r#type: "function".into(),
276                function: tool,
277            }
278        }
279    }
280
281    #[derive(Clone, Debug, Deserialize, Serialize)]
282    pub struct ToolDefinition {
283        pub r#type: String,
284        pub function: completion::ToolDefinition,
285    }
286
287    #[derive(Debug, Deserialize)]
288    pub struct Function {
289        pub name: String,
290        pub arguments: String,
291    }
292
293    #[derive(Debug, Deserialize, Serialize)]
294    pub struct CompletionResponse {
295        pub id: String,
296        pub model: String,
297        pub choices: Vec<Choice>,
298        pub created: i64,
299        pub object: String,
300        pub system_fingerprint: String,
301        pub usage: Usage,
302    }
303
304    #[derive(Debug, Deserialize, Serialize)]
305    pub struct Choice {
306        pub finish_reason: String,
307        pub index: i32,
308        pub message: Message,
309    }
310
311    #[derive(Debug, Deserialize, Serialize)]
312    pub struct Usage {
313        pub completion_tokens: i32,
314        pub prompt_tokens: i32,
315        pub total_tokens: i32,
316    }
317}