1use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{Instrument, Level, enabled, info_span};
9
10use super::api::{ApiResponse, Message, ToolDefinition};
11use super::client::Client;
12use crate::OneOrMany;
13use crate::completion::{self, CompletionError, CompletionRequest};
14use crate::http_client::HttpClientExt;
15use crate::providers::openai::completion::ToolChoice;
16use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
17use crate::providers::openai::responses_api::{Output, ResponsesUsage};
18use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
19
20pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30#[derive(Debug, Serialize, Deserialize)]
35pub(super) struct XAICompletionRequest {
36 model: String,
37 pub input: Vec<Message>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 temperature: Option<f64>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 max_output_tokens: Option<u64>,
42 #[serde(skip_serializing_if = "Vec::is_empty")]
43 tools: Vec<Value>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 tool_choice: Option<ToolChoice>,
46 #[serde(flatten, skip_serializing_if = "Option::is_none")]
47 pub additional_params: Option<serde_json::Value>,
48}
49
50impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
51 type Error = CompletionError;
52
53 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
54 if req.output_schema.is_some() {
55 tracing::warn!("Structured outputs currently not supported for xAI");
56 }
57 let model = req.model.clone().unwrap_or_else(|| model.to_string());
58 let mut input: Vec<Message> = req
59 .preamble
60 .as_ref()
61 .map_or_else(Vec::new, |p| vec![Message::system(p)]);
62
63 if let Some(docs) = req.normalized_documents() {
64 let docs: Vec<Message> = docs.try_into()?;
65 input.extend(docs);
66 }
67
68 let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null);
69
70 for msg in req.chat_history {
71 let msg: Vec<Message> = msg.try_into()?;
72 input.extend(msg);
73 }
74
75 let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
76 let mut additional_tools =
77 extract_tools_from_additional_params(&mut additional_params_payload)?;
78 let mut tools = req
79 .tools
80 .into_iter()
81 .map(ToolDefinition::from)
82 .map(serde_json::to_value)
83 .collect::<Result<Vec<_>, _>>()?;
84 tools.append(&mut additional_tools);
85 let additional_params = if additional_params_payload.is_null() {
86 None
87 } else {
88 Some(additional_params_payload)
89 };
90
91 Ok(Self {
92 model: model.to_string(),
93 input,
94 temperature: req.temperature,
95 max_output_tokens: req.max_tokens,
96 tools,
97 tool_choice,
98 additional_params,
99 })
100 }
101}
102
103fn extract_tools_from_additional_params(
104 additional_params: &mut Value,
105) -> Result<Vec<Value>, CompletionError> {
106 if let Some(map) = additional_params.as_object_mut()
107 && let Some(raw_tools) = map.remove("tools")
108 {
109 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
110 CompletionError::RequestError(
111 format!("Invalid xAI `additional_params.tools` payload: {err}").into(),
112 )
113 });
114 }
115
116 Ok(Vec::new())
117}
118
119#[derive(Debug, Deserialize, Serialize)]
124pub struct CompletionResponse {
125 pub id: String,
126 pub model: String,
127 pub output: Vec<Output>,
128 #[serde(default)]
129 pub created: i64,
130 #[serde(default)]
131 pub object: String,
132 #[serde(default)]
133 pub status: Option<String>,
134 pub usage: Option<ResponsesUsage>,
135}
136
137impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
138 type Error = CompletionError;
139
140 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
141 let content: Vec<completion::AssistantContent> = response
142 .output
143 .iter()
144 .cloned()
145 .flat_map(<Vec<completion::AssistantContent>>::from)
146 .collect();
147
148 let choice = OneOrMany::many(content).map_err(|_| {
149 CompletionError::ResponseError("Response contained no output".to_owned())
150 })?;
151
152 let usage = response
153 .usage
154 .as_ref()
155 .map(|u| completion::Usage {
156 input_tokens: u.input_tokens,
157 output_tokens: u.output_tokens,
158 total_tokens: u.total_tokens,
159 cached_input_tokens: u
160 .input_tokens_details
161 .clone()
162 .map(|x| x.cached_tokens)
163 .unwrap_or_default(),
164 cache_creation_input_tokens: 0,
165 reasoning_tokens: 0,
166 })
167 .unwrap_or_default();
168
169 Ok(completion::CompletionResponse {
170 choice,
171 usage,
172 raw_response: response,
173 message_id: None,
174 })
175 }
176}
177
178#[derive(Clone)]
183pub struct CompletionModel<T = reqwest::Client> {
184 pub(crate) client: Client<T>,
185 pub model: String,
186}
187
188impl<T> CompletionModel<T> {
189 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
190 Self {
191 client,
192 model: model.into(),
193 }
194 }
195}
196
197impl<T> completion::CompletionModel for CompletionModel<T>
198where
199 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
200{
201 type Response = CompletionResponse;
202 type StreamingResponse = StreamingCompletionResponse;
203
204 type Client = Client<T>;
205
206 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
207 Self::new(client.clone(), model)
208 }
209
210 async fn completion(
211 &self,
212 completion_request: completion::CompletionRequest,
213 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
214 let span = if tracing::Span::current().is_disabled() {
215 info_span!(
216 target: "rig::completions",
217 "chat",
218 gen_ai.operation.name = "chat",
219 gen_ai.provider.name = "xai",
220 gen_ai.request.model = self.model,
221 gen_ai.system_instructions = tracing::field::Empty,
222 gen_ai.response.id = tracing::field::Empty,
223 gen_ai.response.model = tracing::field::Empty,
224 gen_ai.usage.output_tokens = tracing::field::Empty,
225 gen_ai.usage.input_tokens = tracing::field::Empty,
226 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
227 )
228 } else {
229 tracing::Span::current()
230 };
231
232 span.record("gen_ai.system_instructions", &completion_request.preamble);
233
234 let request =
235 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
236
237 if enabled!(Level::TRACE) {
238 tracing::trace!(target: "rig::completions",
239 "xAI completion request: {}",
240 serde_json::to_string_pretty(&request)?
241 );
242 }
243
244 let body = serde_json::to_vec(&request)?;
245 let req = self
246 .client
247 .post("/v1/responses")?
248 .body(body)
249 .map_err(|e| CompletionError::HttpError(e.into()))?;
250
251 async move {
252 let response = self.client.send::<_, Bytes>(req).await?;
253 let status = response.status();
254 let response_body = response.into_body().into_future().await?.to_vec();
255
256 if status.is_success() {
257 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
258 ApiResponse::Ok(response) => {
259 if enabled!(Level::TRACE) {
260 tracing::trace!(target: "rig::completions",
261 "xAI completion response: {}",
262 serde_json::to_string_pretty(&response)?
263 );
264 }
265
266 response.try_into()
267 }
268 ApiResponse::Error(error) => {
269 Err(CompletionError::ProviderError(error.message()))
270 }
271 }
272 } else {
273 Err(CompletionError::ProviderError(
274 String::from_utf8_lossy(&response_body).to_string(),
275 ))
276 }
277 }
278 .instrument(span)
279 .await
280 }
281
282 async fn stream(
283 &self,
284 request: CompletionRequest,
285 ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
286 self.stream(request).await
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::XAICompletionRequest;
293 use crate::OneOrMany;
294 use crate::completion::CompletionRequest;
295 use crate::completion::request::Document;
296
297 #[test]
298 fn xai_request_includes_normalized_documents() {
299 let request = CompletionRequest {
300 model: None,
301 preamble: Some("Use the provided context.".to_string()),
302 chat_history: OneOrMany::one("What is glarb-glarb?".into()),
303 documents: vec![Document {
304 id: "doc_1".to_string(),
305 text: "Definition of glarb-glarb: an ancient tool.".to_string(),
306 additional_props: Default::default(),
307 }],
308 tools: vec![],
309 temperature: None,
310 max_tokens: None,
311 tool_choice: None,
312 additional_params: None,
313 output_schema: None,
314 };
315
316 let xai_request = XAICompletionRequest::try_from(("grok-4-0709", request))
317 .expect("request conversion should succeed");
318 let serialized = serde_json::to_value(xai_request).expect("serialization should succeed");
319 let input = serialized["input"]
320 .as_array()
321 .expect("xAI request input should be an array");
322
323 assert!(
324 input
325 .iter()
326 .any(|message| message.to_string().contains("glarb-glarb")),
327 "normalized documents should be forwarded into xAI input"
328 );
329 }
330}