1use crate::{
7 completion::{self, CompletionError},
8 json_utils,
9 providers::openai::Message,
10};
11
12use super::client::{Client, xai_api_types::ApiResponse};
13use crate::completion::CompletionRequest;
14use crate::providers::openai;
15use crate::streaming::StreamingCompletionResponse;
16use serde_json::{Value, json};
17use xai_api_types::{CompletionResponse, ToolDefinition};
18
19pub const GROK_2_1212: &str = "grok-2-1212";
21pub const GROK_2_VISION_1212: &str = "grok-2-vision-1212";
22pub const GROK_3: &str = "grok-3";
23pub const GROK_3_FAST: &str = "grok-3-fast";
24pub const GROK_3_MINI: &str = "grok-3-mini";
25pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast";
26pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212";
27pub const GROK_4: &str = "grok-4-0709";
28
29#[derive(Clone)]
34pub struct CompletionModel {
35 pub(crate) client: Client,
36 pub model: String,
37}
38
39impl CompletionModel {
40 pub(crate) fn create_completion_request(
41 &self,
42 completion_request: completion::CompletionRequest,
43 ) -> Result<Value, CompletionError> {
44 let docs: Option<Vec<Message>> = completion_request
46 .normalized_documents()
47 .map(|docs| docs.try_into())
48 .transpose()?;
49
50 let chat_history: Vec<Message> = completion_request
52 .chat_history
53 .into_iter()
54 .map(|message| message.try_into())
55 .collect::<Result<Vec<Vec<Message>>, _>>()?
56 .into_iter()
57 .flatten()
58 .collect();
59
60 let mut full_history: Vec<Message> = match &completion_request.preamble {
62 Some(preamble) => vec![Message::system(preamble)],
63 None => vec![],
64 };
65
66 if let Some(docs) = docs {
68 full_history.extend(docs)
69 }
70
71 full_history.extend(chat_history);
73
74 let mut request = if completion_request.tools.is_empty() {
75 json!({
76 "model": self.model,
77 "messages": full_history,
78 "temperature": completion_request.temperature,
79 })
80 } else {
81 json!({
82 "model": self.model,
83 "messages": full_history,
84 "temperature": completion_request.temperature,
85 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
86 "tool_choice": "auto",
87 })
88 };
89
90 request = if let Some(params) = completion_request.additional_params {
91 json_utils::merge(request, params)
92 } else {
93 request
94 };
95
96 Ok(request)
97 }
98 pub fn new(client: Client, model: &str) -> Self {
99 Self {
100 client,
101 model: model.to_string(),
102 }
103 }
104}
105
106impl completion::CompletionModel for CompletionModel {
107 type Response = CompletionResponse;
108 type StreamingResponse = openai::StreamingCompletionResponse;
109
110 #[cfg_attr(feature = "worker", worker::send)]
111 async fn completion(
112 &self,
113 completion_request: completion::CompletionRequest,
114 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
115 let request = self.create_completion_request(completion_request)?;
116
117 let response = self
118 .client
119 .post("/v1/chat/completions")
120 .json(&request)
121 .send()
122 .await?;
123
124 if response.status().is_success() {
125 match response.json::<ApiResponse<CompletionResponse>>().await? {
126 ApiResponse::Ok(completion) => completion.try_into(),
127 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
128 }
129 } else {
130 Err(CompletionError::ProviderError(response.text().await?))
131 }
132 }
133
134 #[cfg_attr(feature = "worker", worker::send)]
135 async fn stream(
136 &self,
137 request: CompletionRequest,
138 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
139 CompletionModel::stream(self, request).await
140 }
141}
142
143pub mod xai_api_types {
144 use serde::{Deserialize, Serialize};
145
146 use crate::OneOrMany;
147 use crate::completion::{self, CompletionError};
148 use crate::providers::openai::{AssistantContent, Message};
149
150 impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
151 type Error = CompletionError;
152
153 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
154 let choice = response.choices.first().ok_or_else(|| {
155 CompletionError::ResponseError("Response contained no choices".to_owned())
156 })?;
157 let content = match &choice.message {
158 Message::Assistant {
159 content,
160 tool_calls,
161 ..
162 } => {
163 let mut content = content
164 .iter()
165 .map(|c| match c {
166 AssistantContent::Text { text } => {
167 completion::AssistantContent::text(text)
168 }
169 AssistantContent::Refusal { refusal } => {
170 completion::AssistantContent::text(refusal)
171 }
172 })
173 .collect::<Vec<_>>();
174
175 content.extend(
176 tool_calls
177 .iter()
178 .map(|call| {
179 completion::AssistantContent::tool_call(
180 &call.id,
181 &call.function.name,
182 call.function.arguments.clone(),
183 )
184 })
185 .collect::<Vec<_>>(),
186 );
187 Ok(content)
188 }
189 _ => Err(CompletionError::ResponseError(
190 "Response did not contain a valid message or tool call".into(),
191 )),
192 }?;
193
194 let choice = OneOrMany::many(content).map_err(|_| {
195 CompletionError::ResponseError(
196 "Response contained no message or tool call (empty)".to_owned(),
197 )
198 })?;
199
200 let usage = completion::Usage {
201 input_tokens: response.usage.prompt_tokens as u64,
202 output_tokens: response.usage.completion_tokens as u64,
203 total_tokens: response.usage.total_tokens as u64,
204 };
205
206 Ok(completion::CompletionResponse {
207 choice,
208 usage,
209 raw_response: response,
210 })
211 }
212 }
213
214 impl From<completion::ToolDefinition> for ToolDefinition {
215 fn from(tool: completion::ToolDefinition) -> Self {
216 Self {
217 r#type: "function".into(),
218 function: tool,
219 }
220 }
221 }
222
223 #[derive(Clone, Debug, Deserialize, Serialize)]
224 pub struct ToolDefinition {
225 pub r#type: String,
226 pub function: completion::ToolDefinition,
227 }
228
229 #[derive(Debug, Deserialize)]
230 pub struct Function {
231 pub name: String,
232 pub arguments: String,
233 }
234
235 #[derive(Debug, Deserialize, Serialize)]
236 pub struct CompletionResponse {
237 pub id: String,
238 pub model: String,
239 pub choices: Vec<Choice>,
240 pub created: i64,
241 pub object: String,
242 pub system_fingerprint: String,
243 pub usage: Usage,
244 }
245
246 #[derive(Debug, Deserialize, Serialize)]
247 pub struct Choice {
248 pub finish_reason: String,
249 pub index: i32,
250 pub message: Message,
251 }
252
253 #[derive(Debug, Deserialize, Serialize)]
254 pub struct Usage {
255 pub completion_tokens: i32,
256 pub prompt_tokens: i32,
257 pub total_tokens: i32,
258 }
259}