Skip to main content

rig/providers/xai/
completion.rs

1//! xAI Completion Integration
2//!
3//! Uses the xAI Responses API: <https://docs.x.ai/docs/guides/chat>
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use tracing::{Instrument, Level, enabled, info_span};
8
9use super::api::{ApiResponse, Message, ToolDefinition};
10use super::client::Client;
11use crate::OneOrMany;
12use crate::completion::{self, CompletionError, CompletionRequest};
13use crate::http_client::HttpClientExt;
14use crate::providers::openai::completion::ToolChoice;
15use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
16use crate::providers::openai::responses_api::{Output, ResponsesUsage};
17use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
18
19/// xAI completion models as of 2025-06-04
20pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27pub const GROK_4: &str = "grok-4-0709";
28
29// ================================================================
30// Request Types
31// ================================================================
32
33#[derive(Debug, Serialize, Deserialize)]
34pub(super) struct XAICompletionRequest {
35    model: String,
36    pub input: Vec<Message>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    temperature: Option<f64>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    max_output_tokens: Option<u64>,
41    #[serde(skip_serializing_if = "Vec::is_empty")]
42    tools: Vec<ToolDefinition>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    tool_choice: Option<ToolChoice>,
45    #[serde(flatten, skip_serializing_if = "Option::is_none")]
46    pub additional_params: Option<serde_json::Value>,
47}
48
49impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
50    type Error = CompletionError;
51
52    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
53        let mut input: Vec<Message> = req
54            .preamble
55            .as_ref()
56            .map_or_else(Vec::new, |p| vec![Message::system(p)]);
57
58        for msg in req.chat_history {
59            let msg: Vec<Message> = msg.try_into()?;
60            input.extend(msg);
61        }
62
63        let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
64        let tools = req.tools.into_iter().map(ToolDefinition::from).collect();
65
66        Ok(Self {
67            model: model.to_string(),
68            input,
69            temperature: req.temperature,
70            max_output_tokens: req.max_tokens,
71            tools,
72            tool_choice,
73            additional_params: req.additional_params,
74        })
75    }
76}
77
78// ================================================================
79// Response Types
80// ================================================================
81
82#[derive(Debug, Deserialize, Serialize)]
83pub struct CompletionResponse {
84    pub id: String,
85    pub model: String,
86    pub output: Vec<Output>,
87    #[serde(default)]
88    pub created: i64,
89    #[serde(default)]
90    pub object: String,
91    #[serde(default)]
92    pub status: Option<String>,
93    pub usage: Option<ResponsesUsage>,
94}
95
96impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
97    type Error = CompletionError;
98
99    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
100        let content: Vec<completion::AssistantContent> = response
101            .output
102            .iter()
103            .cloned()
104            .flat_map(<Vec<completion::AssistantContent>>::from)
105            .collect();
106
107        let choice = OneOrMany::many(content).map_err(|_| {
108            CompletionError::ResponseError("Response contained no output".to_owned())
109        })?;
110
111        let usage = response
112            .usage
113            .as_ref()
114            .map(|u| completion::Usage {
115                input_tokens: u.input_tokens,
116                output_tokens: u.output_tokens,
117                total_tokens: u.total_tokens,
118                cached_input_tokens: u
119                    .input_tokens_details
120                    .clone()
121                    .map(|x| x.cached_tokens)
122                    .unwrap_or_default(),
123            })
124            .unwrap_or_default();
125
126        Ok(completion::CompletionResponse {
127            choice,
128            usage,
129            raw_response: response,
130        })
131    }
132}
133
134// ================================================================
135// Completion Model
136// ================================================================
137
138#[derive(Clone)]
139pub struct CompletionModel<T = reqwest::Client> {
140    pub(crate) client: Client<T>,
141    pub model: String,
142}
143
144impl<T> CompletionModel<T> {
145    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
146        Self {
147            client,
148            model: model.into(),
149        }
150    }
151}
152
153impl<T> completion::CompletionModel for CompletionModel<T>
154where
155    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
156{
157    type Response = CompletionResponse;
158    type StreamingResponse = StreamingCompletionResponse;
159
160    type Client = Client<T>;
161
162    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
163        Self::new(client.clone(), model)
164    }
165
166    async fn completion(
167        &self,
168        completion_request: completion::CompletionRequest,
169    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
170        let span = if tracing::Span::current().is_disabled() {
171            info_span!(
172                target: "rig::completions",
173                "chat",
174                gen_ai.operation.name = "chat",
175                gen_ai.provider.name = "xai",
176                gen_ai.request.model = self.model,
177                gen_ai.system_instructions = tracing::field::Empty,
178                gen_ai.response.id = tracing::field::Empty,
179                gen_ai.response.model = tracing::field::Empty,
180                gen_ai.usage.output_tokens = tracing::field::Empty,
181                gen_ai.usage.input_tokens = tracing::field::Empty,
182            )
183        } else {
184            tracing::Span::current()
185        };
186
187        span.record("gen_ai.system_instructions", &completion_request.preamble);
188
189        let request =
190            XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
191
192        if enabled!(Level::TRACE) {
193            tracing::trace!(target: "rig::completions",
194                "xAI completion request: {}",
195                serde_json::to_string_pretty(&request)?
196            );
197        }
198
199        let body = serde_json::to_vec(&request)?;
200        let req = self
201            .client
202            .post("/v1/responses")?
203            .body(body)
204            .map_err(|e| CompletionError::HttpError(e.into()))?;
205
206        async move {
207            let response = self.client.send::<_, Bytes>(req).await?;
208            let status = response.status();
209            let response_body = response.into_body().into_future().await?.to_vec();
210
211            if status.is_success() {
212                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
213                    ApiResponse::Ok(response) => {
214                        if enabled!(Level::TRACE) {
215                            tracing::trace!(target: "rig::completions",
216                                "xAI completion response: {}",
217                                serde_json::to_string_pretty(&response)?
218                            );
219                        }
220
221                        response.try_into()
222                    }
223                    ApiResponse::Error(error) => {
224                        Err(CompletionError::ProviderError(error.message()))
225                    }
226                }
227            } else {
228                Err(CompletionError::ProviderError(
229                    String::from_utf8_lossy(&response_body).to_string(),
230                ))
231            }
232        }
233        .instrument(span)
234        .await
235    }
236
237    async fn stream(
238        &self,
239        request: CompletionRequest,
240    ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
241        self.stream(request).await
242    }
243}