rig/providers/xai/
completion.rs1use crate::{
7 completion::{self, CompletionError},
8 json_utils,
9 providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use serde_json::{Value, json};
17use xai_api_types::{CompletionResponse, ToolDefinition};
18
19pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27
28#[derive(Clone)]
33pub struct CompletionModel {
34 pub(crate) client: Client,
35 pub model: String,
36}
37
38impl CompletionModel {
39 pub(crate) fn create_completion_request(
40 &self,
41 completion_request: completion::CompletionRequest,
42 ) -> Result<Value, CompletionError> {
43 let docs: Option<Vec<Message>> = completion_request
45 .normalized_documents()
46 .map(|docs| docs.try_into())
47 .transpose()?;
48
49 let chat_history: Vec<Message> = completion_request
51 .chat_history
52 .into_iter()
53 .map(|message| message.try_into())
54 .collect::<Result<Vec<Vec<Message>>, _>>()?
55 .into_iter()
56 .flatten()
57 .collect();
58
59 let mut full_history: Vec<Message> = match &completion_request.preamble {
61 Some(preamble) => vec![Message::system(preamble)],
62 None => vec![],
63 };
64
65 if let Some(docs) = docs {
67 full_history.extend(docs)
68 }
69
70 full_history.extend(chat_history);
72
73 let mut request = if completion_request.tools.is_empty() {
74 json!({
75 "model": self.model,
76 "messages": full_history,
77 "temperature": completion_request.temperature,
78 })
79 } else {
80 json!({
81 "model": self.model,
82 "messages": full_history,
83 "temperature": completion_request.temperature,
84 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
85 "tool_choice": "auto",
86 })
87 };
88
89 request = if let Some(params) = completion_request.additional_params {
90 json_utils::merge(request, params)
91 } else {
92 request
93 };
94
95 Ok(request)
96 }
97 pub fn new(client: Client, model: &str) -> Self {
98 Self {
99 client,
100 model: model.to_string(),
101 }
102 }
103}
104
105impl completion::CompletionModel for CompletionModel {
106 type Response = CompletionResponse;
107 type StreamingResponse = openai::StreamingCompletionResponse;
108
109 #[cfg_attr(feature = "worker", worker::send)]
110 async fn completion(
111 &self,
112 completion_request: completion::CompletionRequest,
113 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
114 let request = self.create_completion_request(completion_request)?;
115
116 let response = self
117 .client
118 .post("/v1/chat/completions")
119 .json(&request)
120 .send()
121 .await?;
122
123 if response.status().is_success() {
124 match response.json::<ApiResponse<CompletionResponse>>().await? {
125 ApiResponse::Ok(completion) => completion.try_into(),
126 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
127 }
128 } else {
129 Err(CompletionError::ProviderError(response.text().await?))
130 }
131 }
132
133 #[cfg_attr(feature = "worker", worker::send)]
134 async fn stream(
135 &self,
136 request: CompletionRequest,
137 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
138 CompletionModel::stream(self, request).await
139 }
140}
141
142pub mod xai_api_types {
143 use serde::{Deserialize, Serialize};
144
145 use crate::OneOrMany;
146 use crate::completion::{self, CompletionError};
147 use crate::providers::openai::{AssistantContent, Message};
148
149 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
150 type Error = CompletionError;
151
152 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
153 let choice = response.choices.first().ok_or_else(|| {
154 CompletionError::ResponseError("Response contained no choices".to_owned())
155 })?;
156 let content = match &choice.message {
157 Message::Assistant {
158 content,
159 tool_calls,
160 ..
161 } => {
162 let mut content = content
163 .iter()
164 .map(|c| match c {
165 AssistantContent::Text { text } => {
166 completion::AssistantContent::text(text)
167 }
168 AssistantContent::Refusal { refusal } => {
169 completion::AssistantContent::text(refusal)
170 }
171 })
172 .collect::<Vec<_>>();
173
174 content.extend(
175 tool_calls
176 .iter()
177 .map(|call| {
178 completion::AssistantContent::tool_call(
179 &call.id,
180 &call.function.name,
181 call.function.arguments.clone(),
182 )
183 })
184 .collect::<Vec<_>>(),
185 );
186 Ok(content)
187 }
188 _ => Err(CompletionError::ResponseError(
189 "Response did not contain a valid message or tool call".into(),
190 )),
191 }?;
192
193 let choice = OneOrMany::many(content).map_err(|_| {
194 CompletionError::ResponseError(
195 "Response contained no message or tool call (empty)".to_owned(),
196 )
197 })?;
198
199 Ok(completion::CompletionResponse {
200 choice,
201 raw_response: response,
202 })
203 }
204 }
205
206 impl From<completion::ToolDefinition> for ToolDefinition {
207 fn from(tool: completion::ToolDefinition) -> Self {
208 Self {
209 r#type: "function".into(),
210 function: tool,
211 }
212 }
213 }
214
215 #[derive(Clone, Debug, Deserialize, Serialize)]
216 pub struct ToolDefinition {
217 pub r#type: String,
218 pub function: completion::ToolDefinition,
219 }
220
221 #[derive(Debug, Deserialize)]
222 pub struct Function {
223 pub name: String,
224 pub arguments: String,
225 }
226
227 #[derive(Debug, Deserialize)]
228 pub struct CompletionResponse {
229 pub id: String,
230 pub model: String,
231 pub choices: Vec<Choice>,
232 pub created: i64,
233 pub object: String,
234 pub system_fingerprint: String,
235 pub usage: Usage,
236 }
237
238 #[derive(Debug, Deserialize)]
239 pub struct Choice {
240 pub finish_reason: String,
241 pub index: i32,
242 pub message: Message,
243 }
244
245 #[derive(Debug, Deserialize)]
246 pub struct Usage {
247 pub completion_tokens: i32,
248 pub prompt_tokens: i32,
249 pub total_tokens: i32,
250 }
251}