rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/docs/api-reference#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9    providers::openai::Message,
10};
11
12use super::client::{xai_api_types::ApiResponse, Client};
13use serde_json::{json, Value};
14use xai_api_types::{CompletionResponse, ToolDefinition};
15
16/// `grok-beta` completion model
17pub const GROK_BETA: &str = "grok-beta";
18
19// =================================================================
20// Rig Implementation Types
21// =================================================================
22
23#[derive(Clone)]
24pub struct CompletionModel {
25    pub(crate) client: Client,
26    pub model: String,
27}
28
29impl CompletionModel {
30    pub(crate) fn create_completion_request(
31        &self,
32        completion_request: completion::CompletionRequest,
33    ) -> Result<Value, CompletionError> {
34        // Convert documents into user message
35        let docs: Option<Vec<Message>> = completion_request
36            .normalized_documents()
37            .map(|docs| docs.try_into())
38            .transpose()?;
39
40        // Convert existing chat history
41        let chat_history: Vec<Message> = completion_request
42            .chat_history
43            .into_iter()
44            .map(|message| message.try_into())
45            .collect::<Result<Vec<Vec<Message>>, _>>()?
46            .into_iter()
47            .flatten()
48            .collect();
49
50        // Init full history with preamble (or empty if non-existant)
51        let mut full_history: Vec<Message> = match &completion_request.preamble {
52            Some(preamble) => vec![Message::system(preamble)],
53            None => vec![],
54        };
55
56        // Docs appear right after preamble, if they exist
57        if let Some(docs) = docs {
58            full_history.extend(docs)
59        }
60
61        // Chat history and prompt appear in the order they were provided
62        full_history.extend(chat_history);
63
64        let mut request = if completion_request.tools.is_empty() {
65            json!({
66                "model": self.model,
67                "messages": full_history,
68                "temperature": completion_request.temperature,
69            })
70        } else {
71            json!({
72                "model": self.model,
73                "messages": full_history,
74                "temperature": completion_request.temperature,
75                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
76                "tool_choice": "auto",
77            })
78        };
79
80        request = if let Some(params) = completion_request.additional_params {
81            json_utils::merge(request, params)
82        } else {
83            request
84        };
85
86        Ok(request)
87    }
88    pub fn new(client: Client, model: &str) -> Self {
89        Self {
90            client,
91            model: model.to_string(),
92        }
93    }
94}
95
96impl completion::CompletionModel for CompletionModel {
97    type Response = CompletionResponse;
98
99    #[cfg_attr(feature = "worker", worker::send)]
100    async fn completion(
101        &self,
102        completion_request: completion::CompletionRequest,
103    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
104        let request = self.create_completion_request(completion_request)?;
105
106        let response = self
107            .client
108            .post("/v1/chat/completions")
109            .json(&request)
110            .send()
111            .await?;
112
113        if response.status().is_success() {
114            match response.json::<ApiResponse<CompletionResponse>>().await? {
115                ApiResponse::Ok(completion) => completion.try_into(),
116                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
117            }
118        } else {
119            Err(CompletionError::ProviderError(response.text().await?))
120        }
121    }
122}
123
124pub mod xai_api_types {
125    use serde::{Deserialize, Serialize};
126
127    use crate::completion::{self, CompletionError};
128    use crate::providers::openai::{AssistantContent, Message};
129    use crate::OneOrMany;
130
131    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
132        type Error = CompletionError;
133
134        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
135            let choice = response.choices.first().ok_or_else(|| {
136                CompletionError::ResponseError("Response contained no choices".to_owned())
137            })?;
138            let content = match &choice.message {
139                Message::Assistant {
140                    content,
141                    tool_calls,
142                    ..
143                } => {
144                    let mut content = content
145                        .iter()
146                        .map(|c| match c {
147                            AssistantContent::Text { text } => {
148                                completion::AssistantContent::text(text)
149                            }
150                            AssistantContent::Refusal { refusal } => {
151                                completion::AssistantContent::text(refusal)
152                            }
153                        })
154                        .collect::<Vec<_>>();
155
156                    content.extend(
157                        tool_calls
158                            .iter()
159                            .map(|call| {
160                                completion::AssistantContent::tool_call(
161                                    &call.id,
162                                    &call.function.name,
163                                    call.function.arguments.clone(),
164                                )
165                            })
166                            .collect::<Vec<_>>(),
167                    );
168                    Ok(content)
169                }
170                _ => Err(CompletionError::ResponseError(
171                    "Response did not contain a valid message or tool call".into(),
172                )),
173            }?;
174
175            let choice = OneOrMany::many(content).map_err(|_| {
176                CompletionError::ResponseError(
177                    "Response contained no message or tool call (empty)".to_owned(),
178                )
179            })?;
180
181            Ok(completion::CompletionResponse {
182                choice,
183                raw_response: response,
184            })
185        }
186    }
187
188    impl From<completion::ToolDefinition> for ToolDefinition {
189        fn from(tool: completion::ToolDefinition) -> Self {
190            Self {
191                r#type: "function".into(),
192                function: tool,
193            }
194        }
195    }
196
197    #[derive(Clone, Debug, Deserialize, Serialize)]
198    pub struct ToolDefinition {
199        pub r#type: String,
200        pub function: completion::ToolDefinition,
201    }
202
203    #[derive(Debug, Deserialize)]
204    pub struct Function {
205        pub name: String,
206        pub arguments: String,
207    }
208
209    #[derive(Debug, Deserialize)]
210    pub struct CompletionResponse {
211        pub id: String,
212        pub model: String,
213        pub choices: Vec<Choice>,
214        pub created: i64,
215        pub object: String,
216        pub system_fingerprint: String,
217        pub usage: Usage,
218    }
219
220    #[derive(Debug, Deserialize)]
221    pub struct Choice {
222        pub finish_reason: String,
223        pub index: i32,
224        pub message: Message,
225    }
226
227    #[derive(Debug, Deserialize)]
228    pub struct Usage {
229        pub completion_tokens: i32,
230        pub prompt_tokens: i32,
231        pub total_tokens: i32,
232    }
233}