rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9    providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use serde_json::{Value, json};
17use xai_api_types::{CompletionResponse, ToolDefinition};
18
19/// xAI completion models as of 2025-06-04
20pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27
28// =================================================================
29// Rig Implementation Types
30// =================================================================
31
32#[derive(Clone)]
33pub struct CompletionModel {
34    pub(crate) client: Client,
35    pub model: String,
36}
37
38impl CompletionModel {
39    pub(crate) fn create_completion_request(
40        &self,
41        completion_request: completion::CompletionRequest,
42    ) -> Result<Value, CompletionError> {
43        // Convert documents into user message
44        let docs: Option<Vec<Message>> = completion_request
45            .normalized_documents()
46            .map(|docs| docs.try_into())
47            .transpose()?;
48
49        // Convert existing chat history
50        let chat_history: Vec<Message> = completion_request
51            .chat_history
52            .into_iter()
53            .map(|message| message.try_into())
54            .collect::<Result<Vec<Vec<Message>>, _>>()?
55            .into_iter()
56            .flatten()
57            .collect();
58
59        // Init full history with preamble (or empty if non-existent)
60        let mut full_history: Vec<Message> = match &completion_request.preamble {
61            Some(preamble) => vec![Message::system(preamble)],
62            None => vec![],
63        };
64
65        // Docs appear right after preamble, if they exist
66        if let Some(docs) = docs {
67            full_history.extend(docs)
68        }
69
70        // Chat history and prompt appear in the order they were provided
71        full_history.extend(chat_history);
72
73        let mut request = if completion_request.tools.is_empty() {
74            json!({
75                "model": self.model,
76                "messages": full_history,
77                "temperature": completion_request.temperature,
78            })
79        } else {
80            json!({
81                "model": self.model,
82                "messages": full_history,
83                "temperature": completion_request.temperature,
84                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
85                "tool_choice": "auto",
86            })
87        };
88
89        request = if let Some(params) = completion_request.additional_params {
90            json_utils::merge(request, params)
91        } else {
92            request
93        };
94
95        Ok(request)
96    }
97    pub fn new(client: Client, model: &str) -> Self {
98        Self {
99            client,
100            model: model.to_string(),
101        }
102    }
103}
104
105impl completion::CompletionModel for CompletionModel {
106    type Response = CompletionResponse;
107    type StreamingResponse = openai::StreamingCompletionResponse;
108
109    #[cfg_attr(feature = "worker", worker::send)]
110    async fn completion(
111        &self,
112        completion_request: completion::CompletionRequest,
113    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
114        let request = self.create_completion_request(completion_request)?;
115
116        let response = self
117            .client
118            .post("/v1/chat/completions")
119            .json(&request)
120            .send()
121            .await?;
122
123        if response.status().is_success() {
124            match response.json::<ApiResponse<CompletionResponse>>().await? {
125                ApiResponse::Ok(completion) => completion.try_into(),
126                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
127            }
128        } else {
129            Err(CompletionError::ProviderError(response.text().await?))
130        }
131    }
132
133    #[cfg_attr(feature = "worker", worker::send)]
134    async fn stream(
135        &self,
136        request: CompletionRequest,
137    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
138        CompletionModel::stream(self, request).await
139    }
140}
141
142pub mod xai_api_types {
143    use serde::{Deserialize, Serialize};
144
145    use crate::OneOrMany;
146    use crate::completion::{self, CompletionError};
147    use crate::providers::openai::{AssistantContent, Message};
148
149    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
150        type Error = CompletionError;
151
152        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
153            let choice = response.choices.first().ok_or_else(|| {
154                CompletionError::ResponseError("Response contained no choices".to_owned())
155            })?;
156            let content = match &choice.message {
157                Message::Assistant {
158                    content,
159                    tool_calls,
160                    ..
161                } => {
162                    let mut content = content
163                        .iter()
164                        .map(|c| match c {
165                            AssistantContent::Text { text } => {
166                                completion::AssistantContent::text(text)
167                            }
168                            AssistantContent::Refusal { refusal } => {
169                                completion::AssistantContent::text(refusal)
170                            }
171                        })
172                        .collect::<Vec<_>>();
173
174                    content.extend(
175                        tool_calls
176                            .iter()
177                            .map(|call| {
178                                completion::AssistantContent::tool_call(
179                                    &call.id,
180                                    &call.function.name,
181                                    call.function.arguments.clone(),
182                                )
183                            })
184                            .collect::<Vec<_>>(),
185                    );
186                    Ok(content)
187                }
188                _ => Err(CompletionError::ResponseError(
189                    "Response did not contain a valid message or tool call".into(),
190                )),
191            }?;
192
193            let choice = OneOrMany::many(content).map_err(|_| {
194                CompletionError::ResponseError(
195                    "Response contained no message or tool call (empty)".to_owned(),
196                )
197            })?;
198
199            Ok(completion::CompletionResponse {
200                choice,
201                raw_response: response,
202            })
203        }
204    }
205
206    impl From<completion::ToolDefinition> for ToolDefinition {
207        fn from(tool: completion::ToolDefinition) -> Self {
208            Self {
209                r#type: "function".into(),
210                function: tool,
211            }
212        }
213    }
214
215    #[derive(Clone, Debug, Deserialize, Serialize)]
216    pub struct ToolDefinition {
217        pub r#type: String,
218        pub function: completion::ToolDefinition,
219    }
220
221    #[derive(Debug, Deserialize)]
222    pub struct Function {
223        pub name: String,
224        pub arguments: String,
225    }
226
227    #[derive(Debug, Deserialize)]
228    pub struct CompletionResponse {
229        pub id: String,
230        pub model: String,
231        pub choices: Vec<Choice>,
232        pub created: i64,
233        pub object: String,
234        pub system_fingerprint: String,
235        pub usage: Usage,
236    }
237
238    #[derive(Debug, Deserialize)]
239    pub struct Choice {
240        pub finish_reason: String,
241        pub index: i32,
242        pub message: Message,
243    }
244
245    #[derive(Debug, Deserialize)]
246    pub struct Usage {
247        pub completion_tokens: i32,
248        pub prompt_tokens: i32,
249        pub total_tokens: i32,
250    }
251}