rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9    providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use serde_json::{Value, json};
17use xai_api_types::{CompletionResponse, ToolDefinition};
18
19/// xAI completion models as of 2025-06-04
20pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27pub const GROK_4: &str = "grok-4-0709";
28
29// =================================================================
30// Rig Implementation Types
31// =================================================================
32
33#[derive(Clone)]
34pub struct CompletionModel {
35    pub(crate) client: Client,
36    pub model: String,
37}
38
39impl CompletionModel {
40    pub(crate) fn create_completion_request(
41        &self,
42        completion_request: completion::CompletionRequest,
43    ) -> Result<Value, CompletionError> {
44        // Convert documents into user message
45        let docs: Option<Vec<Message>> = completion_request
46            .normalized_documents()
47            .map(|docs| docs.try_into())
48            .transpose()?;
49
50        // Convert existing chat history
51        let chat_history: Vec<Message> = completion_request
52            .chat_history
53            .into_iter()
54            .map(|message| message.try_into())
55            .collect::<Result<Vec<Vec<Message>>, _>>()?
56            .into_iter()
57            .flatten()
58            .collect();
59
60        // Init full history with preamble (or empty if non-existent)
61        let mut full_history: Vec<Message> = match &completion_request.preamble {
62            Some(preamble) => vec![Message::system(preamble)],
63            None => vec![],
64        };
65
66        // Docs appear right after preamble, if they exist
67        if let Some(docs) = docs {
68            full_history.extend(docs)
69        }
70
71        // Chat history and prompt appear in the order they were provided
72        full_history.extend(chat_history);
73
74        let mut request = if completion_request.tools.is_empty() {
75            json!({
76                "model": self.model,
77                "messages": full_history,
78                "temperature": completion_request.temperature,
79            })
80        } else {
81            json!({
82                "model": self.model,
83                "messages": full_history,
84                "temperature": completion_request.temperature,
85                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
86                "tool_choice": "auto",
87            })
88        };
89
90        request = if let Some(params) = completion_request.additional_params {
91            json_utils::merge(request, params)
92        } else {
93            request
94        };
95
96        Ok(request)
97    }
98    pub fn new(client: Client, model: &str) -> Self {
99        Self {
100            client,
101            model: model.to_string(),
102        }
103    }
104}
105
106impl completion::CompletionModel for CompletionModel {
107    type Response = CompletionResponse;
108    type StreamingResponse = openai::StreamingCompletionResponse;
109
110    #[cfg_attr(feature = "worker", worker::send)]
111    async fn completion(
112        &self,
113        completion_request: completion::CompletionRequest,
114    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
115        let request = self.create_completion_request(completion_request)?;
116
117        let response = self
118            .client
119            .post("/v1/chat/completions")
120            .json(&request)
121            .send()
122            .await?;
123
124        if response.status().is_success() {
125            match response.json::<ApiResponse<CompletionResponse>>().await? {
126                ApiResponse::Ok(completion) => completion.try_into(),
127                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
128            }
129        } else {
130            Err(CompletionError::ProviderError(response.text().await?))
131        }
132    }
133
134    #[cfg_attr(feature = "worker", worker::send)]
135    async fn stream(
136        &self,
137        request: CompletionRequest,
138    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
139        CompletionModel::stream(self, request).await
140    }
141}
142
143pub mod xai_api_types {
144    use serde::{Deserialize, Serialize};
145
146    use crate::OneOrMany;
147    use crate::completion::{self, CompletionError};
148    use crate::providers::openai::{AssistantContent, Message};
149
150    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
151        type Error = CompletionError;
152
153        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
154            let choice = response.choices.first().ok_or_else(|| {
155                CompletionError::ResponseError("Response contained no choices".to_owned())
156            })?;
157            let content = match &choice.message {
158                Message::Assistant {
159                    content,
160                    tool_calls,
161                    ..
162                } => {
163                    let mut content = content
164                        .iter()
165                        .map(|c| match c {
166                            AssistantContent::Text { text } => {
167                                completion::AssistantContent::text(text)
168                            }
169                            AssistantContent::Refusal { refusal } => {
170                                completion::AssistantContent::text(refusal)
171                            }
172                        })
173                        .collect::<Vec<_>>();
174
175                    content.extend(
176                        tool_calls
177                            .iter()
178                            .map(|call| {
179                                completion::AssistantContent::tool_call(
180                                    &call.id,
181                                    &call.function.name,
182                                    call.function.arguments.clone(),
183                                )
184                            })
185                            .collect::<Vec<_>>(),
186                    );
187                    Ok(content)
188                }
189                _ => Err(CompletionError::ResponseError(
190                    "Response did not contain a valid message or tool call".into(),
191                )),
192            }?;
193
194            let choice = OneOrMany::many(content).map_err(|_| {
195                CompletionError::ResponseError(
196                    "Response contained no message or tool call (empty)".to_owned(),
197                )
198            })?;
199
200            let usage = completion::Usage {
201                input_tokens: response.usage.prompt_tokens as u64,
202                output_tokens: response.usage.completion_tokens as u64,
203                total_tokens: response.usage.total_tokens as u64,
204            };
205
206            Ok(completion::CompletionResponse {
207                choice,
208                usage,
209                raw_response: response,
210            })
211        }
212    }
213
214    impl From<completion::ToolDefinition> for ToolDefinition {
215        fn from(tool: completion::ToolDefinition) -> Self {
216            Self {
217                r#type: "function".into(),
218                function: tool,
219            }
220        }
221    }
222
223    #[derive(Clone, Debug, Deserialize, Serialize)]
224    pub struct ToolDefinition {
225        pub r#type: String,
226        pub function: completion::ToolDefinition,
227    }
228
229    #[derive(Debug, Deserialize)]
230    pub struct Function {
231        pub name: String,
232        pub arguments: String,
233    }
234
235    #[derive(Debug, Deserialize, Serialize)]
236    pub struct CompletionResponse {
237        pub id: String,
238        pub model: String,
239        pub choices: Vec<Choice>,
240        pub created: i64,
241        pub object: String,
242        pub system_fingerprint: String,
243        pub usage: Usage,
244    }
245
246    #[derive(Debug, Deserialize, Serialize)]
247    pub struct Choice {
248        pub finish_reason: String,
249        pub index: i32,
250        pub message: Message,
251    }
252
253    #[derive(Debug, Deserialize, Serialize)]
254    pub struct Usage {
255        pub completion_tokens: i32,
256        pub prompt_tokens: i32,
257        pub total_tokens: i32,
258    }
259}