Skip to main content

rig_core/providers/xai/
completion.rs

1//! xAI Completion Integration
2//!
3//! Uses the xAI Responses API: <https://docs.x.ai/docs/guides/chat>
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{Instrument, Level, enabled, info_span};
9
10use super::api::{ApiResponse, Message, ToolDefinition};
11use super::client::Client;
12use crate::OneOrMany;
13use crate::completion::{self, CompletionError, CompletionRequest};
14use crate::http_client::HttpClientExt;
15use crate::providers::openai::completion::ToolChoice;
16use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
17use crate::providers::openai::responses_api::{Output, ResponsesUsage};
18use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
19
20/// xAI completion models as of 2025-06-04
21pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30// ================================================================
31// Request Types
32// ================================================================
33
34#[derive(Debug, Serialize, Deserialize)]
35pub(super) struct XAICompletionRequest {
36    model: String,
37    pub input: Vec<Message>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    temperature: Option<f64>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    max_output_tokens: Option<u64>,
42    #[serde(skip_serializing_if = "Vec::is_empty")]
43    tools: Vec<Value>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    tool_choice: Option<ToolChoice>,
46    #[serde(flatten, skip_serializing_if = "Option::is_none")]
47    pub additional_params: Option<serde_json::Value>,
48}
49
50impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
51    type Error = CompletionError;
52
53    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
54        let chat_history = req.chat_history_with_documents();
55        if req.output_schema.is_some() {
56            tracing::warn!("Structured outputs currently not supported for xAI");
57        }
58        let model = req.model.clone().unwrap_or_else(|| model.to_string());
59        let mut input: Vec<Message> = req
60            .preamble
61            .as_ref()
62            .map_or_else(Vec::new, |p| vec![Message::system(p)]);
63
64        let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null);
65
66        for msg in chat_history {
67            let msg: Vec<Message> = msg.try_into()?;
68            input.extend(msg);
69        }
70
71        let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
72        let mut additional_tools =
73            extract_tools_from_additional_params(&mut additional_params_payload)?;
74        let mut tools = req
75            .tools
76            .into_iter()
77            .map(ToolDefinition::from)
78            .map(serde_json::to_value)
79            .collect::<Result<Vec<_>, _>>()?;
80        tools.append(&mut additional_tools);
81        let additional_params = if additional_params_payload.is_null() {
82            None
83        } else {
84            Some(additional_params_payload)
85        };
86
87        Ok(Self {
88            model: model.to_string(),
89            input,
90            temperature: req.temperature,
91            max_output_tokens: req.max_tokens,
92            tools,
93            tool_choice,
94            additional_params,
95        })
96    }
97}
98
99fn extract_tools_from_additional_params(
100    additional_params: &mut Value,
101) -> Result<Vec<Value>, CompletionError> {
102    if let Some(map) = additional_params.as_object_mut()
103        && let Some(raw_tools) = map.remove("tools")
104    {
105        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
106            CompletionError::RequestError(
107                format!("Invalid xAI `additional_params.tools` payload: {err}").into(),
108            )
109        });
110    }
111
112    Ok(Vec::new())
113}
114
115// ================================================================
116// Response Types
117// ================================================================
118
119#[derive(Debug, Deserialize, Serialize)]
120pub struct CompletionResponse {
121    pub id: String,
122    pub model: String,
123    pub output: Vec<Output>,
124    #[serde(default)]
125    pub created: i64,
126    #[serde(default)]
127    pub object: String,
128    #[serde(default)]
129    pub status: Option<String>,
130    pub usage: Option<ResponsesUsage>,
131}
132
133impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
134    type Error = CompletionError;
135
136    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
137        let content: Vec<completion::AssistantContent> = response
138            .output
139            .iter()
140            .cloned()
141            .flat_map(<Vec<completion::AssistantContent>>::from)
142            .collect();
143
144        let choice = OneOrMany::many(content).map_err(|_| {
145            CompletionError::ResponseError("Response contained no output".to_owned())
146        })?;
147
148        let usage = response
149            .usage
150            .as_ref()
151            .map(|u| completion::Usage {
152                input_tokens: u.input_tokens,
153                output_tokens: u.output_tokens,
154                total_tokens: u.total_tokens,
155                cached_input_tokens: u
156                    .input_tokens_details
157                    .clone()
158                    .map(|x| x.cached_tokens)
159                    .unwrap_or_default(),
160                cache_creation_input_tokens: 0,
161                tool_use_prompt_tokens: 0,
162                reasoning_tokens: 0,
163            })
164            .unwrap_or_default();
165
166        Ok(completion::CompletionResponse {
167            choice,
168            usage,
169            raw_response: response,
170            message_id: None,
171        })
172    }
173}
174
175// ================================================================
176// Completion Model
177// ================================================================
178
179#[derive(Clone)]
180pub struct CompletionModel<T = reqwest::Client> {
181    pub(crate) client: Client<T>,
182    pub model: String,
183}
184
185impl<T> CompletionModel<T> {
186    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
187        Self {
188            client,
189            model: model.into(),
190        }
191    }
192}
193
194impl<T> completion::CompletionModel for CompletionModel<T>
195where
196    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
197{
198    type Response = CompletionResponse;
199    type StreamingResponse = StreamingCompletionResponse;
200
201    type Client = Client<T>;
202
203    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
204        Self::new(client.clone(), model)
205    }
206
207    async fn completion(
208        &self,
209        completion_request: completion::CompletionRequest,
210    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
211        let span = if tracing::Span::current().is_disabled() {
212            info_span!(
213                target: "rig::completions",
214                "chat",
215                gen_ai.operation.name = "chat",
216                gen_ai.provider.name = "xai",
217                gen_ai.request.model = self.model,
218                gen_ai.system_instructions = tracing::field::Empty,
219                gen_ai.response.id = tracing::field::Empty,
220                gen_ai.response.model = tracing::field::Empty,
221                gen_ai.usage.output_tokens = tracing::field::Empty,
222                gen_ai.usage.input_tokens = tracing::field::Empty,
223                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
224            )
225        } else {
226            tracing::Span::current()
227        };
228
229        span.record("gen_ai.system_instructions", &completion_request.preamble);
230
231        let request =
232            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
233
234        if enabled!(Level::TRACE) {
235            tracing::trace!(target: "rig::completions",
236                "xAI completion request: {}",
237                serde_json::to_string_pretty(&request)?
238            );
239        }
240
241        let body = serde_json::to_vec(&request)?;
242        let req = self
243            .client
244            .post("/v1/responses")?
245            .body(body)
246            .map_err(|e| CompletionError::HttpError(e.into()))?;
247
248        async move {
249            let response = self.client.send::<_, Bytes>(req).await?;
250            let status = response.status();
251            let response_body = response.into_body().into_future().await?.to_vec();
252
253            if status.is_success() {
254                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
255                    ApiResponse::Ok(response) => {
256                        if enabled!(Level::TRACE) {
257                            tracing::trace!(target: "rig::completions",
258                                "xAI completion response: {}",
259                                serde_json::to_string_pretty(&response)?
260                            );
261                        }
262
263                        response.try_into()
264                    }
265                    ApiResponse::Error(error) => {
266                        Err(CompletionError::ProviderError(error.message()))
267                    }
268                }
269            } else {
270                Err(CompletionError::ProviderError(
271                    String::from_utf8_lossy(&response_body).to_string(),
272                ))
273            }
274        }
275        .instrument(span)
276        .await
277    }
278
279    async fn stream(
280        &self,
281        request: CompletionRequest,
282    ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
283        self.stream(request).await
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::XAICompletionRequest;
290    use crate::OneOrMany;
291    use crate::completion::request::Document;
292    use crate::completion::{CompletionRequest, CompletionRequestBuilder, Message};
293    use crate::test_utils::MockCompletionModel;
294
295    #[test]
296    fn xai_request_includes_normalized_documents() {
297        let request =
298            CompletionRequestBuilder::new(MockCompletionModel::default(), "What is glarb-glarb?")
299                .message(Message::system("Use the provided context."))
300                .document(Document {
301                    id: "doc_1".to_string(),
302                    text: "Definition of glarb-glarb: an ancient tool.".to_string(),
303                    additional_props: Default::default(),
304                })
305                .build();
306
307        let xai_request = XAICompletionRequest::try_from(("grok-4-0709", request))
308            .expect("request conversion should succeed");
309        let serialized = serde_json::to_value(xai_request).expect("serialization should succeed");
310        let input = serialized["input"]
311            .as_array()
312            .expect("xAI request input should be an array");
313
314        assert!(
315            input
316                .iter()
317                .any(|message| message.to_string().contains("glarb-glarb")),
318            "normalized documents should be forwarded into xAI input"
319        );
320    }
321
322    #[test]
323    fn xai_direct_request_keeps_documents_after_system_messages() {
324        let request = CompletionRequest {
325            model: None,
326            preamble: None,
327            chat_history: OneOrMany::many(vec![
328                Message::system("System prompt"),
329                Message::assistant("Earlier assistant turn"),
330                Message::system("Mid-conversation instruction"),
331                Message::user("What is glarb-glarb?"),
332            ])
333            .unwrap(),
334            documents: vec![Document {
335                id: "doc_1".to_string(),
336                text: "Definition of glarb-glarb: an ancient tool.".to_string(),
337                additional_props: Default::default(),
338            }],
339            tools: vec![],
340            temperature: None,
341            max_tokens: None,
342            tool_choice: None,
343            additional_params: None,
344            output_schema: None,
345        };
346
347        let xai_request = XAICompletionRequest::try_from(("grok-4-0709", request))
348            .expect("request conversion should succeed");
349        let serialized = serde_json::to_value(xai_request).expect("serialization should succeed");
350        let input = serialized["input"]
351            .as_array()
352            .expect("xAI request input should be an array");
353
354        assert_eq!(input.len(), 5);
355        assert_eq!(input[0]["role"], "system");
356        assert_eq!(input[1]["role"], "user");
357        assert!(
358            input[1].to_string().contains("<file id: doc_1>"),
359            "document input should follow leading system input: {input:?}"
360        );
361        assert_eq!(input[2]["role"], "assistant");
362        assert_eq!(input[3]["role"], "system");
363        assert_eq!(input[4]["role"], "user");
364        assert_eq!(
365            input
366                .iter()
367                .filter(|message| message.to_string().contains("<file id: doc_1>"))
368                .count(),
369            1,
370            "document input should appear exactly once: {input:?}"
371        );
372    }
373}