rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    http_client::HttpClientExt,
9    providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use bytes::Bytes;
17use serde::{Deserialize, Serialize};
18use tracing::{Instrument, info_span};
19use xai_api_types::{CompletionResponse, ToolDefinition};
20
21/// xAI completion models as of 2025-06-04
22pub const GROK_2_1212: &str = "grok-2-1212";
23pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
24pub const GROK_3: &str = "grok-3";
25pub const GROK_3_FAST: &str = "grok-3-fast";
26pub const GROK_3_MINI: &str = "grok-3-mini";
27pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
28pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
29pub const GROK_4: &str = "grok-4-0709";
30
31#[derive(Debug, Serialize, Deserialize)]
32pub(super) struct XAICompletionRequest {
33    model: String,
34    pub messages: Vec<Message>,
35    #[serde(flatten, skip_serializing_if = "Option::is_none")]
36    temperature: Option<f64>,
37    #[serde(skip_serializing_if = "Vec::is_empty")]
38    tools: Vec<ToolDefinition>,
39    #[serde(flatten, skip_serializing_if = "Option::is_none")]
40    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
41    #[serde(flatten, skip_serializing_if = "Option::is_none")]
42    pub additional_params: Option<serde_json::Value>,
43}
44
45impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
46    type Error = CompletionError;
47
48    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
49        let mut full_history: Vec<Message> = match &req.preamble {
50            Some(preamble) => vec![Message::system(preamble)],
51            None => vec![],
52        };
53        if let Some(docs) = req.normalized_documents() {
54            let docs: Vec<Message> = docs.try_into()?;
55            full_history.extend(docs);
56        }
57
58        let chat_history: Vec<Message> = req
59            .chat_history
60            .clone()
61            .into_iter()
62            .map(|message| message.try_into())
63            .collect::<Result<Vec<Vec<Message>>, _>>()?
64            .into_iter()
65            .flatten()
66            .collect();
67
68        full_history.extend(chat_history);
69
70        let tool_choice = req
71            .tool_choice
72            .clone()
73            .map(crate::providers::openrouter::ToolChoice::try_from)
74            .transpose()?;
75
76        Ok(Self {
77            model: model.to_string(),
78            messages: full_history,
79            temperature: req.temperature,
80            tools: req
81                .tools
82                .clone()
83                .into_iter()
84                .map(ToolDefinition::from)
85                .collect::<Vec<_>>(),
86            tool_choice,
87            additional_params: req.additional_params,
88        })
89    }
90}
91
92#[derive(Clone)]
93pub struct CompletionModel<T = reqwest::Client> {
94    pub(crate) client: Client<T>,
95    pub model: String,
96}
97
98impl<T> CompletionModel<T> {
99    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
100        Self {
101            client,
102            model: model.into(),
103        }
104    }
105}
106
107impl<T> completion::CompletionModel for CompletionModel<T>
108where
109    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
110{
111    type Response = CompletionResponse;
112    type StreamingResponse = openai::StreamingCompletionResponse;
113
114    type Client = Client<T>;
115
116    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
117        Self::new(client.clone(), model)
118    }
119
120    #[cfg_attr(feature = "worker", worker::send)]
121    async fn completion(
122        &self,
123        completion_request: completion::CompletionRequest,
124    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
125        let preamble = completion_request.preamble.clone();
126        let request =
127            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
128        let request_messages_json_str = serde_json::to_string(&request.messages).unwrap();
129
130        let span = if tracing::Span::current().is_disabled() {
131            info_span!(
132                target: "rig::completions",
133                "chat",
134                gen_ai.operation.name = "chat",
135                gen_ai.provider.name = "xai",
136                gen_ai.request.model = self.model,
137                gen_ai.system_instructions = preamble,
138                gen_ai.response.id = tracing::field::Empty,
139                gen_ai.response.model = tracing::field::Empty,
140                gen_ai.usage.output_tokens = tracing::field::Empty,
141                gen_ai.usage.input_tokens = tracing::field::Empty,
142                gen_ai.input.messages = &request_messages_json_str,
143                gen_ai.output.messages = tracing::field::Empty,
144            )
145        } else {
146            tracing::Span::current()
147        };
148
149        tracing::debug!("xAI completion request: {request_messages_json_str}");
150
151        let body = serde_json::to_vec(&request)?;
152        let req = self
153            .client
154            .post("/v1/chat/completions")?
155            .body(body)
156            .map_err(|e| CompletionError::HttpError(e.into()))?;
157
158        async move {
159            let response = self.client.http_client().send::<_, Bytes>(req).await?;
160            let status = response.status();
161            let response_body = response.into_body().into_future().await?.to_vec();
162
163            if status.is_success() {
164                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
165                    ApiResponse::Ok(completion) => completion.try_into(),
166                    ApiResponse::Error(error) => {
167                        Err(CompletionError::ProviderError(error.message()))
168                    }
169                }
170            } else {
171                Err(CompletionError::ProviderError(
172                    String::from_utf8_lossy(&response_body).to_string(),
173                ))
174            }
175        }
176        .instrument(span)
177        .await
178    }
179
180    #[cfg_attr(feature = "worker", worker::send)]
181    async fn stream(
182        &self,
183        request: CompletionRequest,
184    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
185        CompletionModel::stream(self, request).await
186    }
187}
188
189pub mod xai_api_types {
190    use serde::{Deserialize, Serialize};
191
192    use crate::OneOrMany;
193    use crate::completion::{self, CompletionError};
194    use crate::providers::openai::{AssistantContent, Message};
195
196    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
197        type Error = CompletionError;
198
199        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
200            let choice = response.choices.first().ok_or_else(|| {
201                CompletionError::ResponseError("Response contained no choices".to_owned())
202            })?;
203            let content = match &choice.message {
204                Message::Assistant {
205                    content,
206                    tool_calls,
207                    ..
208                } => {
209                    let mut content = content
210                        .iter()
211                        .map(|c| match c {
212                            AssistantContent::Text { text } => {
213                                completion::AssistantContent::text(text)
214                            }
215                            AssistantContent::Refusal { refusal } => {
216                                completion::AssistantContent::text(refusal)
217                            }
218                        })
219                        .collect::<Vec<_>>();
220
221                    content.extend(
222                        tool_calls
223                            .iter()
224                            .map(|call| {
225                                completion::AssistantContent::tool_call(
226                                    &call.id,
227                                    &call.function.name,
228                                    call.function.arguments.clone(),
229                                )
230                            })
231                            .collect::<Vec<_>>(),
232                    );
233                    Ok(content)
234                }
235                _ => Err(CompletionError::ResponseError(
236                    "Response did not contain a valid message or tool call".into(),
237                )),
238            }?;
239
240            let choice = OneOrMany::many(content).map_err(|_| {
241                CompletionError::ResponseError(
242                    "Response contained no message or tool call (empty)".to_owned(),
243                )
244            })?;
245
246            let usage = completion::Usage {
247                input_tokens: response.usage.prompt_tokens as u64,
248                output_tokens: response.usage.completion_tokens as u64,
249                total_tokens: response.usage.total_tokens as u64,
250            };
251
252            Ok(completion::CompletionResponse {
253                choice,
254                usage,
255                raw_response: response,
256            })
257        }
258    }
259
260    impl From<completion::ToolDefinition> for ToolDefinition {
261        fn from(tool: completion::ToolDefinition) -> Self {
262            Self {
263                r#type: "function".into(),
264                function: tool,
265            }
266        }
267    }
268
269    #[derive(Clone, Debug, Deserialize, Serialize)]
270    pub struct ToolDefinition {
271        pub r#type: String,
272        pub function: completion::ToolDefinition,
273    }
274
275    #[derive(Debug, Deserialize)]
276    pub struct Function {
277        pub name: String,
278        pub arguments: String,
279    }
280
281    #[derive(Debug, Deserialize, Serialize)]
282    pub struct CompletionResponse {
283        pub id: String,
284        pub model: String,
285        pub choices: Vec<Choice>,
286        pub created: i64,
287        pub object: String,
288        pub system_fingerprint: String,
289        pub usage: Usage,
290    }
291
292    #[derive(Debug, Deserialize, Serialize)]
293    pub struct Choice {
294        pub finish_reason: String,
295        pub index: i32,
296        pub message: Message,
297    }
298
299    #[derive(Debug, Deserialize, Serialize)]
300    pub struct Usage {
301        pub completion_tokens: i32,
302        pub prompt_tokens: i32,
303        pub total_tokens: i32,
304    }
305}