Skip to main content

rig/providers/xai/
completion.rs

1//! xAI Completion Integration
2//!
3//! Uses the xAI Responses API: <https://docs.x.ai/docs/guides/chat>
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{Instrument, Level, enabled, info_span};
9
10use super::api::{ApiResponse, Message, ToolDefinition};
11use super::client::Client;
12use crate::OneOrMany;
13use crate::completion::{self, CompletionError, CompletionRequest};
14use crate::http_client::HttpClientExt;
15use crate::providers::openai::completion::ToolChoice;
16use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
17use crate::providers::openai::responses_api::{Output, ResponsesUsage};
18use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
19
20/// xAI completion models as of 2025-06-04
21pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30// ================================================================
31// Request Types
32// ================================================================
33
34#[derive(Debug, Serialize, Deserialize)]
35pub(super) struct XAICompletionRequest {
36    model: String,
37    pub input: Vec<Message>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    temperature: Option<f64>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    max_output_tokens: Option<u64>,
42    #[serde(skip_serializing_if = "Vec::is_empty")]
43    tools: Vec<Value>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    tool_choice: Option<ToolChoice>,
46    #[serde(flatten, skip_serializing_if = "Option::is_none")]
47    pub additional_params: Option<serde_json::Value>,
48}
49
50impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
51    type Error = CompletionError;
52
53    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
54        if req.output_schema.is_some() {
55            tracing::warn!("Structured outputs currently not supported for xAI");
56        }
57        let model = req.model.clone().unwrap_or_else(|| model.to_string());
58        let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null);
59        let mut input: Vec<Message> = req
60            .preamble
61            .as_ref()
62            .map_or_else(Vec::new, |p| vec![Message::system(p)]);
63
64        for msg in req.chat_history {
65            let msg: Vec<Message> = msg.try_into()?;
66            input.extend(msg);
67        }
68
69        let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
70        let mut additional_tools =
71            extract_tools_from_additional_params(&mut additional_params_payload)?;
72        let mut tools = req
73            .tools
74            .into_iter()
75            .map(ToolDefinition::from)
76            .map(serde_json::to_value)
77            .collect::<Result<Vec<_>, _>>()?;
78        tools.append(&mut additional_tools);
79        let additional_params = if additional_params_payload.is_null() {
80            None
81        } else {
82            Some(additional_params_payload)
83        };
84
85        Ok(Self {
86            model: model.to_string(),
87            input,
88            temperature: req.temperature,
89            max_output_tokens: req.max_tokens,
90            tools,
91            tool_choice,
92            additional_params,
93        })
94    }
95}
96
97fn extract_tools_from_additional_params(
98    additional_params: &mut Value,
99) -> Result<Vec<Value>, CompletionError> {
100    if let Some(map) = additional_params.as_object_mut()
101        && let Some(raw_tools) = map.remove("tools")
102    {
103        return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
104            CompletionError::RequestError(
105                format!("Invalid xAI `additional_params.tools` payload: {err}").into(),
106            )
107        });
108    }
109
110    Ok(Vec::new())
111}
112
113// ================================================================
114// Response Types
115// ================================================================
116
117#[derive(Debug, Deserialize, Serialize)]
118pub struct CompletionResponse {
119    pub id: String,
120    pub model: String,
121    pub output: Vec<Output>,
122    #[serde(default)]
123    pub created: i64,
124    #[serde(default)]
125    pub object: String,
126    #[serde(default)]
127    pub status: Option<String>,
128    pub usage: Option<ResponsesUsage>,
129}
130
131impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
132    type Error = CompletionError;
133
134    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
135        let content: Vec<completion::AssistantContent> = response
136            .output
137            .iter()
138            .cloned()
139            .flat_map(<Vec<completion::AssistantContent>>::from)
140            .collect();
141
142        let choice = OneOrMany::many(content).map_err(|_| {
143            CompletionError::ResponseError("Response contained no output".to_owned())
144        })?;
145
146        let usage = response
147            .usage
148            .as_ref()
149            .map(|u| completion::Usage {
150                input_tokens: u.input_tokens,
151                output_tokens: u.output_tokens,
152                total_tokens: u.total_tokens,
153                cached_input_tokens: u
154                    .input_tokens_details
155                    .clone()
156                    .map(|x| x.cached_tokens)
157                    .unwrap_or_default(),
158                cache_creation_input_tokens: 0,
159            })
160            .unwrap_or_default();
161
162        Ok(completion::CompletionResponse {
163            choice,
164            usage,
165            raw_response: response,
166            message_id: None,
167        })
168    }
169}
170
171// ================================================================
172// Completion Model
173// ================================================================
174
175#[derive(Clone)]
176pub struct CompletionModel<T = reqwest::Client> {
177    pub(crate) client: Client<T>,
178    pub model: String,
179}
180
181impl<T> CompletionModel<T> {
182    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
183        Self {
184            client,
185            model: model.into(),
186        }
187    }
188}
189
190impl<T> completion::CompletionModel for CompletionModel<T>
191where
192    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
193{
194    type Response = CompletionResponse;
195    type StreamingResponse = StreamingCompletionResponse;
196
197    type Client = Client<T>;
198
199    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
200        Self::new(client.clone(), model)
201    }
202
203    async fn completion(
204        &self,
205        completion_request: completion::CompletionRequest,
206    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
207        let span = if tracing::Span::current().is_disabled() {
208            info_span!(
209                target: "rig::completions",
210                "chat",
211                gen_ai.operation.name = "chat",
212                gen_ai.provider.name = "xai",
213                gen_ai.request.model = self.model,
214                gen_ai.system_instructions = tracing::field::Empty,
215                gen_ai.response.id = tracing::field::Empty,
216                gen_ai.response.model = tracing::field::Empty,
217                gen_ai.usage.output_tokens = tracing::field::Empty,
218                gen_ai.usage.input_tokens = tracing::field::Empty,
219                gen_ai.usage.cached_tokens = tracing::field::Empty,
220            )
221        } else {
222            tracing::Span::current()
223        };
224
225        span.record("gen_ai.system_instructions", &completion_request.preamble);
226
227        let request =
228            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
229
230        if enabled!(Level::TRACE) {
231            tracing::trace!(target: "rig::completions",
232                "xAI completion request: {}",
233                serde_json::to_string_pretty(&request)?
234            );
235        }
236
237        let body = serde_json::to_vec(&request)?;
238        let req = self
239            .client
240            .post("/v1/responses")?
241            .body(body)
242            .map_err(|e| CompletionError::HttpError(e.into()))?;
243
244        async move {
245            let response = self.client.send::<_, Bytes>(req).await?;
246            let status = response.status();
247            let response_body = response.into_body().into_future().await?.to_vec();
248
249            if status.is_success() {
250                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
251                    ApiResponse::Ok(response) => {
252                        if enabled!(Level::TRACE) {
253                            tracing::trace!(target: "rig::completions",
254                                "xAI completion response: {}",
255                                serde_json::to_string_pretty(&response)?
256                            );
257                        }
258
259                        response.try_into()
260                    }
261                    ApiResponse::Error(error) => {
262                        Err(CompletionError::ProviderError(error.message()))
263                    }
264                }
265            } else {
266                Err(CompletionError::ProviderError(
267                    String::from_utf8_lossy(&response_body).to_string(),
268                ))
269            }
270        }
271        .instrument(span)
272        .await
273    }
274
275    async fn stream(
276        &self,
277        request: CompletionRequest,
278    ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
279        self.stream(request).await
280    }
281}