Skip to main content

rig/providers/xai/
completion.rs

1//! xAI Completion Integration
2//!
3//! Uses the xAI Responses API: <https://docs.x.ai/docs/guides/chat>
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use tracing::{Instrument, Level, enabled, info_span};
8
9use super::api::{ApiResponse, Message, ToolDefinition};
10use super::client::Client;
11use crate::OneOrMany;
12use crate::completion::{self, CompletionError, CompletionRequest};
13use crate::http_client::HttpClientExt;
14use crate::providers::openai::completion::ToolChoice;
15use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
16use crate::providers::openai::responses_api::{Output, ResponsesUsage};
17use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
18
19/// xAI completion models as of 2025-06-04
20pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27pub const GROK_4: &str = "grok-4-0709";
28
29// ================================================================
30// Request Types
31// ================================================================
32
33#[derive(Debug, Serialize, Deserialize)]
34pub(super) struct XAICompletionRequest {
35    model: String,
36    pub input: Vec<Message>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    temperature: Option<f64>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    max_output_tokens: Option<u64>,
41    #[serde(skip_serializing_if = "Vec::is_empty")]
42    tools: Vec<ToolDefinition>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    tool_choice: Option<ToolChoice>,
45    #[serde(flatten, skip_serializing_if = "Option::is_none")]
46    pub additional_params: Option<serde_json::Value>,
47}
48
49impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
50    type Error = CompletionError;
51
52    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
53        if req.output_schema.is_some() {
54            tracing::warn!("Structured outputs currently not supported for xAI");
55        }
56        let model = req.model.clone().unwrap_or_else(|| model.to_string());
57        let mut input: Vec<Message> = req
58            .preamble
59            .as_ref()
60            .map_or_else(Vec::new, |p| vec![Message::system(p)]);
61
62        for msg in req.chat_history {
63            let msg: Vec<Message> = msg.try_into()?;
64            input.extend(msg);
65        }
66
67        let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
68        let tools = req.tools.into_iter().map(ToolDefinition::from).collect();
69
70        Ok(Self {
71            model: model.to_string(),
72            input,
73            temperature: req.temperature,
74            max_output_tokens: req.max_tokens,
75            tools,
76            tool_choice,
77            additional_params: req.additional_params,
78        })
79    }
80}
81
82// ================================================================
83// Response Types
84// ================================================================
85
86#[derive(Debug, Deserialize, Serialize)]
87pub struct CompletionResponse {
88    pub id: String,
89    pub model: String,
90    pub output: Vec<Output>,
91    #[serde(default)]
92    pub created: i64,
93    #[serde(default)]
94    pub object: String,
95    #[serde(default)]
96    pub status: Option<String>,
97    pub usage: Option<ResponsesUsage>,
98}
99
100impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
101    type Error = CompletionError;
102
103    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
104        let content: Vec<completion::AssistantContent> = response
105            .output
106            .iter()
107            .cloned()
108            .flat_map(<Vec<completion::AssistantContent>>::from)
109            .collect();
110
111        let choice = OneOrMany::many(content).map_err(|_| {
112            CompletionError::ResponseError("Response contained no output".to_owned())
113        })?;
114
115        let usage = response
116            .usage
117            .as_ref()
118            .map(|u| completion::Usage {
119                input_tokens: u.input_tokens,
120                output_tokens: u.output_tokens,
121                total_tokens: u.total_tokens,
122                cached_input_tokens: u
123                    .input_tokens_details
124                    .clone()
125                    .map(|x| x.cached_tokens)
126                    .unwrap_or_default(),
127            })
128            .unwrap_or_default();
129
130        Ok(completion::CompletionResponse {
131            choice,
132            usage,
133            raw_response: response,
134            message_id: None,
135        })
136    }
137}
138
139// ================================================================
140// Completion Model
141// ================================================================
142
143#[derive(Clone)]
144pub struct CompletionModel<T = reqwest::Client> {
145    pub(crate) client: Client<T>,
146    pub model: String,
147}
148
149impl<T> CompletionModel<T> {
150    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
151        Self {
152            client,
153            model: model.into(),
154        }
155    }
156}
157
158impl<T> completion::CompletionModel for CompletionModel<T>
159where
160    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
161{
162    type Response = CompletionResponse;
163    type StreamingResponse = StreamingCompletionResponse;
164
165    type Client = Client<T>;
166
167    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
168        Self::new(client.clone(), model)
169    }
170
171    async fn completion(
172        &self,
173        completion_request: completion::CompletionRequest,
174    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
175        let span = if tracing::Span::current().is_disabled() {
176            info_span!(
177                target: "rig::completions",
178                "chat",
179                gen_ai.operation.name = "chat",
180                gen_ai.provider.name = "xai",
181                gen_ai.request.model = self.model,
182                gen_ai.system_instructions = tracing::field::Empty,
183                gen_ai.response.id = tracing::field::Empty,
184                gen_ai.response.model = tracing::field::Empty,
185                gen_ai.usage.output_tokens = tracing::field::Empty,
186                gen_ai.usage.input_tokens = tracing::field::Empty,
187            )
188        } else {
189            tracing::Span::current()
190        };
191
192        span.record("gen_ai.system_instructions", &completion_request.preamble);
193
194        let request =
195            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
196
197        if enabled!(Level::TRACE) {
198            tracing::trace!(target: "rig::completions",
199                "xAI completion request: {}",
200                serde_json::to_string_pretty(&request)?
201            );
202        }
203
204        let body = serde_json::to_vec(&request)?;
205        let req = self
206            .client
207            .post("/v1/responses")?
208            .body(body)
209            .map_err(|e| CompletionError::HttpError(e.into()))?;
210
211        async move {
212            let response = self.client.send::<_, Bytes>(req).await?;
213            let status = response.status();
214            let response_body = response.into_body().into_future().await?.to_vec();
215
216            if status.is_success() {
217                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
218                    ApiResponse::Ok(response) => {
219                        if enabled!(Level::TRACE) {
220                            tracing::trace!(target: "rig::completions",
221                                "xAI completion response: {}",
222                                serde_json::to_string_pretty(&response)?
223                            );
224                        }
225
226                        response.try_into()
227                    }
228                    ApiResponse::Error(error) => {
229                        Err(CompletionError::ProviderError(error.message()))
230                    }
231                }
232            } else {
233                Err(CompletionError::ProviderError(
234                    String::from_utf8_lossy(&response_body).to_string(),
235                ))
236            }
237        }
238        .instrument(span)
239        .await
240    }
241
242    async fn stream(
243        &self,
244        request: CompletionRequest,
245    ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
246        self.stream(request).await
247    }
248}