1use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{Instrument, Level, enabled, info_span};
9
10use super::api::{ApiResponse, Message, ToolDefinition};
11use super::client::Client;
12use crate::OneOrMany;
13use crate::completion::{self, CompletionError, CompletionRequest};
14use crate::http_client::HttpClientExt;
15use crate::providers::openai::completion::ToolChoice;
16use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
17use crate::providers::openai::responses_api::{Output, ResponsesUsage};
18use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
19
20pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30#[derive(Debug, Serialize, Deserialize)]
35pub(super) struct XAICompletionRequest {
36 model: String,
37 pub input: Vec<Message>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 temperature: Option<f64>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 max_output_tokens: Option<u64>,
42 #[serde(skip_serializing_if = "Vec::is_empty")]
43 tools: Vec<Value>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 tool_choice: Option<ToolChoice>,
46 #[serde(flatten, skip_serializing_if = "Option::is_none")]
47 pub additional_params: Option<serde_json::Value>,
48}
49
50impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
51 type Error = CompletionError;
52
53 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
54 let chat_history = req.chat_history_with_documents();
55 if req.output_schema.is_some() {
56 tracing::warn!("Structured outputs currently not supported for xAI");
57 }
58 let model = req.model.clone().unwrap_or_else(|| model.to_string());
59 let mut input: Vec<Message> = req
60 .preamble
61 .as_ref()
62 .map_or_else(Vec::new, |p| vec![Message::system(p)]);
63
64 let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null);
65
66 for msg in chat_history {
67 let msg: Vec<Message> = msg.try_into()?;
68 input.extend(msg);
69 }
70
71 let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
72 let mut additional_tools =
73 extract_tools_from_additional_params(&mut additional_params_payload)?;
74 let mut tools = req
75 .tools
76 .into_iter()
77 .map(ToolDefinition::from)
78 .map(serde_json::to_value)
79 .collect::<Result<Vec<_>, _>>()?;
80 tools.append(&mut additional_tools);
81 let additional_params = if additional_params_payload.is_null() {
82 None
83 } else {
84 Some(additional_params_payload)
85 };
86
87 Ok(Self {
88 model: model.to_string(),
89 input,
90 temperature: req.temperature,
91 max_output_tokens: req.max_tokens,
92 tools,
93 tool_choice,
94 additional_params,
95 })
96 }
97}
98
99fn extract_tools_from_additional_params(
100 additional_params: &mut Value,
101) -> Result<Vec<Value>, CompletionError> {
102 if let Some(map) = additional_params.as_object_mut()
103 && let Some(raw_tools) = map.remove("tools")
104 {
105 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
106 CompletionError::RequestError(
107 format!("Invalid xAI `additional_params.tools` payload: {err}").into(),
108 )
109 });
110 }
111
112 Ok(Vec::new())
113}
114
115#[derive(Debug, Deserialize, Serialize)]
120pub struct CompletionResponse {
121 pub id: String,
122 pub model: String,
123 pub output: Vec<Output>,
124 #[serde(default)]
125 pub created: i64,
126 #[serde(default)]
127 pub object: String,
128 #[serde(default)]
129 pub status: Option<String>,
130 pub usage: Option<ResponsesUsage>,
131}
132
133impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
134 type Error = CompletionError;
135
136 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
137 let content: Vec<completion::AssistantContent> = response
138 .output
139 .iter()
140 .cloned()
141 .flat_map(<Vec<completion::AssistantContent>>::from)
142 .collect();
143
144 let choice = OneOrMany::many(content).map_err(|_| {
145 CompletionError::ResponseError("Response contained no output".to_owned())
146 })?;
147
148 let usage = response
149 .usage
150 .as_ref()
151 .map(|u| completion::Usage {
152 input_tokens: u.input_tokens,
153 output_tokens: u.output_tokens,
154 total_tokens: u.total_tokens,
155 cached_input_tokens: u
156 .input_tokens_details
157 .clone()
158 .map(|x| x.cached_tokens)
159 .unwrap_or_default(),
160 cache_creation_input_tokens: 0,
161 tool_use_prompt_tokens: 0,
162 reasoning_tokens: 0,
163 })
164 .unwrap_or_default();
165
166 Ok(completion::CompletionResponse {
167 choice,
168 usage,
169 raw_response: response,
170 message_id: None,
171 })
172 }
173}
174
175#[derive(Clone)]
180pub struct CompletionModel<T = reqwest::Client> {
181 pub(crate) client: Client<T>,
182 pub model: String,
183}
184
185impl<T> CompletionModel<T> {
186 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
187 Self {
188 client,
189 model: model.into(),
190 }
191 }
192}
193
194impl<T> completion::CompletionModel for CompletionModel<T>
195where
196 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
197{
198 type Response = CompletionResponse;
199 type StreamingResponse = StreamingCompletionResponse;
200
201 type Client = Client<T>;
202
203 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
204 Self::new(client.clone(), model)
205 }
206
207 async fn completion(
208 &self,
209 completion_request: completion::CompletionRequest,
210 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
211 let span = if tracing::Span::current().is_disabled() {
212 info_span!(
213 target: "rig::completions",
214 "chat",
215 gen_ai.operation.name = "chat",
216 gen_ai.provider.name = "xai",
217 gen_ai.request.model = self.model,
218 gen_ai.system_instructions = tracing::field::Empty,
219 gen_ai.response.id = tracing::field::Empty,
220 gen_ai.response.model = tracing::field::Empty,
221 gen_ai.usage.output_tokens = tracing::field::Empty,
222 gen_ai.usage.input_tokens = tracing::field::Empty,
223 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
224 )
225 } else {
226 tracing::Span::current()
227 };
228
229 span.record("gen_ai.system_instructions", &completion_request.preamble);
230
231 let request =
232 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
233
234 if enabled!(Level::TRACE) {
235 tracing::trace!(target: "rig::completions",
236 "xAI completion request: {}",
237 serde_json::to_string_pretty(&request)?
238 );
239 }
240
241 let body = serde_json::to_vec(&request)?;
242 let req = self
243 .client
244 .post("/v1/responses")?
245 .body(body)
246 .map_err(|e| CompletionError::HttpError(e.into()))?;
247
248 async move {
249 let response = self.client.send::<_, Bytes>(req).await?;
250 let status = response.status();
251 let response_body = response.into_body().into_future().await?.to_vec();
252
253 if status.is_success() {
254 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
255 ApiResponse::Ok(response) => {
256 if enabled!(Level::TRACE) {
257 tracing::trace!(target: "rig::completions",
258 "xAI completion response: {}",
259 serde_json::to_string_pretty(&response)?
260 );
261 }
262
263 response.try_into()
264 }
265 ApiResponse::Error(error) => {
266 Err(CompletionError::ProviderError(error.message()))
267 }
268 }
269 } else {
270 Err(CompletionError::ProviderError(
271 String::from_utf8_lossy(&response_body).to_string(),
272 ))
273 }
274 }
275 .instrument(span)
276 .await
277 }
278
279 async fn stream(
280 &self,
281 request: CompletionRequest,
282 ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
283 self.stream(request).await
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::XAICompletionRequest;
290 use crate::OneOrMany;
291 use crate::completion::request::Document;
292 use crate::completion::{CompletionRequest, CompletionRequestBuilder, Message};
293 use crate::test_utils::MockCompletionModel;
294
295 #[test]
296 fn xai_request_includes_normalized_documents() {
297 let request =
298 CompletionRequestBuilder::new(MockCompletionModel::default(), "What is glarb-glarb?")
299 .message(Message::system("Use the provided context."))
300 .document(Document {
301 id: "doc_1".to_string(),
302 text: "Definition of glarb-glarb: an ancient tool.".to_string(),
303 additional_props: Default::default(),
304 })
305 .build();
306
307 let xai_request = XAICompletionRequest::try_from(("grok-4-0709", request))
308 .expect("request conversion should succeed");
309 let serialized = serde_json::to_value(xai_request).expect("serialization should succeed");
310 let input = serialized["input"]
311 .as_array()
312 .expect("xAI request input should be an array");
313
314 assert!(
315 input
316 .iter()
317 .any(|message| message.to_string().contains("glarb-glarb")),
318 "normalized documents should be forwarded into xAI input"
319 );
320 }
321
322 #[test]
323 fn xai_direct_request_keeps_documents_after_system_messages() {
324 let request = CompletionRequest {
325 model: None,
326 preamble: None,
327 chat_history: OneOrMany::many(vec![
328 Message::system("System prompt"),
329 Message::assistant("Earlier assistant turn"),
330 Message::system("Mid-conversation instruction"),
331 Message::user("What is glarb-glarb?"),
332 ])
333 .unwrap(),
334 documents: vec![Document {
335 id: "doc_1".to_string(),
336 text: "Definition of glarb-glarb: an ancient tool.".to_string(),
337 additional_props: Default::default(),
338 }],
339 tools: vec![],
340 temperature: None,
341 max_tokens: None,
342 tool_choice: None,
343 additional_params: None,
344 output_schema: None,
345 };
346
347 let xai_request = XAICompletionRequest::try_from(("grok-4-0709", request))
348 .expect("request conversion should succeed");
349 let serialized = serde_json::to_value(xai_request).expect("serialization should succeed");
350 let input = serialized["input"]
351 .as_array()
352 .expect("xAI request input should be an array");
353
354 assert_eq!(input.len(), 5);
355 assert_eq!(input[0]["role"], "system");
356 assert_eq!(input[1]["role"], "user");
357 assert!(
358 input[1].to_string().contains("<file id: doc_1>"),
359 "document input should follow leading system input: {input:?}"
360 );
361 assert_eq!(input[2]["role"], "assistant");
362 assert_eq!(input[3]["role"], "system");
363 assert_eq!(input[4]["role"], "user");
364 assert_eq!(
365 input
366 .iter()
367 .filter(|message| message.to_string().contains("<file id: doc_1>"))
368 .count(),
369 1,
370 "document input should appear exactly once: {input:?}"
371 );
372 }
373}