1use crate::{
7 completion::{self, CompletionError},
8 http_client::HttpClientExt,
9 providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use bytes::Bytes;
17use serde::{Deserialize, Serialize};
18use tracing::{Instrument, info_span};
19use xai_api_types::{CompletionResponse, ToolDefinition};
20
21pub const GROK_2_1212: &str = "grok-2-1212";
23pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
24pub const GROK_3: &str = "grok-3";
25pub const GROK_3_FAST: &str = "grok-3-fast";
26pub const GROK_3_MINI: &str = "grok-3-mini";
27pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
28pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
29pub const GROK_4: &str = "grok-4-0709";
30
31#[derive(Debug, Serialize, Deserialize)]
32pub(super) struct XAICompletionRequest {
33 model: String,
34 pub messages: Vec<Message>,
35 #[serde(flatten, skip_serializing_if = "Option::is_none")]
36 temperature: Option<f64>,
37 #[serde(skip_serializing_if = "Vec::is_empty")]
38 tools: Vec<ToolDefinition>,
39 #[serde(flatten, skip_serializing_if = "Option::is_none")]
40 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
41 #[serde(flatten, skip_serializing_if = "Option::is_none")]
42 pub additional_params: Option<serde_json::Value>,
43}
44
45impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest {
46 type Error = CompletionError;
47
48 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
49 let mut full_history: Vec<Message> = match &req.preamble {
50 Some(preamble) => vec![Message::system(preamble)],
51 None => vec![],
52 };
53 if let Some(docs) = req.normalized_documents() {
54 let docs: Vec<Message> = docs.try_into()?;
55 full_history.extend(docs);
56 }
57
58 let chat_history: Vec<Message> = req
59 .chat_history
60 .clone()
61 .into_iter()
62 .map(|message| message.try_into())
63 .collect::<Result<Vec<Vec<Message>>, _>>()?
64 .into_iter()
65 .flatten()
66 .collect();
67
68 full_history.extend(chat_history);
69
70 let tool_choice = req
71 .tool_choice
72 .clone()
73 .map(crate::providers::openrouter::ToolChoice::try_from)
74 .transpose()?;
75
76 Ok(Self {
77 model: model.to_string(),
78 messages: full_history,
79 temperature: req.temperature,
80 tools: req
81 .tools
82 .clone()
83 .into_iter()
84 .map(ToolDefinition::from)
85 .collect::<Vec<_>>(),
86 tool_choice,
87 additional_params: req.additional_params,
88 })
89 }
90}
91
92#[derive(Clone)]
93pub struct CompletionModel<T = reqwest::Client> {
94 pub(crate) client: Client<T>,
95 pub model: String,
96}
97
98impl<T> CompletionModel<T> {
99 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
100 Self {
101 client,
102 model: model.into(),
103 }
104 }
105}
106
107impl<T> completion::CompletionModel for CompletionModel<T>
108where
109 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
110{
111 type Response = CompletionResponse;
112 type StreamingResponse = openai::StreamingCompletionResponse;
113
114 type Client = Client<T>;
115
116 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
117 Self::new(client.clone(), model)
118 }
119
120 #[cfg_attr(feature = "worker", worker::send)]
121 async fn completion(
122 &self,
123 completion_request: completion::CompletionRequest,
124 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
125 let preamble = completion_request.preamble.clone();
126 let request =
127 XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?;
128 let request_messages_json_str = serde_json::to_string(&request.messages).unwrap();
129
130 let span = if tracing::Span::current().is_disabled() {
131 info_span!(
132 target: "rig::completions",
133 "chat",
134 gen_ai.operation.name = "chat",
135 gen_ai.provider.name = "xai",
136 gen_ai.request.model = self.model,
137 gen_ai.system_instructions = preamble,
138 gen_ai.response.id = tracing::field::Empty,
139 gen_ai.response.model = tracing::field::Empty,
140 gen_ai.usage.output_tokens = tracing::field::Empty,
141 gen_ai.usage.input_tokens = tracing::field::Empty,
142 gen_ai.input.messages = &request_messages_json_str,
143 gen_ai.output.messages = tracing::field::Empty,
144 )
145 } else {
146 tracing::Span::current()
147 };
148
149 tracing::debug!("xAI completion request: {request_messages_json_str}");
150
151 let body = serde_json::to_vec(&request)?;
152 let req = self
153 .client
154 .post("/v1/chat/completions")?
155 .body(body)
156 .map_err(|e| CompletionError::HttpError(e.into()))?;
157
158 async move {
159 let response = self.client.http_client().send::<_, Bytes>(req).await?;
160 let status = response.status();
161 let response_body = response.into_body().into_future().await?.to_vec();
162
163 if status.is_success() {
164 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
165 ApiResponse::Ok(completion) => completion.try_into(),
166 ApiResponse::Error(error) => {
167 Err(CompletionError::ProviderError(error.message()))
168 }
169 }
170 } else {
171 Err(CompletionError::ProviderError(
172 String::from_utf8_lossy(&response_body).to_string(),
173 ))
174 }
175 }
176 .instrument(span)
177 .await
178 }
179
180 #[cfg_attr(feature = "worker", worker::send)]
181 async fn stream(
182 &self,
183 request: CompletionRequest,
184 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
185 CompletionModel::stream(self, request).await
186 }
187}
188
189pub mod xai_api_types {
190 use serde::{Deserialize, Serialize};
191
192 use crate::OneOrMany;
193 use crate::completion::{self, CompletionError};
194 use crate::providers::openai::{AssistantContent, Message};
195
196 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
197 type Error = CompletionError;
198
199 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
200 let choice = response.choices.first().ok_or_else(|| {
201 CompletionError::ResponseError("Response contained no choices".to_owned())
202 })?;
203 let content = match &choice.message {
204 Message::Assistant {
205 content,
206 tool_calls,
207 ..
208 } => {
209 let mut content = content
210 .iter()
211 .map(|c| match c {
212 AssistantContent::Text { text } => {
213 completion::AssistantContent::text(text)
214 }
215 AssistantContent::Refusal { refusal } => {
216 completion::AssistantContent::text(refusal)
217 }
218 })
219 .collect::<Vec<_>>();
220
221 content.extend(
222 tool_calls
223 .iter()
224 .map(|call| {
225 completion::AssistantContent::tool_call(
226 &call.id,
227 &call.function.name,
228 call.function.arguments.clone(),
229 )
230 })
231 .collect::<Vec<_>>(),
232 );
233 Ok(content)
234 }
235 _ => Err(CompletionError::ResponseError(
236 "Response did not contain a valid message or tool call".into(),
237 )),
238 }?;
239
240 let choice = OneOrMany::many(content).map_err(|_| {
241 CompletionError::ResponseError(
242 "Response contained no message or tool call (empty)".to_owned(),
243 )
244 })?;
245
246 let usage = completion::Usage {
247 input_tokens: response.usage.prompt_tokens as u64,
248 output_tokens: response.usage.completion_tokens as u64,
249 total_tokens: response.usage.total_tokens as u64,
250 };
251
252 Ok(completion::CompletionResponse {
253 choice,
254 usage,
255 raw_response: response,
256 })
257 }
258 }
259
260 impl From<completion::ToolDefinition> for ToolDefinition {
261 fn from(tool: completion::ToolDefinition) -> Self {
262 Self {
263 r#type: "function".into(),
264 function: tool,
265 }
266 }
267 }
268
269 #[derive(Clone, Debug, Deserialize, Serialize)]
270 pub struct ToolDefinition {
271 pub r#type: String,
272 pub function: completion::ToolDefinition,
273 }
274
275 #[derive(Debug, Deserialize)]
276 pub struct Function {
277 pub name: String,
278 pub arguments: String,
279 }
280
281 #[derive(Debug, Deserialize, Serialize)]
282 pub struct CompletionResponse {
283 pub id: String,
284 pub model: String,
285 pub choices: Vec<Choice>,
286 pub created: i64,
287 pub object: String,
288 pub system_fingerprint: String,
289 pub usage: Usage,
290 }
291
292 #[derive(Debug, Deserialize, Serialize)]
293 pub struct Choice {
294 pub finish_reason: String,
295 pub index: i32,
296 pub message: Message,
297 }
298
299 #[derive(Debug, Deserialize, Serialize)]
300 pub struct Usage {
301 pub completion_tokens: i32,
302 pub prompt_tokens: i32,
303 pub total_tokens: i32,
304 }
305}