1use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tracing::{Instrument, Level, enabled, info_span};
9
10use super::api::{ApiResponse, Message, ToolDefinition};
11use super::client::Client;
12use crate::OneOrMany;
13use crate::completion::{self, CompletionError, CompletionRequest};
14use crate::http_client::HttpClientExt;
15use crate::providers::openai::completion::ToolChoice;
16use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
17use crate::providers::openai::responses_api::{Output, ResponsesUsage};
18use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
19
20pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30#[derive(Debug, Serialize, Deserialize)]
35pub(super) struct XAICompletionRequest {
36 model: String,
37 pub input: Vec<Message>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 temperature: Option<f64>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 max_output_tokens: Option<u64>,
42 #[serde(skip_serializing_if = "Vec::is_empty")]
43 tools: Vec<Value>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 tool_choice: Option<ToolChoice>,
46 #[serde(flatten, skip_serializing_if = "Option::is_none")]
47 pub additional_params: Option<serde_json::Value>,
48}
49
50impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
51 type Error = CompletionError;
52
53 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
54 if req.output_schema.is_some() {
55 tracing::warn!("Structured outputs currently not supported for xAI");
56 }
57 let model = req.model.clone().unwrap_or_else(|| model.to_string());
58 let mut additional_params_payload = req.additional_params.unwrap_or(Value::Null);
59 let mut input: Vec<Message> = req
60 .preamble
61 .as_ref()
62 .map_or_else(Vec::new, |p| vec![Message::system(p)]);
63
64 for msg in req.chat_history {
65 let msg: Vec<Message> = msg.try_into()?;
66 input.extend(msg);
67 }
68
69 let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
70 let mut additional_tools =
71 extract_tools_from_additional_params(&mut additional_params_payload)?;
72 let mut tools = req
73 .tools
74 .into_iter()
75 .map(ToolDefinition::from)
76 .map(serde_json::to_value)
77 .collect::<Result<Vec<_>, _>>()?;
78 tools.append(&mut additional_tools);
79 let additional_params = if additional_params_payload.is_null() {
80 None
81 } else {
82 Some(additional_params_payload)
83 };
84
85 Ok(Self {
86 model: model.to_string(),
87 input,
88 temperature: req.temperature,
89 max_output_tokens: req.max_tokens,
90 tools,
91 tool_choice,
92 additional_params,
93 })
94 }
95}
96
97fn extract_tools_from_additional_params(
98 additional_params: &mut Value,
99) -> Result<Vec<Value>, CompletionError> {
100 if let Some(map) = additional_params.as_object_mut()
101 && let Some(raw_tools) = map.remove("tools")
102 {
103 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
104 CompletionError::RequestError(
105 format!("Invalid xAI `additional_params.tools` payload: {err}").into(),
106 )
107 });
108 }
109
110 Ok(Vec::new())
111}
112
113#[derive(Debug, Deserialize, Serialize)]
118pub struct CompletionResponse {
119 pub id: String,
120 pub model: String,
121 pub output: Vec<Output>,
122 #[serde(default)]
123 pub created: i64,
124 #[serde(default)]
125 pub object: String,
126 #[serde(default)]
127 pub status: Option<String>,
128 pub usage: Option<ResponsesUsage>,
129}
130
131impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
132 type Error = CompletionError;
133
134 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
135 let content: Vec<completion::AssistantContent> = response
136 .output
137 .iter()
138 .cloned()
139 .flat_map(<Vec<completion::AssistantContent>>::from)
140 .collect();
141
142 let choice = OneOrMany::many(content).map_err(|_| {
143 CompletionError::ResponseError("Response contained no output".to_owned())
144 })?;
145
146 let usage = response
147 .usage
148 .as_ref()
149 .map(|u| completion::Usage {
150 input_tokens: u.input_tokens,
151 output_tokens: u.output_tokens,
152 total_tokens: u.total_tokens,
153 cached_input_tokens: u
154 .input_tokens_details
155 .clone()
156 .map(|x| x.cached_tokens)
157 .unwrap_or_default(),
158 })
159 .unwrap_or_default();
160
161 Ok(completion::CompletionResponse {
162 choice,
163 usage,
164 raw_response: response,
165 message_id: None,
166 })
167 }
168}
169
170#[derive(Clone)]
175pub struct CompletionModel<T = reqwest::Client> {
176 pub(crate) client: Client<T>,
177 pub model: String,
178}
179
180impl<T> CompletionModel<T> {
181 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
182 Self {
183 client,
184 model: model.into(),
185 }
186 }
187}
188
189impl<T> completion::CompletionModel for CompletionModel<T>
190where
191 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
192{
193 type Response = CompletionResponse;
194 type StreamingResponse = StreamingCompletionResponse;
195
196 type Client = Client<T>;
197
198 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
199 Self::new(client.clone(), model)
200 }
201
202 async fn completion(
203 &self,
204 completion_request: completion::CompletionRequest,
205 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
206 let span = if tracing::Span::current().is_disabled() {
207 info_span!(
208 target: "rig::completions",
209 "chat",
210 gen_ai.operation.name = "chat",
211 gen_ai.provider.name = "xai",
212 gen_ai.request.model = self.model,
213 gen_ai.system_instructions = tracing::field::Empty,
214 gen_ai.response.id = tracing::field::Empty,
215 gen_ai.response.model = tracing::field::Empty,
216 gen_ai.usage.output_tokens = tracing::field::Empty,
217 gen_ai.usage.input_tokens = tracing::field::Empty,
218 gen_ai.usage.cached_tokens = tracing::field::Empty,
219 )
220 } else {
221 tracing::Span::current()
222 };
223
224 span.record("gen_ai.system_instructions", &completion_request.preamble);
225
226 let request =
227 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
228
229 if enabled!(Level::TRACE) {
230 tracing::trace!(target: "rig::completions",
231 "xAI completion request: {}",
232 serde_json::to_string_pretty(&request)?
233 );
234 }
235
236 let body = serde_json::to_vec(&request)?;
237 let req = self
238 .client
239 .post("/v1/responses")?
240 .body(body)
241 .map_err(|e| CompletionError::HttpError(e.into()))?;
242
243 async move {
244 let response = self.client.send::<_, Bytes>(req).await?;
245 let status = response.status();
246 let response_body = response.into_body().into_future().await?.to_vec();
247
248 if status.is_success() {
249 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
250 ApiResponse::Ok(response) => {
251 if enabled!(Level::TRACE) {
252 tracing::trace!(target: "rig::completions",
253 "xAI completion response: {}",
254 serde_json::to_string_pretty(&response)?
255 );
256 }
257
258 response.try_into()
259 }
260 ApiResponse::Error(error) => {
261 Err(CompletionError::ProviderError(error.message()))
262 }
263 }
264 } else {
265 Err(CompletionError::ProviderError(
266 String::from_utf8_lossy(&response_body).to_string(),
267 ))
268 }
269 }
270 .instrument(span)
271 .await
272 }
273
274 async fn stream(
275 &self,
276 request: CompletionRequest,
277 ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
278 self.stream(request).await
279 }
280}