1use crate::{
7 completion::{self, CompletionError},
8 http_client::HttpClientExt,
9 json_utils,
10 providers::openai::Message,
11};
12
13use super::client::{Client, xai_api_types::ApiResponse};
14use crate::completion::CompletionRequest;
15use crate::providers::openai;
16use crate::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde_json::{Value, json};
19use tracing::{Instrument, info_span};
20use xai_api_types::{CompletionResponse, ToolDefinition};
21
22pub const GROK_2_1212: &str = "grok-2-1212";
24pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
25pub const GROK_3: &str = "grok-3";
26pub const GROK_3_FAST: &str = "grok-3-fast";
27pub const GROK_3_MINI: &str = "grok-3-mini";
28pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
29pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
30pub const GROK_4: &str = "grok-4-0709";
31
32#[derive(Clone)]
37pub struct CompletionModel<T = reqwest::Client> {
38 pub(crate) client: Client<T>,
39 pub model: String,
40}
41
42impl<T> CompletionModel<T> {
43 pub(crate) fn create_completion_request(
44 &self,
45 completion_request: completion::CompletionRequest,
46 ) -> Result<Value, CompletionError> {
47 let docs: Option<Vec<Message>> = completion_request
49 .normalized_documents()
50 .map(|docs| docs.try_into())
51 .transpose()?;
52
53 let chat_history: Vec<Message> = completion_request
55 .chat_history
56 .into_iter()
57 .map(|message| message.try_into())
58 .collect::<Result<Vec<Vec<Message>>, _>>()?
59 .into_iter()
60 .flatten()
61 .collect();
62
63 let mut full_history: Vec<Message> = match &completion_request.preamble {
65 Some(preamble) => vec![Message::system(preamble)],
66 None => vec![],
67 };
68
69 if let Some(docs) = docs {
71 full_history.extend(docs)
72 }
73
74 full_history.extend(chat_history);
76
77 let tool_choice = completion_request
78 .tool_choice
79 .map(crate::providers::openrouter::ToolChoice::try_from)
80 .transpose()?;
81
82 let mut request = if completion_request.tools.is_empty() {
83 json!({
84 "model": self.model,
85 "messages": full_history,
86 "temperature": completion_request.temperature,
87 })
88 } else {
89 json!({
90 "model": self.model,
91 "messages": full_history,
92 "temperature": completion_request.temperature,
93 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
94 "tool_choice": tool_choice,
95 })
96 };
97
98 request = if let Some(params) = completion_request.additional_params {
99 json_utils::merge(request, params)
100 } else {
101 request
102 };
103
104 Ok(request)
105 }
106
107 pub fn new(client: Client<T>, model: &str) -> Self {
108 Self {
109 client,
110 model: model.to_string(),
111 }
112 }
113}
114
115impl<T> completion::CompletionModel for CompletionModel<T>
116where
117 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
118{
119 type Response = CompletionResponse;
120 type StreamingResponse = openai::StreamingCompletionResponse;
121
122 #[cfg_attr(feature = "worker", worker::send)]
123 async fn completion(
124 &self,
125 completion_request: completion::CompletionRequest,
126 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
127 let preamble = completion_request.preamble.clone();
128 let request = self.create_completion_request(completion_request)?;
129 let request_messages_json_str =
130 serde_json::to_string(&request.get("messages").unwrap()).unwrap();
131
132 let span = if tracing::Span::current().is_disabled() {
133 info_span!(
134 target: "rig::completions",
135 "chat",
136 gen_ai.operation.name = "chat",
137 gen_ai.provider.name = "xai",
138 gen_ai.request.model = self.model,
139 gen_ai.system_instructions = preamble,
140 gen_ai.response.id = tracing::field::Empty,
141 gen_ai.response.model = tracing::field::Empty,
142 gen_ai.usage.output_tokens = tracing::field::Empty,
143 gen_ai.usage.input_tokens = tracing::field::Empty,
144 gen_ai.input.messages = &request_messages_json_str,
145 gen_ai.output.messages = tracing::field::Empty,
146 )
147 } else {
148 tracing::Span::current()
149 };
150
151 tracing::debug!("xAI completion request: {request_messages_json_str}");
152
153 let body = serde_json::to_vec(&request)?;
154 let req = self
155 .client
156 .post("/v1/chat/completions")?
157 .header("Content-Type", "application/json")
158 .body(body)
159 .map_err(|e| CompletionError::HttpError(e.into()))?;
160
161 async move {
162 let response = self.client.http_client.send::<_, Bytes>(req).await?;
163 let status = response.status();
164 let response_body = response.into_body().into_future().await?.to_vec();
165
166 if status.is_success() {
167 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
168 ApiResponse::Ok(completion) => completion.try_into(),
169 ApiResponse::Error(error) => {
170 Err(CompletionError::ProviderError(error.message()))
171 }
172 }
173 } else {
174 Err(CompletionError::ProviderError(
175 String::from_utf8_lossy(&response_body).to_string(),
176 ))
177 }
178 }
179 .instrument(span)
180 .await
181 }
182
183 #[cfg_attr(feature = "worker", worker::send)]
184 async fn stream(
185 &self,
186 request: CompletionRequest,
187 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
188 CompletionModel::stream(self, request).await
189 }
190}
191
192pub mod xai_api_types {
193 use serde::{Deserialize, Serialize};
194
195 use crate::OneOrMany;
196 use crate::completion::{self, CompletionError};
197 use crate::providers::openai::{AssistantContent, Message};
198
199 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
200 type Error = CompletionError;
201
202 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
203 let choice = response.choices.first().ok_or_else(|| {
204 CompletionError::ResponseError("Response contained no choices".to_owned())
205 })?;
206 let content = match &choice.message {
207 Message::Assistant {
208 content,
209 tool_calls,
210 ..
211 } => {
212 let mut content = content
213 .iter()
214 .map(|c| match c {
215 AssistantContent::Text { text } => {
216 completion::AssistantContent::text(text)
217 }
218 AssistantContent::Refusal { refusal } => {
219 completion::AssistantContent::text(refusal)
220 }
221 })
222 .collect::<Vec<_>>();
223
224 content.extend(
225 tool_calls
226 .iter()
227 .map(|call| {
228 completion::AssistantContent::tool_call(
229 &call.id,
230 &call.function.name,
231 call.function.arguments.clone(),
232 )
233 })
234 .collect::<Vec<_>>(),
235 );
236 Ok(content)
237 }
238 _ => Err(CompletionError::ResponseError(
239 "Response did not contain a valid message or tool call".into(),
240 )),
241 }?;
242
243 let choice = OneOrMany::many(content).map_err(|_| {
244 CompletionError::ResponseError(
245 "Response contained no message or tool call (empty)".to_owned(),
246 )
247 })?;
248
249 let usage = completion::Usage {
250 input_tokens: response.usage.prompt_tokens as u64,
251 output_tokens: response.usage.completion_tokens as u64,
252 total_tokens: response.usage.total_tokens as u64,
253 };
254
255 Ok(completion::CompletionResponse {
256 choice,
257 usage,
258 raw_response: response,
259 })
260 }
261 }
262
263 impl From<completion::ToolDefinition> for ToolDefinition {
264 fn from(tool: completion::ToolDefinition) -> Self {
265 Self {
266 r#type: "function".into(),
267 function: tool,
268 }
269 }
270 }
271
272 #[derive(Clone, Debug, Deserialize, Serialize)]
273 pub struct ToolDefinition {
274 pub r#type: String,
275 pub function: completion::ToolDefinition,
276 }
277
278 #[derive(Debug, Deserialize)]
279 pub struct Function {
280 pub name: String,
281 pub arguments: String,
282 }
283
284 #[derive(Debug, Deserialize, Serialize)]
285 pub struct CompletionResponse {
286 pub id: String,
287 pub model: String,
288 pub choices: Vec<Choice>,
289 pub created: i64,
290 pub object: String,
291 pub system_fingerprint: String,
292 pub usage: Usage,
293 }
294
295 #[derive(Debug, Deserialize, Serialize)]
296 pub struct Choice {
297 pub finish_reason: String,
298 pub index: i32,
299 pub message: Message,
300 }
301
302 #[derive(Debug, Deserialize, Serialize)]
303 pub struct Usage {
304 pub completion_tokens: i32,
305 pub prompt_tokens: i32,
306 pub total_tokens: i32,
307 }
308}