1use crate::{
7 completion::{self, CompletionError},
8 http_client::HttpClientExt,
9 providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use bytes::Bytes;
17use serde::{Deserialize, Serialize};
18use tracing::{Instrument, Level, enabled, info_span};
19use xai_api_types::{CompletionResponse, ToolDefinition};
20
21pub const GROK_2_1212: &str = "grok-2-1212";
23pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
24pub const GROK_3: &str = "grok-3";
25pub const GROK_3_FAST: &str = "grok-3-fast";
26pub const GROK_3_MINI: &str = "grok-3-mini";
27pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
28pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
29pub const GROK_4: &str = "grok-4-0709";
30
31#[derive(Debug, Serialize, Deserialize)]
32pub(super) struct XAICompletionRequest {
33 model: String,
34 pub messages: Vec<Message>,
35 #[serde(flatten, skip_serializing_if = "Option::is_none")]
36 temperature: Option<f64>,
37 #[serde(skip_serializing_if = "Vec::is_empty")]
38 tools: Vec<ToolDefinition>,
39 #[serde(flatten, skip_serializing_if = "Option::is_none")]
40 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
41 #[serde(flatten, skip_serializing_if = "Option::is_none")]
42 pub additional_params: Option<serde_json::Value>,
43}
44
45impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
46 type Error = CompletionError;
47
48 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
49 let mut full_history: Vec<Message> = match &req.preamble {
50 Some(preamble) => vec![Message::system(preamble)],
51 None => vec![],
52 };
53 if let Some(docs) = req.normalized_documents() {
54 let docs: Vec<Message> = docs.try_into()?;
55 full_history.extend(docs);
56 }
57
58 let chat_history: Vec<Message> = req
59 .chat_history
60 .clone()
61 .into_iter()
62 .map(|message| message.try_into())
63 .collect::<Result<Vec<Vec<Message>>, _>>()?
64 .into_iter()
65 .flatten()
66 .collect();
67
68 full_history.extend(chat_history);
69
70 let tool_choice = req
71 .tool_choice
72 .clone()
73 .map(crate::providers::openrouter::ToolChoice::try_from)
74 .transpose()?;
75
76 Ok(Self {
77 model: model.to_string(),
78 messages: full_history,
79 temperature: req.temperature,
80 tools: req
81 .tools
82 .clone()
83 .into_iter()
84 .map(ToolDefinition::from)
85 .collect::<Vec<_>>(),
86 tool_choice,
87 additional_params: req.additional_params,
88 })
89 }
90}
91
92#[derive(Clone)]
93pub struct CompletionModel<T = reqwest::Client> {
94 pub(crate) client: Client<T>,
95 pub model: String,
96}
97
98impl<T> CompletionModel<T> {
99 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
100 Self {
101 client,
102 model: model.into(),
103 }
104 }
105}
106
107impl<T> completion::CompletionModel for CompletionModel<T>
108where
109 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
110{
111 type Response = CompletionResponse;
112 type StreamingResponse = openai::StreamingCompletionResponse;
113
114 type Client = Client<T>;
115
116 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
117 Self::new(client.clone(), model)
118 }
119
120 #[cfg_attr(feature = "worker", worker::send)]
121 async fn completion(
122 &self,
123 completion_request: completion::CompletionRequest,
124 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
125 let span = if tracing::Span::current().is_disabled() {
126 info_span!(
127 target: "rig::completions",
128 "chat",
129 gen_ai.operation.name = "chat",
130 gen_ai.provider.name = "xai",
131 gen_ai.request.model = self.model,
132 gen_ai.system_instructions = tracing::field::Empty,
133 gen_ai.response.id = tracing::field::Empty,
134 gen_ai.response.model = tracing::field::Empty,
135 gen_ai.usage.output_tokens = tracing::field::Empty,
136 gen_ai.usage.input_tokens = tracing::field::Empty,
137 )
138 } else {
139 tracing::Span::current()
140 };
141
142 span.record("gen_ai.system_instructions", &completion_request.preamble);
143
144 let request =
145 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
146
147 if enabled!(Level::TRACE) {
148 tracing::trace!(target: "rig::completions",
149 "xAI completion request: {}",
150 serde_json::to_string_pretty(&request)?
151 );
152 }
153
154 let body = serde_json::to_vec(&request)?;
155 let req = self
156 .client
157 .post("/v1/chat/completions")?
158 .body(body)
159 .map_err(|e| CompletionError::HttpError(e.into()))?;
160
161 async move {
162 let response = self.client.send::<_, Bytes>(req).await?;
163 let status = response.status();
164 let response_body = response.into_body().into_future().await?.to_vec();
165
166 if status.is_success() {
167 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
168 ApiResponse::Ok(response) => {
169 if enabled!(Level::TRACE) {
170 tracing::trace!(target: "rig::completions",
171 "xAI completion response: {}",
172 serde_json::to_string_pretty(&response)?
173 );
174 }
175
176 response.try_into()
177 }
178 ApiResponse::Error(error) => {
179 Err(CompletionError::ProviderError(error.message()))
180 }
181 }
182 } else {
183 Err(CompletionError::ProviderError(
184 String::from_utf8_lossy(&response_body).to_string(),
185 ))
186 }
187 }
188 .instrument(span)
189 .await
190 }
191
192 #[cfg_attr(feature = "worker", worker::send)]
193 async fn stream(
194 &self,
195 request: CompletionRequest,
196 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
197 CompletionModel::stream(self, request).await
198 }
199}
200
201pub mod xai_api_types {
202 use serde::{Deserialize, Serialize};
203
204 use crate::OneOrMany;
205 use crate::completion::{self, CompletionError};
206 use crate::providers::openai::{AssistantContent, Message};
207
208 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
209 type Error = CompletionError;
210
211 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
212 let choice = response.choices.first().ok_or_else(|| {
213 CompletionError::ResponseError("Response contained no choices".to_owned())
214 })?;
215 let content = match &choice.message {
216 Message::Assistant {
217 content,
218 tool_calls,
219 ..
220 } => {
221 let mut content = content
222 .iter()
223 .map(|c| match c {
224 AssistantContent::Text { text } => {
225 completion::AssistantContent::text(text)
226 }
227 AssistantContent::Refusal { refusal } => {
228 completion::AssistantContent::text(refusal)
229 }
230 })
231 .collect::<Vec<_>>();
232
233 content.extend(
234 tool_calls
235 .iter()
236 .map(|call| {
237 completion::AssistantContent::tool_call(
238 &call.id,
239 &call.function.name,
240 call.function.arguments.clone(),
241 )
242 })
243 .collect::<Vec<_>>(),
244 );
245 Ok(content)
246 }
247 _ => Err(CompletionError::ResponseError(
248 "Response did not contain a valid message or tool call".into(),
249 )),
250 }?;
251
252 let choice = OneOrMany::many(content).map_err(|_| {
253 CompletionError::ResponseError(
254 "Response contained no message or tool call (empty)".to_owned(),
255 )
256 })?;
257
258 let usage = completion::Usage {
259 input_tokens: response.usage.prompt_tokens as u64,
260 output_tokens: response.usage.completion_tokens as u64,
261 total_tokens: response.usage.total_tokens as u64,
262 };
263
264 Ok(completion::CompletionResponse {
265 choice,
266 usage,
267 raw_response: response,
268 })
269 }
270 }
271
272 impl From<completion::ToolDefinition> for ToolDefinition {
273 fn from(tool: completion::ToolDefinition) -> Self {
274 Self {
275 r#type: "function".into(),
276 function: tool,
277 }
278 }
279 }
280
281 #[derive(Clone, Debug, Deserialize, Serialize)]
282 pub struct ToolDefinition {
283 pub r#type: String,
284 pub function: completion::ToolDefinition,
285 }
286
287 #[derive(Debug, Deserialize)]
288 pub struct Function {
289 pub name: String,
290 pub arguments: String,
291 }
292
293 #[derive(Debug, Deserialize, Serialize)]
294 pub struct CompletionResponse {
295 pub id: String,
296 pub model: String,
297 pub choices: Vec<Choice>,
298 pub created: i64,
299 pub object: String,
300 pub system_fingerprint: String,
301 pub usage: Usage,
302 }
303
304 #[derive(Debug, Deserialize, Serialize)]
305 pub struct Choice {
306 pub finish_reason: String,
307 pub index: i32,
308 pub message: Message,
309 }
310
311 #[derive(Debug, Deserialize, Serialize)]
312 pub struct Usage {
313 pub completion_tokens: i32,
314 pub prompt_tokens: i32,
315 pub total_tokens: i32,
316 }
317}