rig/providers/xai/
completion.rs1use crate::{
7 completion::{self, CompletionError},
8 json_utils,
9 providers::openai::Message,
10};
11
12use super::client::{xai_api_types::ApiResponse, Client};
13use crate::completion::CompletionRequest;
14use serde_json::{json, Value};
15use xai_api_types::{CompletionResponse, ToolDefinition};
16
17pub const GROK_BETA: &str = "grok-beta";
19
20#[derive(Clone)]
25pub struct CompletionModel {
26 pub(crate) client: Client,
27 pub model: String,
28}
29
30impl CompletionModel {
31 pub(crate) fn create_completion_request(
32 &self,
33 completion_request: CompletionRequest,
34 ) -> Result<Value, CompletionError> {
35 let mut full_history: Vec<Message> = match &completion_request.preamble {
37 Some(preamble) => {
38 if preamble.is_empty() {
39 vec![]
40 } else {
41 vec![Message::system(preamble)]
42 }
43 }
44 None => vec![],
45 };
46
47 let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
49
50 let chat_history: Vec<Message> = completion_request
52 .chat_history
53 .into_iter()
54 .map(|message| message.try_into())
55 .collect::<Result<Vec<Vec<Message>>, _>>()?
56 .into_iter()
57 .flatten()
58 .collect();
59
60 full_history.extend(chat_history);
62 full_history.extend(prompt);
63
64 let mut request = if completion_request.tools.is_empty() {
65 json!({
66 "model": self.model,
67 "messages": full_history,
68 "temperature": completion_request.temperature,
69 })
70 } else {
71 json!({
72 "model": self.model,
73 "messages": full_history,
74 "temperature": completion_request.temperature,
75 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
76 "tool_choice": "auto",
77 })
78 };
79
80 request = if let Some(params) = completion_request.additional_params {
81 json_utils::merge(request, params)
82 } else {
83 request
84 };
85
86 Ok(request)
87 }
88 pub fn new(client: Client, model: &str) -> Self {
89 Self {
90 client,
91 model: model.to_string(),
92 }
93 }
94}
95
96impl completion::CompletionModel for CompletionModel {
97 type Response = CompletionResponse;
98
99 #[cfg_attr(feature = "worker", worker::send)]
100 async fn completion(
101 &self,
102 completion_request: completion::CompletionRequest,
103 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
104 let request = self.create_completion_request(completion_request)?;
105
106 let response = self
107 .client
108 .post("/v1/chat/completions")
109 .json(&request)
110 .send()
111 .await?;
112
113 if response.status().is_success() {
114 match response.json::<ApiResponse<CompletionResponse>>().await? {
115 ApiResponse::Ok(completion) => completion.try_into(),
116 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
117 }
118 } else {
119 Err(CompletionError::ProviderError(response.text().await?))
120 }
121 }
122}
123
124pub mod xai_api_types {
125 use serde::{Deserialize, Serialize};
126
127 use crate::completion::{self, CompletionError};
128 use crate::providers::openai::{AssistantContent, Message};
129 use crate::OneOrMany;
130
131 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
132 type Error = CompletionError;
133
134 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
135 let choice = response.choices.first().ok_or_else(|| {
136 CompletionError::ResponseError("Response contained no choices".to_owned())
137 })?;
138 let content = match &choice.message {
139 Message::Assistant {
140 content,
141 tool_calls,
142 ..
143 } => {
144 let mut content = content
145 .iter()
146 .map(|c| match c {
147 AssistantContent::Text { text } => {
148 completion::AssistantContent::text(text)
149 }
150 AssistantContent::Refusal { refusal } => {
151 completion::AssistantContent::text(refusal)
152 }
153 })
154 .collect::<Vec<_>>();
155
156 content.extend(
157 tool_calls
158 .iter()
159 .map(|call| {
160 completion::AssistantContent::tool_call(
161 &call.function.name,
162 &call.function.name,
163 call.function.arguments.clone(),
164 )
165 })
166 .collect::<Vec<_>>(),
167 );
168 Ok(content)
169 }
170 _ => Err(CompletionError::ResponseError(
171 "Response did not contain a valid message or tool call".into(),
172 )),
173 }?;
174
175 let choice = OneOrMany::many(content).map_err(|_| {
176 CompletionError::ResponseError(
177 "Response contained no message or tool call (empty)".to_owned(),
178 )
179 })?;
180
181 Ok(completion::CompletionResponse {
182 choice,
183 raw_response: response,
184 })
185 }
186 }
187
188 impl From<completion::ToolDefinition> for ToolDefinition {
189 fn from(tool: completion::ToolDefinition) -> Self {
190 Self {
191 r#type: "function".into(),
192 function: tool,
193 }
194 }
195 }
196
197 #[derive(Clone, Debug, Deserialize, Serialize)]
198 pub struct ToolDefinition {
199 pub r#type: String,
200 pub function: completion::ToolDefinition,
201 }
202
203 #[derive(Debug, Deserialize)]
204 pub struct Function {
205 pub name: String,
206 pub arguments: String,
207 }
208
209 #[derive(Debug, Deserialize)]
210 pub struct CompletionResponse {
211 pub id: String,
212 pub model: String,
213 pub choices: Vec<Choice>,
214 pub created: i64,
215 pub object: String,
216 pub system_fingerprint: String,
217 pub usage: Usage,
218 }
219
220 #[derive(Debug, Deserialize)]
221 pub struct Choice {
222 pub finish_reason: String,
223 pub index: i32,
224 pub message: Message,
225 }
226
227 #[derive(Debug, Deserialize)]
228 pub struct Usage {
229 pub completion_tokens: i32,
230 pub prompt_tokens: i32,
231 pub total_tokens: i32,
232 }
233}