rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    http_client::HttpClientExt,
9    json_utils,
10    providers::openai::Message,
11};
12
13use super::client::{Client, xai_api_types::ApiResponse};
14use crate::completion::CompletionRequest;
15use crate::providers::openai;
16use crate::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde_json::{Value, json};
19use tracing::{Instrument, info_span};
20use xai_api_types::{CompletionResponse, ToolDefinition};
21
22/// xAI completion models as of 2025-06-04
23pub const GROK_2_1212: &str = "grok-2-1212";
24pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
25pub const GROK_3: &str = "grok-3";
26pub const GROK_3_FAST: &str = "grok-3-fast";
27pub const GROK_3_MINI: &str = "grok-3-mini";
28pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
29pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
30pub const GROK_4: &str = "grok-4-0709";
31
32// =================================================================
33// Rig Implementation Types
34// =================================================================
35
36#[derive(Clone)]
37pub struct CompletionModel<T = reqwest::Client> {
38    pub(crate) client: Client<T>,
39    pub model: String,
40}
41
42impl<T> CompletionModel<T> {
43    pub(crate) fn create_completion_request(
44        &self,
45        completion_request: completion::CompletionRequest,
46    ) -> Result<Value, CompletionError> {
47        // Convert documents into user message
48        let docs: Option<Vec<Message>> = completion_request
49            .normalized_documents()
50            .map(|docs| docs.try_into())
51            .transpose()?;
52
53        // Convert existing chat history
54        let chat_history: Vec<Message> = completion_request
55            .chat_history
56            .into_iter()
57            .map(|message| message.try_into())
58            .collect::<Result<Vec<Vec<Message>>, _>>()?
59            .into_iter()
60            .flatten()
61            .collect();
62
63        // Init full history with preamble (or empty if non-existent)
64        let mut full_history: Vec<Message> = match &completion_request.preamble {
65            Some(preamble) => vec![Message::system(preamble)],
66            None => vec![],
67        };
68
69        // Docs appear right after preamble, if they exist
70        if let Some(docs) = docs {
71            full_history.extend(docs)
72        }
73
74        // Chat history and prompt appear in the order they were provided
75        full_history.extend(chat_history);
76
77        let tool_choice = completion_request
78            .tool_choice
79            .map(crate::providers::openrouter::ToolChoice::try_from)
80            .transpose()?;
81
82        let mut request = if completion_request.tools.is_empty() {
83            json!({
84                "model": self.model,
85                "messages": full_history,
86                "temperature": completion_request.temperature,
87            })
88        } else {
89            json!({
90                "model": self.model,
91                "messages": full_history,
92                "temperature": completion_request.temperature,
93                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
94                "tool_choice": tool_choice,
95            })
96        };
97
98        request = if let Some(params) = completion_request.additional_params {
99            json_utils::merge(request, params)
100        } else {
101            request
102        };
103
104        Ok(request)
105    }
106
107    pub fn new(client: Client<T>, model: &str) -> Self {
108        Self {
109            client,
110            model: model.to_string(),
111        }
112    }
113}
114
115impl<T> completion::CompletionModel for CompletionModel<T>
116where
117    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
118{
119    type Response = CompletionResponse;
120    type StreamingResponse = openai::StreamingCompletionResponse;
121
122    #[cfg_attr(feature = "worker", worker::send)]
123    async fn completion(
124        &self,
125        completion_request: completion::CompletionRequest,
126    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
127        let preamble = completion_request.preamble.clone();
128        let request = self.create_completion_request(completion_request)?;
129        let request_messages_json_str =
130            serde_json::to_string(&request.get("messages").unwrap()).unwrap();
131
132        let span = if tracing::Span::current().is_disabled() {
133            info_span!(
134                target: "rig::completions",
135                "chat",
136                gen_ai.operation.name = "chat",
137                gen_ai.provider.name = "xai",
138                gen_ai.request.model = self.model,
139                gen_ai.system_instructions = preamble,
140                gen_ai.response.id = tracing::field::Empty,
141                gen_ai.response.model = tracing::field::Empty,
142                gen_ai.usage.output_tokens = tracing::field::Empty,
143                gen_ai.usage.input_tokens = tracing::field::Empty,
144                gen_ai.input.messages = &request_messages_json_str,
145                gen_ai.output.messages = tracing::field::Empty,
146            )
147        } else {
148            tracing::Span::current()
149        };
150
151        tracing::debug!("xAI completion request: {request_messages_json_str}");
152
153        let body = serde_json::to_vec(&request)?;
154        let req = self
155            .client
156            .post("/v1/chat/completions")?
157            .header("Content-Type", "application/json")
158            .body(body)
159            .map_err(|e| CompletionError::HttpError(e.into()))?;
160
161        async move {
162            let response = self.client.http_client.send::<_, Bytes>(req).await?;
163            let status = response.status();
164            let response_body = response.into_body().into_future().await?.to_vec();
165
166            if status.is_success() {
167                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
168                    ApiResponse::Ok(completion) => completion.try_into(),
169                    ApiResponse::Error(error) => {
170                        Err(CompletionError::ProviderError(error.message()))
171                    }
172                }
173            } else {
174                Err(CompletionError::ProviderError(
175                    String::from_utf8_lossy(&response_body).to_string(),
176                ))
177            }
178        }
179        .instrument(span)
180        .await
181    }
182
183    #[cfg_attr(feature = "worker", worker::send)]
184    async fn stream(
185        &self,
186        request: CompletionRequest,
187    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
188        CompletionModel::stream(self, request).await
189    }
190}
191
192pub mod xai_api_types {
193    use serde::{Deserialize, Serialize};
194
195    use crate::OneOrMany;
196    use crate::completion::{self, CompletionError};
197    use crate::providers::openai::{AssistantContent, Message};
198
199    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
200        type Error = CompletionError;
201
202        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
203            let choice = response.choices.first().ok_or_else(|| {
204                CompletionError::ResponseError("Response contained no choices".to_owned())
205            })?;
206            let content = match &choice.message {
207                Message::Assistant {
208                    content,
209                    tool_calls,
210                    ..
211                } => {
212                    let mut content = content
213                        .iter()
214                        .map(|c| match c {
215                            AssistantContent::Text { text } => {
216                                completion::AssistantContent::text(text)
217                            }
218                            AssistantContent::Refusal { refusal } => {
219                                completion::AssistantContent::text(refusal)
220                            }
221                        })
222                        .collect::<Vec<_>>();
223
224                    content.extend(
225                        tool_calls
226                            .iter()
227                            .map(|call| {
228                                completion::AssistantContent::tool_call(
229                                    &call.id,
230                                    &call.function.name,
231                                    call.function.arguments.clone(),
232                                )
233                            })
234                            .collect::<Vec<_>>(),
235                    );
236                    Ok(content)
237                }
238                _ => Err(CompletionError::ResponseError(
239                    "Response did not contain a valid message or tool call".into(),
240                )),
241            }?;
242
243            let choice = OneOrMany::many(content).map_err(|_| {
244                CompletionError::ResponseError(
245                    "Response contained no message or tool call (empty)".to_owned(),
246                )
247            })?;
248
249            let usage = completion::Usage {
250                input_tokens: response.usage.prompt_tokens as u64,
251                output_tokens: response.usage.completion_tokens as u64,
252                total_tokens: response.usage.total_tokens as u64,
253            };
254
255            Ok(completion::CompletionResponse {
256                choice,
257                usage,
258                raw_response: response,
259            })
260        }
261    }
262
263    impl From<completion::ToolDefinition> for ToolDefinition {
264        fn from(tool: completion::ToolDefinition) -> Self {
265            Self {
266                r#type: "function".into(),
267                function: tool,
268            }
269        }
270    }
271
272    #[derive(Clone, Debug, Deserialize, Serialize)]
273    pub struct ToolDefinition {
274        pub r#type: String,
275        pub function: completion::ToolDefinition,
276    }
277
278    #[derive(Debug, Deserialize)]
279    pub struct Function {
280        pub name: String,
281        pub arguments: String,
282    }
283
284    #[derive(Debug, Deserialize, Serialize)]
285    pub struct CompletionResponse {
286        pub id: String,
287        pub model: String,
288        pub choices: Vec<Choice>,
289        pub created: i64,
290        pub object: String,
291        pub system_fingerprint: String,
292        pub usage: Usage,
293    }
294
295    #[derive(Debug, Deserialize, Serialize)]
296    pub struct Choice {
297        pub finish_reason: String,
298        pub index: i32,
299        pub message: Message,
300    }
301
302    #[derive(Debug, Deserialize, Serialize)]
303    pub struct Usage {
304        pub completion_tokens: i32,
305        pub prompt_tokens: i32,
306        pub total_tokens: i32,
307    }
308}