rig/providers/xai/
completion.rs1use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use tracing::{Instrument, Level, enabled, info_span};
8
9use super::api::{ApiResponse, Message, ToolDefinition};
10use super::client::Client;
11use crate::OneOrMany;
12use crate::completion::{self, CompletionError, CompletionRequest};
13use crate::http_client::HttpClientExt;
14use crate::providers::openai::completion::ToolChoice;
15use crate::providers::openai::responses_api::streaming::StreamingCompletionResponse;
16use crate::providers::openai::responses_api::{Output, ResponsesUsage};
17use crate::streaming::StreamingCompletionResponse as BaseStreamingCompletionResponse;
18
19pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27pub const GROK_4: &str = "grok-4-0709";
28
29#[derive(Debug, Serialize, Deserialize)]
34pub(super) struct XAICompletionRequest {
35 model: String,
36 pub input: Vec<Message>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 temperature: Option<f64>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 max_output_tokens: Option<u64>,
41 #[serde(skip_serializing_if = "Vec::is_empty")]
42 tools: Vec<ToolDefinition>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 tool_choice: Option<ToolChoice>,
45 #[serde(flatten, skip_serializing_if = "Option::is_none")]
46 pub additional_params: Option<serde_json::Value>,
47}
48
49impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
50 type Error = CompletionError;
51
52 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
53 if req.output_schema.is_some() {
54 tracing::warn!("Structured outputs currently not supported for xAI");
55 }
56 let model = req.model.clone().unwrap_or_else(|| model.to_string());
57 let mut input: Vec<Message> = req
58 .preamble
59 .as_ref()
60 .map_or_else(Vec::new, |p| vec![Message::system(p)]);
61
62 for msg in req.chat_history {
63 let msg: Vec<Message> = msg.try_into()?;
64 input.extend(msg);
65 }
66
67 let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?;
68 let tools = req.tools.into_iter().map(ToolDefinition::from).collect();
69
70 Ok(Self {
71 model: model.to_string(),
72 input,
73 temperature: req.temperature,
74 max_output_tokens: req.max_tokens,
75 tools,
76 tool_choice,
77 additional_params: req.additional_params,
78 })
79 }
80}
81
82#[derive(Debug, Deserialize, Serialize)]
87pub struct CompletionResponse {
88 pub id: String,
89 pub model: String,
90 pub output: Vec<Output>,
91 #[serde(default)]
92 pub created: i64,
93 #[serde(default)]
94 pub object: String,
95 #[serde(default)]
96 pub status: Option<String>,
97 pub usage: Option<ResponsesUsage>,
98}
99
100impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
101 type Error = CompletionError;
102
103 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
104 let content: Vec<completion::AssistantContent> = response
105 .output
106 .iter()
107 .cloned()
108 .flat_map(<Vec<completion::AssistantContent>>::from)
109 .collect();
110
111 let choice = OneOrMany::many(content).map_err(|_| {
112 CompletionError::ResponseError("Response contained no output".to_owned())
113 })?;
114
115 let usage = response
116 .usage
117 .as_ref()
118 .map(|u| completion::Usage {
119 input_tokens: u.input_tokens,
120 output_tokens: u.output_tokens,
121 total_tokens: u.total_tokens,
122 cached_input_tokens: u
123 .input_tokens_details
124 .clone()
125 .map(|x| x.cached_tokens)
126 .unwrap_or_default(),
127 })
128 .unwrap_or_default();
129
130 Ok(completion::CompletionResponse {
131 choice,
132 usage,
133 raw_response: response,
134 message_id: None,
135 })
136 }
137}
138
139#[derive(Clone)]
144pub struct CompletionModel<T = reqwest::Client> {
145 pub(crate) client: Client<T>,
146 pub model: String,
147}
148
149impl<T> CompletionModel<T> {
150 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
151 Self {
152 client,
153 model: model.into(),
154 }
155 }
156}
157
158impl<T> completion::CompletionModel for CompletionModel<T>
159where
160 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
161{
162 type Response = CompletionResponse;
163 type StreamingResponse = StreamingCompletionResponse;
164
165 type Client = Client<T>;
166
167 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
168 Self::new(client.clone(), model)
169 }
170
171 async fn completion(
172 &self,
173 completion_request: completion::CompletionRequest,
174 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
175 let span = if tracing::Span::current().is_disabled() {
176 info_span!(
177 target: "rig::completions",
178 "chat",
179 gen_ai.operation.name = "chat",
180 gen_ai.provider.name = "xai",
181 gen_ai.request.model = self.model,
182 gen_ai.system_instructions = tracing::field::Empty,
183 gen_ai.response.id = tracing::field::Empty,
184 gen_ai.response.model = tracing::field::Empty,
185 gen_ai.usage.output_tokens = tracing::field::Empty,
186 gen_ai.usage.input_tokens = tracing::field::Empty,
187 )
188 } else {
189 tracing::Span::current()
190 };
191
192 span.record("gen_ai.system_instructions", &completion_request.preamble);
193
194 let request =
195 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
196
197 if enabled!(Level::TRACE) {
198 tracing::trace!(target: "rig::completions",
199 "xAI completion request: {}",
200 serde_json::to_string_pretty(&request)?
201 );
202 }
203
204 let body = serde_json::to_vec(&request)?;
205 let req = self
206 .client
207 .post("/v1/responses")?
208 .body(body)
209 .map_err(|e| CompletionError::HttpError(e.into()))?;
210
211 async move {
212 let response = self.client.send::<_, Bytes>(req).await?;
213 let status = response.status();
214 let response_body = response.into_body().into_future().await?.to_vec();
215
216 if status.is_success() {
217 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
218 ApiResponse::Ok(response) => {
219 if enabled!(Level::TRACE) {
220 tracing::trace!(target: "rig::completions",
221 "xAI completion response: {}",
222 serde_json::to_string_pretty(&response)?
223 );
224 }
225
226 response.try_into()
227 }
228 ApiResponse::Error(error) => {
229 Err(CompletionError::ProviderError(error.message()))
230 }
231 }
232 } else {
233 Err(CompletionError::ProviderError(
234 String::from_utf8_lossy(&response_body).to_string(),
235 ))
236 }
237 }
238 .instrument(span)
239 .await
240 }
241
242 async fn stream(
243 &self,
244 request: CompletionRequest,
245 ) -> Result<BaseStreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
246 self.stream(request).await
247 }
248}