1use serde::{Deserialize, Serialize};
5use strum::{Display, EnumIter, EnumString};
6
7use crate::error::{self, Result};
8use crate::types::api::{
9 ApiKey, ImageModel as ApiImageModel, LanguageModel as ApiLanguageModel, Model, TokenizeResponse,
10};
11use crate::types::chat::{
12 ChatCompletionRequest, ChatCompletionResponse, Choice, DeferredChatCompletionResponse, Message,
13 stream,
14};
15use crate::types::image::{ImageRequest, ImageResponse};
16use futures::StreamExt;
17
18#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
19pub enum LanguageModel {
20 #[strum(serialize = "grok-4")]
21 Grok4,
22
23 #[strum(serialize = "grok-code-fast")]
24 GrokCode,
25
26 #[strum(serialize = "grok-3")]
27 Grok3,
28 #[strum(serialize = "grok-3-fast")]
29 Grok3Fast,
30
31 #[strum(serialize = "grok-3-mini")]
32 Grok3Mini,
33 #[strum(serialize = "grok-3-mini-fast")]
34 Grok3MiniFast,
35
36 #[strum(serialize = "grok-2")]
38 Grok2,
39
40 #[strum(serialize = "grok-2-vision")]
42 Grok2Vision,
43}
44
45impl LanguageModel {
46 pub fn err_ivalid_model(model: String) -> String {
47 format!("Invalid language model '{model}'")
48 }
49}
50
51#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
52pub enum ImageModel {
53 #[strum(serialize = "grok-2-image")]
54 Grok2Image,
55}
56
57impl ImageModel {
58 pub fn err_ivalid_model(model: String) -> String {
59 format!("Invalid image model '{model}'")
60 }
61}
62
63#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
64#[strum(serialize_all = "snake_case")]
65pub enum Role {
66 Assistant,
67 System,
68 Tool,
69
70 User,
71}
72
73pub mod url {
77 pub const HOST: &str = "https://api.x.ai/v1";
78 pub const MANAGEMENT_HOST: &str = "https://management-api.x.ai";
79
80 pub mod api {
81 use super::HOST;
82 use const_format::formatcp;
83
84 pub const GET_KEY: &str = formatcp!("{HOST}/api-key");
85 pub const GET_MODELS: &str = formatcp!("{HOST}/models");
86 pub const GET_LANGUAGE_MODELS: &str = formatcp!("{HOST}/language-models");
87 pub const GET_IMAGE_MODELS: &str = formatcp!("{HOST}/image-generation-models");
88
89 pub const POST_TOKENIZE_TEXT: &str = formatcp!("{HOST}/tokenize-text");
90
91 pub fn get_model(id: String) -> String {
92 format!("{GET_MODELS}/{id}")
93 }
94
95 pub fn get_language_model(id: String) -> String {
96 format!("{GET_LANGUAGE_MODELS}/{id}")
97 }
98
99 pub fn get_image_model(id: String) -> String {
100 format!("{GET_IMAGE_MODELS}/{id}")
101 }
102 }
103
104 pub mod chat {
105 use super::HOST;
106 use const_format::formatcp;
107
108 pub const POST_COMPLETION: &str = formatcp!("{HOST}/chat/completions");
109 pub const GET_DEFERED_COMPLETION: &str = formatcp!("{HOST}/chat/deferred-completion");
110
111 pub fn get_deferred_completion(request_id: String) -> String {
112 format!("{GET_DEFERED_COMPLETION}/{request_id}")
113 }
114 }
115
116 pub mod image {
117 use super::HOST;
118 use const_format::formatcp;
119
120 pub const POST_GENERATE: &str = formatcp!("{HOST}/images/generations");
121 }
122}
123
124#[derive(Debug, Clone)]
128pub struct GrokClient {
129 client: reqwest::Client,
130 api_key: String,
131}
132
133impl GrokClient {
134 pub fn new(api_key: String) -> Self {
136 Self {
137 client: reqwest::Client::new(),
138 api_key,
139 }
140 }
141
142 pub fn with_client(client: reqwest::Client, api_key: String) -> Self {
144 Self { client, api_key }
145 }
146
147 pub fn api_key(&self) -> &str {
149 &self.api_key
150 }
151
152 pub fn client(&self) -> &reqwest::Client {
154 &self.client
155 }
156
157 pub async fn get_api_key(&self) -> Result<ApiKey> {
163 let res = self
164 .client
165 .get(url::api::GET_KEY)
166 .header("Authorization", format!("Bearer {}", self.api_key))
167 .send()
168 .await?;
169
170 Ok(res.json().await?)
171 }
172
173 pub async fn get_model(&self, id: LanguageModel) -> Result<Model> {
175 let res = self
176 .client
177 .get(url::api::get_model(id.to_string()))
178 .header("Authorization", format!("Bearer {}", self.api_key))
179 .send()
180 .await?;
181
182 Ok(res.json().await?)
183 }
184
185 pub async fn get_language_models(&self) -> Result<Vec<ApiLanguageModel>> {
187 let res = self
188 .client
189 .get(url::api::GET_LANGUAGE_MODELS)
190 .header("Authorization", format!("Bearer {}", self.api_key))
191 .send()
192 .await?;
193
194 let res: crate::types::api::LanguageModels = res.json().await?;
195 Ok(res.models)
196 }
197
198 pub async fn get_language_model(&self, id: LanguageModel) -> Result<ApiLanguageModel> {
200 let res = self
201 .client
202 .get(url::api::get_language_model(id.to_string()))
203 .header("Authorization", format!("Bearer {}", self.api_key))
204 .send()
205 .await?;
206
207 Ok(res.json().await?)
208 }
209
210 pub async fn get_image_models(&self) -> Result<Vec<ApiImageModel>> {
212 let res = self
213 .client
214 .get(url::api::GET_IMAGE_MODELS)
215 .header("Authorization", format!("Bearer {}", self.api_key))
216 .send()
217 .await?;
218
219 let res: crate::types::api::ImageModels = res.json().await?;
220 Ok(res.models)
221 }
222
223 pub async fn get_image_model(&self, id: ImageModel) -> Result<ApiImageModel> {
225 let res = self
226 .client
227 .get(url::api::get_image_model(id.to_string()))
228 .header("Authorization", format!("Bearer {}", self.api_key))
229 .send()
230 .await?;
231
232 Ok(res.json().await?)
233 }
234
235 pub async fn tokenize_text(
237 &self,
238 model: LanguageModel,
239 text: String,
240 ) -> Result<TokenizeResponse> {
241 let body = crate::types::api::TokenizeRequest::init(model, text);
242 let res = self
243 .client
244 .post(url::api::POST_TOKENIZE_TEXT)
245 .header("Authorization", format!("Bearer {}", self.api_key))
246 .json(&body)
247 .send()
248 .await?;
249
250 Ok(res.json().await?)
251 }
252
253 pub async fn chat_complete(
259 &self,
260 request: &ChatCompletionRequest,
261 ) -> Result<ChatCompletionResponse> {
262 let mut complete_req = request.clone();
263 complete_req.stream = Some(false);
264 complete_req.deferred = Some(false);
265
266 let res = self
267 .client
268 .post(url::chat::POST_COMPLETION)
269 .header("Authorization", format!("Bearer {}", self.api_key))
270 .header("Content-Type", "application/json")
271 .json(&complete_req)
272 .send()
273 .await?;
274
275 Ok(res.json().await?)
276 }
277
278 pub async fn chat_stream<F1, F2, F3>(
280 &self,
281 request: &ChatCompletionRequest,
282 on_content_token: F1,
283 on_reason_token: Option<F2>,
284 on_done: Option<F3>,
285 ) -> Result<ChatCompletionResponse>
286 where
287 F1: Fn(&str),
288 F2: Fn(&str),
289 F3: Fn(),
290 {
291 let mut complete_req = request.clone();
292 complete_req.stream = Some(true);
293 complete_req.deferred = Some(false);
294
295 let req_builder = self
296 .client
297 .post(url::chat::POST_COMPLETION)
298 .header("Authorization", format!("Bearer {}", self.api_key))
299 .header("Content-Type", "application/json")
300 .json(&complete_req);
301
302 let mut stream = reqwest_eventsource::EventSource::new(req_builder)?;
303
304 let mut buf_reasoning_content = String::new();
305 let mut buf_content = String::new();
306 let mut complete_res = ChatCompletionResponse::new(0);
307 let mut init = true;
308 let mut role: Option<String> = None;
309
310 while let Some(event) = stream.next().await {
311 match event {
312 Ok(reqwest_eventsource::Event::Open) => {}
313 Ok(reqwest_eventsource::Event::Message(message)) => {
314 if message.data == "[DONE]" {
315 if let Some(done) = on_done {
316 done();
317 }
318 stream.close();
319 break;
320 }
321
322 let chunk: stream::ChatCompletionChunk = serde_json::from_str(&message.data)
323 .map_err(|e| error::Error::SerdeJson(e))?;
324
325 if init {
326 init = false;
327 complete_res.id = chunk.id;
328 complete_res.object = "chat.response".to_string();
329 complete_res.created = chunk.created;
330 complete_res.model = chunk.model;
331 complete_res.system_fingerprint = Some(chunk.system_fingerprint);
332 }
333
334 if let Some(choice) = chunk.choices.last()
335 && role.is_none()
336 {
337 if let Some(r) = &choice.delta.role {
338 role = Some(r.clone());
339 }
340 }
341
342 if chunk.usage.is_some() {
343 complete_res.usage = chunk.usage;
344 }
345
346 if chunk.citations.is_some() {
347 complete_res.citations = chunk.citations;
348 }
349
350 if let Some(choice) = chunk.choices.get(0) {
351 if let (Some(cb_reason_token), Some(reason_token)) =
352 (&on_reason_token, &choice.delta.reasoning_content)
353 {
354 cb_reason_token(&reason_token);
355 buf_reasoning_content.push_str(reason_token);
356 }
357
358 if let Some(content_token) = &choice.delta.content {
359 on_content_token(&content_token);
360 buf_content.push_str(content_token);
361 }
362 }
363 }
364 Err(err) => {
365 stream.close();
366 return Err(error::Error::EventSource(err));
367 }
368 }
369 }
370
371 complete_res.choices.push(Choice {
372 index: 0,
373 message: Message {
374 role: role.unwrap_or("unknown".to_string()),
375 content: buf_content,
376 reasoning_content: Some(buf_reasoning_content),
377 refusal: None,
378 tool_calls: None,
379 tool_call_id: None,
380 },
381 finish_reason: "stop".to_string(),
382 });
383
384 Ok(complete_res)
385 }
386
387 pub async fn chat_defer(
389 &self,
390 request: &ChatCompletionRequest,
391 ) -> Result<DeferredChatCompletionResponse> {
392 let mut complete_req = request.clone();
393 complete_req.stream = Some(false);
394 complete_req.deferred = Some(true);
395
396 let res = self
397 .client
398 .post(url::chat::POST_COMPLETION)
399 .header("Authorization", format!("Bearer {}", self.api_key))
400 .header("Content-Type", "application/json")
401 .json(&complete_req)
402 .send()
403 .await?;
404
405 Ok(res.json().await?)
406 }
407
408 pub async fn get_deferred_completion(
410 &self,
411 request_id: String,
412 ) -> Result<ChatCompletionResponse> {
413 let res = self
414 .client
415 .get(url::chat::get_deferred_completion(request_id))
416 .header("Authorization", format!("Bearer {}", self.api_key))
417 .send()
418 .await?;
419
420 Ok(res.json().await?)
421 }
422
423 pub async fn generate_image(&self, request: &ImageRequest) -> Result<ImageResponse> {
429 let res = self
430 .client
431 .post(url::image::POST_GENERATE)
432 .header("Authorization", format!("Bearer {}", self.api_key))
433 .json(request)
434 .send()
435 .await?;
436
437 Ok(res.json().await?)
438 }
439}