1use crate::{
7 completion::{self, CompletionError},
8 http_client::HttpClientExt,
9 providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use bytes::Bytes;
17use serde::{Deserialize, Serialize};
18use tracing::{Instrument, Level, enabled, info_span};
19use xai_api_types::{CompletionResponse, ToolDefinition};
20
21pub const GROK_2_1212: &str = "grok-2-1212";
23pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
24pub const GROK_3: &str = "grok-3";
25pub const GROK_3_FAST: &str = "grok-3-fast";
26pub const GROK_3_MINI: &str = "grok-3-mini";
27pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
28pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
29pub const GROK_4: &str = "grok-4-0709";
30
31#[derive(Debug, Serialize, Deserialize)]
32pub(super) struct XAICompletionRequest {
33 model: String,
34 pub messages: Vec<Message>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 temperature: Option<f64>,
37 #[serde(skip_serializing_if = "Vec::is_empty")]
38 tools: Vec<ToolDefinition>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
41 #[serde(flatten, skip_serializing_if = "Option::is_none")]
42 pub additional_params: Option<serde_json::Value>,
43}
44
45impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
46 type Error = CompletionError;
47
48 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
49 let mut full_history: Vec<Message> = match &req.preamble {
50 Some(preamble) => vec![Message::system(preamble)],
51 None => vec![],
52 };
53 if let Some(docs) = req.normalized_documents() {
54 let docs: Vec<Message> = docs.try_into()?;
55 full_history.extend(docs);
56 }
57
58 let chat_history: Vec<Message> = req
59 .chat_history
60 .clone()
61 .into_iter()
62 .map(|message| message.try_into())
63 .collect::<Result<Vec<Vec<Message>>, _>>()?
64 .into_iter()
65 .flatten()
66 .collect();
67
68 full_history.extend(chat_history);
69
70 let tool_choice = req
71 .tool_choice
72 .clone()
73 .map(crate::providers::openrouter::ToolChoice::try_from)
74 .transpose()?;
75
76 Ok(Self {
77 model: model.to_string(),
78 messages: full_history,
79 temperature: req.temperature,
80 tools: req
81 .tools
82 .clone()
83 .into_iter()
84 .map(ToolDefinition::from)
85 .collect::<Vec<_>>(),
86 tool_choice,
87 additional_params: req.additional_params,
88 })
89 }
90}
91
92#[derive(Clone)]
93pub struct CompletionModel<T = reqwest::Client> {
94 pub(crate) client: Client<T>,
95 pub model: String,
96}
97
98impl<T> CompletionModel<T> {
99 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
100 Self {
101 client,
102 model: model.into(),
103 }
104 }
105}
106
107impl<T> completion::CompletionModel for CompletionModel<T>
108where
109 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
110{
111 type Response = CompletionResponse;
112 type StreamingResponse = openai::StreamingCompletionResponse;
113
114 type Client = Client<T>;
115
116 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
117 Self::new(client.clone(), model)
118 }
119
120 async fn completion(
121 &self,
122 completion_request: completion::CompletionRequest,
123 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
124 let span = if tracing::Span::current().is_disabled() {
125 info_span!(
126 target: "rig::completions",
127 "chat",
128 gen_ai.operation.name = "chat",
129 gen_ai.provider.name = "xai",
130 gen_ai.request.model = self.model,
131 gen_ai.system_instructions = tracing::field::Empty,
132 gen_ai.response.id = tracing::field::Empty,
133 gen_ai.response.model = tracing::field::Empty,
134 gen_ai.usage.output_tokens = tracing::field::Empty,
135 gen_ai.usage.input_tokens = tracing::field::Empty,
136 )
137 } else {
138 tracing::Span::current()
139 };
140
141 span.record("gen_ai.system_instructions", &completion_request.preamble);
142
143 let request =
144 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
145
146 if enabled!(Level::TRACE) {
147 tracing::trace!(target: "rig::completions",
148 "xAI completion request: {}",
149 serde_json::to_string_pretty(&request)?
150 );
151 }
152
153 let body = serde_json::to_vec(&request)?;
154 let req = self
155 .client
156 .post("/v1/chat/completions")?
157 .body(body)
158 .map_err(|e| CompletionError::HttpError(e.into()))?;
159
160 async move {
161 let response = self.client.send::<_, Bytes>(req).await?;
162 let status = response.status();
163 let response_body = response.into_body().into_future().await?.to_vec();
164
165 if status.is_success() {
166 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
167 ApiResponse::Ok(response) => {
168 if enabled!(Level::TRACE) {
169 tracing::trace!(target: "rig::completions",
170 "xAI completion response: {}",
171 serde_json::to_string_pretty(&response)?
172 );
173 }
174
175 response.try_into()
176 }
177 ApiResponse::Error(error) => {
178 Err(CompletionError::ProviderError(error.message()))
179 }
180 }
181 } else {
182 Err(CompletionError::ProviderError(
183 String::from_utf8_lossy(&response_body).to_string(),
184 ))
185 }
186 }
187 .instrument(span)
188 .await
189 }
190
191 async fn stream(
192 &self,
193 request: CompletionRequest,
194 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
195 CompletionModel::stream(self, request).await
196 }
197}
198
199pub mod xai_api_types {
200 use serde::{Deserialize, Serialize};
201
202 use crate::OneOrMany;
203 use crate::completion::{self, CompletionError};
204 use crate::providers::openai::{AssistantContent, Message};
205
206 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
207 type Error = CompletionError;
208
209 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
210 let choice = response.choices.first().ok_or_else(|| {
211 CompletionError::ResponseError("Response contained no choices".to_owned())
212 })?;
213 let content = match &choice.message {
214 Message::Assistant {
215 content,
216 tool_calls,
217 ..
218 } => {
219 let mut content = content
220 .iter()
221 .map(|c| match c {
222 AssistantContent::Text { text } => {
223 completion::AssistantContent::text(text)
224 }
225 AssistantContent::Refusal { refusal } => {
226 completion::AssistantContent::text(refusal)
227 }
228 })
229 .collect::<Vec<_>>();
230
231 content.extend(
232 tool_calls
233 .iter()
234 .map(|call| {
235 completion::AssistantContent::tool_call(
236 &call.id,
237 &call.function.name,
238 call.function.arguments.clone(),
239 )
240 })
241 .collect::<Vec<_>>(),
242 );
243 Ok(content)
244 }
245 _ => Err(CompletionError::ResponseError(
246 "Response did not contain a valid message or tool call".into(),
247 )),
248 }?;
249
250 let choice = OneOrMany::many(content).map_err(|_| {
251 CompletionError::ResponseError(
252 "Response contained no message or tool call (empty)".to_owned(),
253 )
254 })?;
255
256 let usage = completion::Usage {
257 input_tokens: response.usage.prompt_tokens as u64,
258 output_tokens: response.usage.completion_tokens as u64,
259 total_tokens: response.usage.total_tokens as u64,
260 };
261
262 Ok(completion::CompletionResponse {
263 choice,
264 usage,
265 raw_response: response,
266 })
267 }
268 }
269
270 impl From<completion::ToolDefinition> for ToolDefinition {
271 fn from(tool: completion::ToolDefinition) -> Self {
272 Self {
273 r#type: "function".into(),
274 function: tool,
275 }
276 }
277 }
278
279 #[derive(Clone, Debug, Deserialize, Serialize)]
280 pub struct ToolDefinition {
281 pub r#type: String,
282 pub function: completion::ToolDefinition,
283 }
284
285 #[derive(Debug, Deserialize)]
286 pub struct Function {
287 pub name: String,
288 pub arguments: String,
289 }
290
291 #[derive(Debug, Deserialize, Serialize)]
292 pub struct CompletionResponse {
293 pub id: String,
294 pub model: String,
295 pub choices: Vec<Choice>,
296 pub created: i64,
297 pub object: String,
298 pub system_fingerprint: String,
299 pub usage: Usage,
300 }
301
302 #[derive(Debug, Deserialize, Serialize)]
303 pub struct Choice {
304 pub finish_reason: String,
305 pub index: i32,
306 pub message: Message,
307 }
308
309 #[derive(Debug, Deserialize, Serialize)]
310 pub struct Usage {
311 pub completion_tokens: i32,
312 pub prompt_tokens: i32,
313 pub total_tokens: i32,
314 }
315}