rig/providers/xai/
completion.rs

1// ================================================================
2//! xAI Completion Integration
3//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions)
4// ================================================================
5
6use crate::{
7    completion::{self, CompletionError},
8    json_utils,
9    providers::openai::Message,
10};
11
12use serde_json::json;
13use xai_api_types::{CompletionResponse, ToolDefinition};
14
15use super::client::{xai_api_types::ApiResponse, Client};
16
17/// `grok-beta` completion model
18pub const GROK_BETA: &str = "grok-beta";
19
20// =================================================================
21// Rig Implementation Types
22// =================================================================
23
24#[derive(Clone)]
25pub struct CompletionModel {
26    client: Client,
27    pub model: String,
28}
29
30impl CompletionModel {
31    pub fn new(client: Client, model: &str) -> Self {
32        Self {
33            client,
34            model: model.to_string(),
35        }
36    }
37}
38
39impl completion::CompletionModel for CompletionModel {
40    type Response = CompletionResponse;
41
42    #[cfg_attr(feature = "worker", worker::send)]
43    async fn completion(
44        &self,
45        completion_request: completion::CompletionRequest,
46    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
47        // Add preamble to chat history (if available)
48        let mut full_history: Vec<Message> = match &completion_request.preamble {
49            Some(preamble) => {
50                if preamble.is_empty() {
51                    vec![]
52                } else {
53                    vec![Message::system(preamble)]
54                }
55            }
56            None => vec![],
57        };
58
59        // Convert prompt to user message
60        let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
61
62        // Convert existing chat history
63        let chat_history: Vec<Message> = completion_request
64            .chat_history
65            .into_iter()
66            .map(|message| message.try_into())
67            .collect::<Result<Vec<Vec<Message>>, _>>()?
68            .into_iter()
69            .flatten()
70            .collect();
71
72        // Combine all messages into a single history
73        full_history.extend(chat_history);
74        full_history.extend(prompt);
75
76        let mut request = if completion_request.tools.is_empty() {
77            json!({
78                "model": self.model,
79                "messages": full_history,
80                "temperature": completion_request.temperature,
81            })
82        } else {
83            json!({
84                "model": self.model,
85                "messages": full_history,
86                "temperature": completion_request.temperature,
87                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
88                "tool_choice": "auto",
89            })
90        };
91
92        request = if let Some(params) = completion_request.additional_params {
93            json_utils::merge(request, params)
94        } else {
95            request
96        };
97
98        let response = self
99            .client
100            .post("/v1/chat/completions")
101            .json(&request)
102            .send()
103            .await?;
104
105        if response.status().is_success() {
106            match response.json::<ApiResponse<CompletionResponse>>().await? {
107                ApiResponse::Ok(completion) => completion.try_into(),
108                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
109            }
110        } else {
111            Err(CompletionError::ProviderError(response.text().await?))
112        }
113    }
114}
115
116pub mod xai_api_types {
117    use serde::{Deserialize, Serialize};
118
119    use crate::completion::{self, CompletionError};
120    use crate::providers::openai::{AssistantContent, Message};
121    use crate::OneOrMany;
122
123    impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
124        type Error = CompletionError;
125
126        fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
127            let choice = response.choices.first().ok_or_else(|| {
128                CompletionError::ResponseError("Response contained no choices".to_owned())
129            })?;
130            let content = match &choice.message {
131                Message::Assistant {
132                    content,
133                    tool_calls,
134                    ..
135                } => {
136                    let mut content = content
137                        .iter()
138                        .map(|c| match c {
139                            AssistantContent::Text { text } => {
140                                completion::AssistantContent::text(text)
141                            }
142                            AssistantContent::Refusal { refusal } => {
143                                completion::AssistantContent::text(refusal)
144                            }
145                        })
146                        .collect::<Vec<_>>();
147
148                    content.extend(
149                        tool_calls
150                            .iter()
151                            .map(|call| {
152                                completion::AssistantContent::tool_call(
153                                    &call.function.name,
154                                    &call.function.name,
155                                    call.function.arguments.clone(),
156                                )
157                            })
158                            .collect::<Vec<_>>(),
159                    );
160                    Ok(content)
161                }
162                _ => Err(CompletionError::ResponseError(
163                    "Response did not contain a valid message or tool call".into(),
164                )),
165            }?;
166
167            let choice = OneOrMany::many(content).map_err(|_| {
168                CompletionError::ResponseError(
169                    "Response contained no message or tool call (empty)".to_owned(),
170                )
171            })?;
172
173            Ok(completion::CompletionResponse {
174                choice,
175                raw_response: response,
176            })
177        }
178    }
179
180    impl From<completion::ToolDefinition> for ToolDefinition {
181        fn from(tool: completion::ToolDefinition) -> Self {
182            Self {
183                r#type: "function".into(),
184                function: tool,
185            }
186        }
187    }
188
189    #[derive(Clone, Debug, Deserialize, Serialize)]
190    pub struct ToolDefinition {
191        pub r#type: String,
192        pub function: completion::ToolDefinition,
193    }
194
195    #[derive(Debug, Deserialize)]
196    pub struct Function {
197        pub name: String,
198        pub arguments: String,
199    }
200
201    #[derive(Debug, Deserialize)]
202    pub struct CompletionResponse {
203        pub id: String,
204        pub model: String,
205        pub choices: Vec<Choice>,
206        pub created: i64,
207        pub object: String,
208        pub system_fingerprint: String,
209        pub usage: Usage,
210    }
211
212    #[derive(Debug, Deserialize)]
213    pub struct Choice {
214        pub finish_reason: String,
215        pub index: i32,
216        pub message: Message,
217    }
218
219    #[derive(Debug, Deserialize)]
220    pub struct Usage {
221        pub completion_tokens: i32,
222        pub prompt_tokens: i32,
223        pub total_tokens: i32,
224    }
225}