1use crate::{
7 completion::{self, CompletionError},
8 http_client, json_utils,
9 providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18use xai_api_types::{CompletionResponse, ToolDefinition};
19
20pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30#[derive(Clone)]
35pub struct CompletionModel<T = reqwest::Client> {
36 pub(crate) client: Client<T>,
37 pub model: String,
38}
39
40impl<T> CompletionModel<T> {
41 pub(crate) fn create_completion_request(
42 &self,
43 completion_request: completion::CompletionRequest,
44 ) -> Result<Value, CompletionError> {
45 let docs: Option<Vec<Message>> = completion_request
47 .normalized_documents()
48 .map(|docs| docs.try_into())
49 .transpose()?;
50
51 let chat_history: Vec<Message> = completion_request
53 .chat_history
54 .into_iter()
55 .map(|message| message.try_into())
56 .collect::<Result<Vec<Vec<Message>>, _>>()?
57 .into_iter()
58 .flatten()
59 .collect();
60
61 let mut full_history: Vec<Message> = match &completion_request.preamble {
63 Some(preamble) => vec![Message::system(preamble)],
64 None => vec![],
65 };
66
67 if let Some(docs) = docs {
69 full_history.extend(docs)
70 }
71
72 full_history.extend(chat_history);
74
75 let tool_choice = completion_request
76 .tool_choice
77 .map(crate::providers::openrouter::ToolChoice::try_from)
78 .transpose()?;
79
80 let mut request = if completion_request.tools.is_empty() {
81 json!({
82 "model": self.model,
83 "messages": full_history,
84 "temperature": completion_request.temperature,
85 })
86 } else {
87 json!({
88 "model": self.model,
89 "messages": full_history,
90 "temperature": completion_request.temperature,
91 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
92 "tool_choice": tool_choice,
93 })
94 };
95
96 request = if let Some(params) = completion_request.additional_params {
97 json_utils::merge(request, params)
98 } else {
99 request
100 };
101
102 Ok(request)
103 }
104
105 pub fn new(client: Client<T>, model: &str) -> Self {
106 Self {
107 client,
108 model: model.to_string(),
109 }
110 }
111}
112
113impl completion::CompletionModel for CompletionModel<reqwest::Client> {
114 type Response = CompletionResponse;
115 type StreamingResponse = openai::StreamingCompletionResponse;
116
117 #[cfg_attr(feature = "worker", worker::send)]
118 async fn completion(
119 &self,
120 completion_request: completion::CompletionRequest,
121 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
122 let preamble = completion_request.preamble.clone();
123 let request = self.create_completion_request(completion_request)?;
124 let request_messages_json_str =
125 serde_json::to_string(&request.get("messages").unwrap()).unwrap();
126
127 let span = if tracing::Span::current().is_disabled() {
128 info_span!(
129 target: "rig::completions",
130 "chat",
131 gen_ai.operation.name = "chat",
132 gen_ai.provider.name = "xai",
133 gen_ai.request.model = self.model,
134 gen_ai.system_instructions = preamble,
135 gen_ai.response.id = tracing::field::Empty,
136 gen_ai.response.model = tracing::field::Empty,
137 gen_ai.usage.output_tokens = tracing::field::Empty,
138 gen_ai.usage.input_tokens = tracing::field::Empty,
139 gen_ai.input.messages = &request_messages_json_str,
140 gen_ai.output.messages = tracing::field::Empty,
141 )
142 } else {
143 tracing::Span::current()
144 };
145
146 tracing::debug!("xAI completion request: {request_messages_json_str}");
147
148 async move {
149 let response = self
150 .client
151 .reqwest_post("/v1/chat/completions")
152 .json(&request)
153 .send()
154 .await
155 .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
156
157 if response.status().is_success() {
158 match response
159 .json::<ApiResponse<CompletionResponse>>()
160 .await
161 .map_err(|e| {
162 CompletionError::HttpError(http_client::Error::Instance(e.into()))
163 })? {
164 ApiResponse::Ok(completion) => completion.try_into(),
165 ApiResponse::Error(error) => {
166 Err(CompletionError::ProviderError(error.message()))
167 }
168 }
169 } else {
170 Err(CompletionError::ProviderError(
171 response.text().await.map_err(|e| {
172 CompletionError::HttpError(http_client::Error::Instance(e.into()))
173 })?,
174 ))
175 }
176 }
177 .instrument(span)
178 .await
179 }
180
181 #[cfg_attr(feature = "worker", worker::send)]
182 async fn stream(
183 &self,
184 request: CompletionRequest,
185 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
186 CompletionModel::stream(self, request).await
187 }
188}
189
190pub mod xai_api_types {
191 use serde::{Deserialize, Serialize};
192
193 use crate::OneOrMany;
194 use crate::completion::{self, CompletionError};
195 use crate::providers::openai::{AssistantContent, Message};
196
197 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
198 type Error = CompletionError;
199
200 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
201 let choice = response.choices.first().ok_or_else(|| {
202 CompletionError::ResponseError("Response contained no choices".to_owned())
203 })?;
204 let content = match &choice.message {
205 Message::Assistant {
206 content,
207 tool_calls,
208 ..
209 } => {
210 let mut content = content
211 .iter()
212 .map(|c| match c {
213 AssistantContent::Text { text } => {
214 completion::AssistantContent::text(text)
215 }
216 AssistantContent::Refusal { refusal } => {
217 completion::AssistantContent::text(refusal)
218 }
219 })
220 .collect::<Vec<_>>();
221
222 content.extend(
223 tool_calls
224 .iter()
225 .map(|call| {
226 completion::AssistantContent::tool_call(
227 &call.id,
228 &call.function.name,
229 call.function.arguments.clone(),
230 )
231 })
232 .collect::<Vec<_>>(),
233 );
234 Ok(content)
235 }
236 _ => Err(CompletionError::ResponseError(
237 "Response did not contain a valid message or tool call".into(),
238 )),
239 }?;
240
241 let choice = OneOrMany::many(content).map_err(|_| {
242 CompletionError::ResponseError(
243 "Response contained no message or tool call (empty)".to_owned(),
244 )
245 })?;
246
247 let usage = completion::Usage {
248 input_tokens: response.usage.prompt_tokens as u64,
249 output_tokens: response.usage.completion_tokens as u64,
250 total_tokens: response.usage.total_tokens as u64,
251 };
252
253 Ok(completion::CompletionResponse {
254 choice,
255 usage,
256 raw_response: response,
257 })
258 }
259 }
260
261 impl From<completion::ToolDefinition> for ToolDefinition {
262 fn from(tool: completion::ToolDefinition) -> Self {
263 Self {
264 r#type: "function".into(),
265 function: tool,
266 }
267 }
268 }
269
270 #[derive(Clone, Debug, Deserialize, Serialize)]
271 pub struct ToolDefinition {
272 pub r#type: String,
273 pub function: completion::ToolDefinition,
274 }
275
276 #[derive(Debug, Deserialize)]
277 pub struct Function {
278 pub name: String,
279 pub arguments: String,
280 }
281
282 #[derive(Debug, Deserialize, Serialize)]
283 pub struct CompletionResponse {
284 pub id: String,
285 pub model: String,
286 pub choices: Vec<Choice>,
287 pub created: i64,
288 pub object: String,
289 pub system_fingerprint: String,
290 pub usage: Usage,
291 }
292
293 #[derive(Debug, Deserialize, Serialize)]
294 pub struct Choice {
295 pub finish_reason: String,
296 pub index: i32,
297 pub message: Message,
298 }
299
300 #[derive(Debug, Deserialize, Serialize)]
301 pub struct Usage {
302 pub completion_tokens: i32,
303 pub prompt_tokens: i32,
304 pub total_tokens: i32,
305 }
306}