rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9    providers::openai::Message,
10};
11
12use super::client::{xai_api_types::ApiResponse, Client};
13use crate::completion::CompletionRequest;
14use serde_json::{json, Value};
15use xai_api_types::{CompletionResponse, ToolDefinition};
16
17/// `grok-beta` completion model
18pub const GROK_BETA: &str = "grok-beta";
19
20// =================================================================
21// Rig Implementation Types
22// =================================================================
23
24#[derive(Clone)]
25pub struct CompletionModel {
26    pub(crate) client: Client,
27    pub model: String,
28}
29
30impl CompletionModel {
31    pub(crate) fn create_completion_request(
32        &self,
33        completion_request: CompletionRequest,
34    ) -> Result<Value, CompletionError> {
35        // Add preamble to chat history (if available)
36        let mut full_history: Vec<Message> = match &completion_request.preamble {
37            Some(preamble) => {
38                if preamble.is_empty() {
39                    vec![]
40                } else {
41                    vec![Message::system(preamble)]
42                }
43            }
44            None => vec![],
45        };
46
47        // Convert prompt to user message
48        let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
49
50        // Convert existing chat history
51        let chat_history: Vec<Message> = completion_request
52            .chat_history
53            .into_iter()
54            .map(|message| message.try_into())
55            .collect::<Result<Vec<Vec<Message>>, _>>()?
56            .into_iter()
57            .flatten()
58            .collect();
59
60        // Combine all messages into a single history
61        full_history.extend(chat_history);
62        full_history.extend(prompt);
63
64        let mut request = if completion_request.tools.is_empty() {
65            json!({
66                "model": self.model,
67                "messages": full_history,
68                "temperature": completion_request.temperature,
69            })
70        } else {
71            json!({
72                "model": self.model,
73                "messages": full_history,
74                "temperature": completion_request.temperature,
75                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
76                "tool_choice": "auto",
77            })
78        };
79
80        request = if let Some(params) = completion_request.additional_params {
81            json_utils::merge(request, params)
82        } else {
83            request
84        };
85
86        Ok(request)
87    }
88    pub fn new(client: Client, model: &str) -> Self {
89        Self {
90            client,
91            model: model.to_string(),
92        }
93    }
94}
95
96impl completion::CompletionModel for CompletionModel {
97    type Response = CompletionResponse;
98
99    #[cfg_attr(feature = "worker", worker::send)]
100    async fn completion(
101        &self,
102        completion_request: completion::CompletionRequest,
103    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
104        let request = self.create_completion_request(completion_request)?;
105
106        let response = self
107            .client
108            .post("/v1/chat/completions")
109            .json(&request)
110            .send()
111            .await?;
112
113        if response.status().is_success() {
114            match response.json::<ApiResponse<CompletionResponse>>().await? {
115                ApiResponse::Ok(completion) => completion.try_into(),
116                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
117            }
118        } else {
119            Err(CompletionError::ProviderError(response.text().await?))
120        }
121    }
122}
123
124pub mod xai_api_types {
125    use serde::{Deserialize, Serialize};
126
127    use crate::completion::{self, CompletionError};
128    use crate::providers::openai::{AssistantContent, Message};
129    use crate::OneOrMany;
130
131    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
132        type Error = CompletionError;
133
134        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
135            let choice = response.choices.first().ok_or_else(|| {
136                CompletionError::ResponseError("Response contained no choices".to_owned())
137            })?;
138            let content = match &choice.message {
139                Message::Assistant {
140                    content,
141                    tool_calls,
142                    ..
143                } => {
144                    let mut content = content
145                        .iter()
146                        .map(|c| match c {
147                            AssistantContent::Text { text } => {
148                                completion::AssistantContent::text(text)
149                            }
150                            AssistantContent::Refusal { refusal } => {
151                                completion::AssistantContent::text(refusal)
152                            }
153                        })
154                        .collect::<Vec<_>>();
155
156                    content.extend(
157                        tool_calls
158                            .iter()
159                            .map(|call| {
160                                completion::AssistantContent::tool_call(
161                                    &call.function.name,
162                                    &call.function.name,
163                                    call.function.arguments.clone(),
164                                )
165                            })
166                            .collect::<Vec<_>>(),
167                    );
168                    Ok(content)
169                }
170                _ => Err(CompletionError::ResponseError(
171                    "Response did not contain a valid message or tool call".into(),
172                )),
173            }?;
174
175            let choice = OneOrMany::many(content).map_err(|_| {
176                CompletionError::ResponseError(
177                    "Response contained no message or tool call (empty)".to_owned(),
178                )
179            })?;
180
181            Ok(completion::CompletionResponse {
182                choice,
183                raw_response: response,
184            })
185        }
186    }
187
188    impl From<completion::ToolDefinition> for ToolDefinition {
189        fn from(tool: completion::ToolDefinition) -> Self {
190            Self {
191                r#type: "function".into(),
192                function: tool,
193            }
194        }
195    }
196
197    #[derive(Clone, Debug, Deserialize, Serialize)]
198    pub struct ToolDefinition {
199        pub r#type: String,
200        pub function: completion::ToolDefinition,
201    }
202
203    #[derive(Debug, Deserialize)]
204    pub struct Function {
205        pub name: String,
206        pub arguments: String,
207    }
208
209    #[derive(Debug, Deserialize)]
210    pub struct CompletionResponse {
211        pub id: String,
212        pub model: String,
213        pub choices: Vec<Choice>,
214        pub created: i64,
215        pub object: String,
216        pub system_fingerprint: String,
217        pub usage: Usage,
218    }
219
220    #[derive(Debug, Deserialize)]
221    pub struct Choice {
222        pub finish_reason: String,
223        pub index: i32,
224        pub message: Message,
225    }
226
227    #[derive(Debug, Deserialize)]
228    pub struct Usage {
229        pub completion_tokens: i32,
230        pub prompt_tokens: i32,
231        pub total_tokens: i32,
232    }
233}