Skip to main content

rig_core/providers/xai/
completion.rs

1//! xAI Completion Integration
2//!
3//! Uses the xAI Responses API: <https://docs.x.ai/docs/guides/chat>
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{Instrument, Level, enabled, info_span};
9
10use super::api::{ApiResponse, Message, ToolDefinition};
11use super::client::Client;
12use crate::OneOrMany;
13use crate::completion::{self, CompletionError, CompletionRequest};
14use crate::http_client::HttpClientExt;
15use crate::providers::openai::completion::ToolChoice;
16use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
17use crate::providers::openai::responses_api::{Output, ResponsesUsage};
18use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
19
20/// xAI completion models as of 2025-06-04
21pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30// ================================================================
31// Request Types
32// ================================================================
33
34#[derive(Debug, Serialize, Deserialize)]
35pub(super) struct XAICompletionRequest {
36    model: String,
37    pub input: Vec<Message>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    temperature: Option<f64>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    max_output_tokens: Option<u64>,
42    #[serde(skip_serializing_if = "Vec::is_empty")]
43    tools: Vec<Value>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    tool_choice: Option<ToolChoice>,
46    #[serde(flatten, skip_serializing_if = "Option::is_none")]
47    pub additional_params: Option<serde_json::Value>,
48}
49
50impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
51    type Error = CompletionError;
52
53    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
54        if req.output_schema.is_some() {
55            tracing::warn!("Structured outputs currently not supported for xAI");
56        }
57        let model = req.model.clone().unwrap_or_else(|| model.to_string());
58        let mut input: Vec<Message> = req
59            .preamble
60            .as_ref()
61            .map_or_else(Vec::new, |p| vec![Message::system(p)]);
62
63        if let Some(docs) = req.normalized_documents() {
64            let docs: Vec<Message> = docs.try_into()?;
65            input.extend(docs);
66        }
67
68        let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null);
69
70        for msg in req.chat_history {
71            let msg: Vec<Message> = msg.try_into()?;
72            input.extend(msg);
73        }
74
75        let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
76        let mut additional_tools =
77            extract_tools_from_additional_params(&mut additional_params_payload)?;
78        let mut tools = req
79            .tools
80            .into_iter()
81            .map(ToolDefinition::from)
82            .map(serde_json::to_value)
83            .collect::<Result<Vec<_>, _>>()?;
84        tools.append(&mut additional_tools);
85        let additional_params = if additional_params_payload.is_null() {
86            None
87        } else {
88            Some(additional_params_payload)
89        };
90
91        Ok(Self {
92            model: model.to_string(),
93            input,
94            temperature: req.temperature,
95            max_output_tokens: req.max_tokens,
96            tools,
97            tool_choice,
98            additional_params,
99        })
100    }
101}
102
103fn extract_tools_from_additional_params(
104    additional_params: &mut Value,
105) -> Result<Vec<Value>, CompletionError> {
106    if let Some(map) = additional_params.as_object_mut()
107        && let Some(raw_tools) = map.remove("tools")
108    {
109        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
110            CompletionError::RequestError(
111                format!("Invalid xAI `additional_params.tools` payload: {err}").into(),
112            )
113        });
114    }
115
116    Ok(Vec::new())
117}
118
119// ================================================================
120// Response Types
121// ================================================================
122
123#[derive(Debug, Deserialize, Serialize)]
124pub struct CompletionResponse {
125    pub id: String,
126    pub model: String,
127    pub output: Vec<Output>,
128    #[serde(default)]
129    pub created: i64,
130    #[serde(default)]
131    pub object: String,
132    #[serde(default)]
133    pub status: Option<String>,
134    pub usage: Option<ResponsesUsage>,
135}
136
137impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
138    type Error = CompletionError;
139
140    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
141        let content: Vec<completion::AssistantContent> = response
142            .output
143            .iter()
144            .cloned()
145            .flat_map(<Vec<completion::AssistantContent>>::from)
146            .collect();
147
148        let choice = OneOrMany::many(content).map_err(|_| {
149            CompletionError::ResponseError("Response contained no output".to_owned())
150        })?;
151
152        let usage = response
153            .usage
154            .as_ref()
155            .map(|u| completion::Usage {
156                input_tokens: u.input_tokens,
157                output_tokens: u.output_tokens,
158                total_tokens: u.total_tokens,
159                cached_input_tokens: u
160                    .input_tokens_details
161                    .clone()
162                    .map(|x| x.cached_tokens)
163                    .unwrap_or_default(),
164                cache_creation_input_tokens: 0,
165                tool_use_prompt_tokens: 0,
166                reasoning_tokens: 0,
167            })
168            .unwrap_or_default();
169
170        Ok(completion::CompletionResponse {
171            choice,
172            usage,
173            raw_response: response,
174            message_id: None,
175        })
176    }
177}
178
179// ================================================================
180// Completion Model
181// ================================================================
182
183#[derive(Clone)]
184pub struct CompletionModel<T = reqwest::Client> {
185    pub(crate) client: Client<T>,
186    pub model: String,
187}
188
189impl<T> CompletionModel<T> {
190    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
191        Self {
192            client,
193            model: model.into(),
194        }
195    }
196}
197
198impl<T> completion::CompletionModel for CompletionModel<T>
199where
200    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
201{
202    type Response = CompletionResponse;
203    type StreamingResponse = StreamingCompletionResponse;
204
205    type Client = Client<T>;
206
207    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
208        Self::new(client.clone(), model)
209    }
210
211    async fn completion(
212        &self,
213        completion_request: completion::CompletionRequest,
214    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
215        let span = if tracing::Span::current().is_disabled() {
216            info_span!(
217                target: "rig::completions",
218                "chat",
219                gen_ai.operation.name = "chat",
220                gen_ai.provider.name = "xai",
221                gen_ai.request.model = self.model,
222                gen_ai.system_instructions = tracing::field::Empty,
223                gen_ai.response.id = tracing::field::Empty,
224                gen_ai.response.model = tracing::field::Empty,
225                gen_ai.usage.output_tokens = tracing::field::Empty,
226                gen_ai.usage.input_tokens = tracing::field::Empty,
227                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
228            )
229        } else {
230            tracing::Span::current()
231        };
232
233        span.record("gen_ai.system_instructions", &completion_request.preamble);
234
235        let request =
236            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
237
238        if enabled!(Level::TRACE) {
239            tracing::trace!(target: "rig::completions",
240                "xAI completion request: {}",
241                serde_json::to_string_pretty(&request)?
242            );
243        }
244
245        let body = serde_json::to_vec(&request)?;
246        let req = self
247            .client
248            .post("/v1/responses")?
249            .body(body)
250            .map_err(|e| CompletionError::HttpError(e.into()))?;
251
252        async move {
253            let response = self.client.send::<_, Bytes>(req).await?;
254            let status = response.status();
255            let response_body = response.into_body().into_future().await?.to_vec();
256
257            if status.is_success() {
258                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
259                    ApiResponse::Ok(response) => {
260                        if enabled!(Level::TRACE) {
261                            tracing::trace!(target: "rig::completions",
262                                "xAI completion response: {}",
263                                serde_json::to_string_pretty(&response)?
264                            );
265                        }
266
267                        response.try_into()
268                    }
269                    ApiResponse::Error(error) => {
270                        Err(CompletionError::ProviderError(error.message()))
271                    }
272                }
273            } else {
274                Err(CompletionError::ProviderError(
275                    String::from_utf8_lossy(&response_body).to_string(),
276                ))
277            }
278        }
279        .instrument(span)
280        .await
281    }
282
283    async fn stream(
284        &self,
285        request: CompletionRequest,
286    ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
287        self.stream(request).await
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::XAICompletionRequest;
294    use crate::OneOrMany;
295    use crate::completion::CompletionRequest;
296    use crate::completion::request::Document;
297
298    #[test]
299    fn xai_request_includes_normalized_documents() {
300        let request = CompletionRequest {
301            model: None,
302            preamble: Some("Use the provided context.".to_string()),
303            chat_history: OneOrMany::one("What is glarb-glarb?".into()),
304            documents: vec![Document {
305                id: "doc_1".to_string(),
306                text: "Definition of glarb-glarb: an ancient tool.".to_string(),
307                additional_props: Default::default(),
308            }],
309            tools: vec![],
310            temperature: None,
311            max_tokens: None,
312            tool_choice: None,
313            additional_params: None,
314            output_schema: None,
315        };
316
317        let xai_request = XAICompletionRequest::try_from(("grok-4-0709", request))
318            .expect("request conversion should succeed");
319        let serialized = serde_json::to_value(xai_request).expect("serialization should succeed");
320        let input = serialized["input"]
321            .as_array()
322            .expect("xAI request input should be an array");
323
324        assert!(
325            input
326                .iter()
327                .any(|message| message.to_string().contains("glarb-glarb")),
328            "normalized documents should be forwarded into xAI input"
329        );
330    }
331}