rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    http_client, json_utils,
9    providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18use xai_api_types::{CompletionResponse, ToolDefinition};
19
20/// xAI completion models as of 2025-06-04
21pub const GROK_2_1212: &str = "grok-2-1212";
22pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
23pub const GROK_3: &str = "grok-3";
24pub const GROK_3_FAST: &str = "grok-3-fast";
25pub const GROK_3_MINI: &str = "grok-3-mini";
26pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
27pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
28pub const GROK_4: &str = "grok-4-0709";
29
30// =================================================================
31// Rig Implementation Types
32// =================================================================
33
34#[derive(Clone)]
35pub struct CompletionModel<T = reqwest::Client> {
36    pub(crate) client: Client<T>,
37    pub model: String,
38}
39
40impl<T> CompletionModel<T> {
41    pub(crate) fn create_completion_request(
42        &self,
43        completion_request: completion::CompletionRequest,
44    ) -> Result<Value, CompletionError> {
45        // Convert documents into user message
46        let docs: Option<Vec<Message>> = completion_request
47            .normalized_documents()
48            .map(|docs| docs.try_into())
49            .transpose()?;
50
51        // Convert existing chat history
52        let chat_history: Vec<Message> = completion_request
53            .chat_history
54            .into_iter()
55            .map(|message| message.try_into())
56            .collect::<Result<Vec<Vec<Message>>, _>>()?
57            .into_iter()
58            .flatten()
59            .collect();
60
61        // Init full history with preamble (or empty if non-existent)
62        let mut full_history: Vec<Message> = match &completion_request.preamble {
63            Some(preamble) => vec![Message::system(preamble)],
64            None => vec![],
65        };
66
67        // Docs appear right after preamble, if they exist
68        if let Some(docs) = docs {
69            full_history.extend(docs)
70        }
71
72        // Chat history and prompt appear in the order they were provided
73        full_history.extend(chat_history);
74
75        let tool_choice = completion_request
76            .tool_choice
77            .map(crate::providers::openrouter::ToolChoice::try_from)
78            .transpose()?;
79
80        let mut request = if completion_request.tools.is_empty() {
81            json!({
82                "model": self.model,
83                "messages": full_history,
84                "temperature": completion_request.temperature,
85            })
86        } else {
87            json!({
88                "model": self.model,
89                "messages": full_history,
90                "temperature": completion_request.temperature,
91                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
92                "tool_choice": tool_choice,
93            })
94        };
95
96        request = if let Some(params) = completion_request.additional_params {
97            json_utils::merge(request, params)
98        } else {
99            request
100        };
101
102        Ok(request)
103    }
104
105    pub fn new(client: Client<T>, model: &str) -> Self {
106        Self {
107            client,
108            model: model.to_string(),
109        }
110    }
111}
112
113impl completion::CompletionModel for CompletionModel<reqwest::Client> {
114    type Response = CompletionResponse;
115    type StreamingResponse = openai::StreamingCompletionResponse;
116
117    #[cfg_attr(feature = "worker", worker::send)]
118    async fn completion(
119        &self,
120        completion_request: completion::CompletionRequest,
121    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
122        let preamble = completion_request.preamble.clone();
123        let request = self.create_completion_request(completion_request)?;
124        let request_messages_json_str =
125            serde_json::to_string(&request.get("messages").unwrap()).unwrap();
126
127        let span = if tracing::Span::current().is_disabled() {
128            info_span!(
129                target: "rig::completions",
130                "chat",
131                gen_ai.operation.name = "chat",
132                gen_ai.provider.name = "xai",
133                gen_ai.request.model = self.model,
134                gen_ai.system_instructions = preamble,
135                gen_ai.response.id = tracing::field::Empty,
136                gen_ai.response.model = tracing::field::Empty,
137                gen_ai.usage.output_tokens = tracing::field::Empty,
138                gen_ai.usage.input_tokens = tracing::field::Empty,
139                gen_ai.input.messages = &request_messages_json_str,
140                gen_ai.output.messages = tracing::field::Empty,
141            )
142        } else {
143            tracing::Span::current()
144        };
145
146        tracing::debug!("xAI completion request: {request_messages_json_str}");
147
148        async move {
149            let response = self
150                .client
151                .reqwest_post("/v1/chat/completions")
152                .json(&request)
153                .send()
154                .await
155                .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
156
157            if response.status().is_success() {
158                match response
159                    .json::<ApiResponse<CompletionResponse>>()
160                    .await
161                    .map_err(|e| {
162                        CompletionError::HttpError(http_client::Error::Instance(e.into()))
163                    })? {
164                    ApiResponse::Ok(completion) => completion.try_into(),
165                    ApiResponse::Error(error) => {
166                        Err(CompletionError::ProviderError(error.message()))
167                    }
168                }
169            } else {
170                Err(CompletionError::ProviderError(
171                    response.text().await.map_err(|e| {
172                        CompletionError::HttpError(http_client::Error::Instance(e.into()))
173                    })?,
174                ))
175            }
176        }
177        .instrument(span)
178        .await
179    }
180
181    #[cfg_attr(feature = "worker", worker::send)]
182    async fn stream(
183        &self,
184        request: CompletionRequest,
185    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
186        CompletionModel::stream(self, request).await
187    }
188}
189
190pub mod xai_api_types {
191    use serde::{Deserialize, Serialize};
192
193    use crate::OneOrMany;
194    use crate::completion::{self, CompletionError};
195    use crate::providers::openai::{AssistantContent, Message};
196
197    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
198        type Error = CompletionError;
199
200        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
201            let choice = response.choices.first().ok_or_else(|| {
202                CompletionError::ResponseError("Response contained no choices".to_owned())
203            })?;
204            let content = match &choice.message {
205                Message::Assistant {
206                    content,
207                    tool_calls,
208                    ..
209                } => {
210                    let mut content = content
211                        .iter()
212                        .map(|c| match c {
213                            AssistantContent::Text { text } => {
214                                completion::AssistantContent::text(text)
215                            }
216                            AssistantContent::Refusal { refusal } => {
217                                completion::AssistantContent::text(refusal)
218                            }
219                        })
220                        .collect::<Vec<_>>();
221
222                    content.extend(
223                        tool_calls
224                            .iter()
225                            .map(|call| {
226                                completion::AssistantContent::tool_call(
227                                    &call.id,
228                                    &call.function.name,
229                                    call.function.arguments.clone(),
230                                )
231                            })
232                            .collect::<Vec<_>>(),
233                    );
234                    Ok(content)
235                }
236                _ => Err(CompletionError::ResponseError(
237                    "Response did not contain a valid message or tool call".into(),
238                )),
239            }?;
240
241            let choice = OneOrMany::many(content).map_err(|_| {
242                CompletionError::ResponseError(
243                    "Response contained no message or tool call (empty)".to_owned(),
244                )
245            })?;
246
247            let usage = completion::Usage {
248                input_tokens: response.usage.prompt_tokens as u64,
249                output_tokens: response.usage.completion_tokens as u64,
250                total_tokens: response.usage.total_tokens as u64,
251            };
252
253            Ok(completion::CompletionResponse {
254                choice,
255                usage,
256                raw_response: response,
257            })
258        }
259    }
260
261    impl From<completion::ToolDefinition> for ToolDefinition {
262        fn from(tool: completion::ToolDefinition) -> Self {
263            Self {
264                r#type: "function".into(),
265                function: tool,
266            }
267        }
268    }
269
270    #[derive(Clone, Debug, Deserialize, Serialize)]
271    pub struct ToolDefinition {
272        pub r#type: String,
273        pub function: completion::ToolDefinition,
274    }
275
276    #[derive(Debug, Deserialize)]
277    pub struct Function {
278        pub name: String,
279        pub arguments: String,
280    }
281
282    #[derive(Debug, Deserialize, Serialize)]
283    pub struct CompletionResponse {
284        pub id: String,
285        pub model: String,
286        pub choices: Vec<Choice>,
287        pub created: i64,
288        pub object: String,
289        pub system_fingerprint: String,
290        pub usage: Usage,
291    }
292
293    #[derive(Debug, Deserialize, Serialize)]
294    pub struct Choice {
295        pub finish_reason: String,
296        pub index: i32,
297        pub message: Message,
298    }
299
300    #[derive(Debug, Deserialize, Serialize)]
301    pub struct Usage {
302        pub completion_tokens: i32,
303        pub prompt_tokens: i32,
304        pub total_tokens: i32,
305    }
306}