1use serde::{Deserialize, Serialize};
5use strum::{Display, EnumIter, EnumString};
6
7use crate::error::{self, Result};
8use crate::types::api::{
9 ApiKey, ImageModel as ApiImageModel, LanguageModel as ApiLanguageModel, Model, TokenizeResponse,
10};
11use crate::types::chat::{
12 ChatCompletionRequest, ChatCompletionResponse, Choice, DeferredChatCompletionResponse, Message,
13 stream,
14};
15use crate::types::image::{ImageRequest, ImageResponse};
16use futures::StreamExt;
17
18#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
19pub enum LanguageModel {
20 #[strum(serialize = "grok-4")]
21 Grok4,
22
23 #[strum(serialize = "grok-code-fast")]
24 GrokCode,
25
26 #[strum(serialize = "grok-3")]
27 Grok3,
28 #[strum(serialize = "grok-3-fast")]
29 Grok3Fast,
30
31 #[strum(serialize = "grok-3-mini")]
32 Grok3Mini,
33 #[strum(serialize = "grok-3-mini-fast")]
34 Grok3MiniFast,
35
36 #[strum(serialize = "grok-2")]
38 Grok2,
39
40 #[strum(serialize = "grok-2-vision")]
42 Grok2Vision,
43}
44
45impl LanguageModel {
46 pub fn err_ivalid_model(model: String) -> String {
47 format!("Invalid language model '{model}'")
48 }
49}
50
51#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
52pub enum ImageModel {
53 #[strum(serialize = "grok-2-image")]
54 Grok2Image,
55}
56
57impl ImageModel {
58 pub fn err_ivalid_model(model: String) -> String {
59 format!("Invalid image model '{model}'")
60 }
61}
62
63#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
64#[strum(serialize_all = "snake_case")]
65pub enum Role {
66 Assistant,
67 System,
68 Tool,
69
70 User,
71}
72
73pub mod url {
77 pub const HOST: &str = "https://api.x.ai/v1";
78 pub const MANAGEMENT_HOST: &str = "https://management-api.x.ai";
79
80 pub mod api {
81 use super::HOST;
82 use const_format::formatcp;
83
84 pub const GET_KEY: &str = formatcp!("{HOST}/api-key");
85 pub const GET_MODELS: &str = formatcp!("{HOST}/models");
86 pub const GET_LANGUAGE_MODELS: &str = formatcp!("{HOST}/language-models");
87 pub const GET_IMAGE_MODELS: &str = formatcp!("{HOST}/image-generation-models");
88
89 pub const POST_TOKENIZE_TEXT: &str = formatcp!("{HOST}/tokenize-text");
90
91 pub fn get_model(id: String) -> String {
92 format!("{GET_MODELS}/{id}")
93 }
94
95 pub fn get_language_model(id: String) -> String {
96 format!("{GET_LANGUAGE_MODELS}/{id}")
97 }
98
99 pub fn get_image_model(id: String) -> String {
100 format!("{GET_IMAGE_MODELS}/{id}")
101 }
102 }
103
104 pub mod chat {
105 use super::HOST;
106 use const_format::formatcp;
107
108 pub const POST_COMPLETION: &str = formatcp!("{HOST}/chat/completions");
109 pub const GET_DEFERED_COMPLETION: &str = formatcp!("{HOST}/chat/deferred-completion");
110
111 pub fn get_deferred_completion(request_id: String) -> String {
112 format!("{GET_DEFERED_COMPLETION}/{request_id}")
113 }
114 }
115
116 pub mod image {
117 use super::HOST;
118 use const_format::formatcp;
119
120 pub const POST_GENERATE: &str = formatcp!("{HOST}/images/generations");
121 }
122}
123
124#[derive(Debug, Clone)]
128pub struct GrokClient {
129 client: reqwest::Client,
130 api_key: String,
131}
132
133impl GrokClient {
134 pub fn new(api_key: String) -> Self {
136 Self {
137 client: reqwest::Client::new(),
138 api_key,
139 }
140 }
141
142 pub fn with_client(client: reqwest::Client, api_key: String) -> Self {
144 Self { client, api_key }
145 }
146
147 pub fn api_key(&self) -> &str {
149 &self.api_key
150 }
151
152 pub fn client(&self) -> &reqwest::Client {
154 &self.client
155 }
156
157 pub async fn get_api_key(&self) -> Result<ApiKey> {
163 let res = self
164 .client
165 .get(url::api::GET_KEY)
166 .header("Authorization", format!("Bearer {}", self.api_key))
167 .send()
168 .await?;
169
170 Ok(res.json().await?)
171 }
172
173 pub async fn get_model(&self, id: LanguageModel) -> Result<Model> {
175 let res = self
176 .client
177 .get(url::api::get_model(id.to_string()))
178 .header("Authorization", format!("Bearer {}", self.api_key))
179 .send()
180 .await?;
181
182 Ok(res.json().await?)
183 }
184
185 pub async fn get_language_models(&self) -> Result<Vec<ApiLanguageModel>> {
187 let res = self
188 .client
189 .get(url::api::GET_LANGUAGE_MODELS)
190 .header("Authorization", format!("Bearer {}", self.api_key))
191 .send()
192 .await?;
193
194 let res: crate::types::api::LanguageModels = res.json().await?;
195 Ok(res.models)
196 }
197
198 pub async fn get_language_model(&self, id: LanguageModel) -> Result<ApiLanguageModel> {
200 let res = self
201 .client
202 .get(url::api::get_language_model(id.to_string()))
203 .header("Authorization", format!("Bearer {}", self.api_key))
204 .send()
205 .await?;
206
207 Ok(res.json().await?)
208 }
209
210 pub async fn get_image_models(&self) -> Result<Vec<ApiImageModel>> {
212 let res = self
213 .client
214 .get(url::api::GET_IMAGE_MODELS)
215 .header("Authorization", format!("Bearer {}", self.api_key))
216 .send()
217 .await?;
218
219 let res: crate::types::api::ImageModels = res.json().await?;
220 Ok(res.models)
221 }
222
223 pub async fn get_image_model(&self, id: ImageModel) -> Result<ApiImageModel> {
225 let res = self
226 .client
227 .get(url::api::get_image_model(id.to_string()))
228 .header("Authorization", format!("Bearer {}", self.api_key))
229 .send()
230 .await?;
231
232 Ok(res.json().await?)
233 }
234
235 pub async fn tokenize_text(
237 &self,
238 model: LanguageModel,
239 text: String,
240 ) -> Result<TokenizeResponse> {
241 let body = crate::types::api::TokenizeRequest::init(model, text);
242 let res = self
243 .client
244 .post(url::api::POST_TOKENIZE_TEXT)
245 .header("Authorization", format!("Bearer {}", self.api_key))
246 .json(&body)
247 .send()
248 .await?;
249
250 Ok(res.json().await?)
251 }
252
253 pub async fn chat_complete(
259 &self,
260 request: &ChatCompletionRequest,
261 ) -> Result<ChatCompletionResponse> {
262 let mut complete_req = request.clone();
263 complete_req.stream = Some(false);
264 complete_req.deferred = Some(false);
265
266 let res = self
267 .client
268 .post(url::chat::POST_COMPLETION)
269 .header("Authorization", format!("Bearer {}", self.api_key))
270 .header("Content-Type", "application/json")
271 .json(&complete_req)
272 .send()
273 .await?;
274
275 Ok(res.json().await?)
276 }
277
278 pub async fn chat_stream<F1, F2>(
280 &self,
281 request: &ChatCompletionRequest,
282 on_content_token: F1,
283 on_reason_token: Option<F2>,
284 ) -> Result<ChatCompletionResponse>
285 where
286 F1: Fn(&str),
287 F2: Fn(&str),
288 {
289 let mut complete_req = request.clone();
290 complete_req.stream = Some(true);
291 complete_req.deferred = Some(false);
292
293 let req_builder = self
294 .client
295 .post(url::chat::POST_COMPLETION)
296 .header("Authorization", format!("Bearer {}", self.api_key))
297 .header("Content-Type", "application/json")
298 .json(&complete_req);
299
300 let mut stream = reqwest_eventsource::EventSource::new(req_builder)?;
301
302 let mut buf_reasoning_content = String::new();
303 let mut buf_content = String::new();
304 let mut complete_res = ChatCompletionResponse::new(0);
305 let mut init = true;
306 let mut role: Option<String> = None;
307
308 while let Some(event) = stream.next().await {
309 match event {
310 Ok(reqwest_eventsource::Event::Open) => {}
311 Ok(reqwest_eventsource::Event::Message(message)) => {
312 if message.data == "[DONE]" {
313 stream.close();
314 break;
315 }
316
317 let chunk: stream::ChatCompletionChunk = serde_json::from_str(&message.data)
318 .map_err(|e| error::Error::SerdeJson(e))?;
319
320 if init {
321 init = false;
322 complete_res.id = chunk.id;
323 complete_res.object = "chat.response".to_string();
324 complete_res.created = chunk.created;
325 complete_res.model = chunk.model;
326 complete_res.system_fingerprint = Some(chunk.system_fingerprint);
327 }
328
329 if let Some(choice) = chunk.choices.last()
330 && role.is_none()
331 {
332 if let Some(r) = &choice.delta.role {
333 role = Some(r.clone());
334 }
335 }
336
337 if chunk.usage.is_some() {
338 complete_res.usage = chunk.usage;
339 }
340
341 if chunk.citations.is_some() {
342 complete_res.citations = chunk.citations;
343 }
344
345 if let Some(choice) = chunk.choices.get(0) {
346 if let (Some(cb_reason_token), Some(reason_token)) =
347 (&on_reason_token, &choice.delta.reasoning_content)
348 {
349 cb_reason_token(&reason_token);
350 buf_reasoning_content.push_str(reason_token);
351 }
352
353 if let Some(content_token) = &choice.delta.content {
354 on_content_token(&content_token);
355 buf_content.push_str(content_token);
356 }
357 }
358 }
359 Err(err) => {
360 stream.close();
361 return Err(error::Error::EventSource(err));
362 }
363 }
364 }
365
366 complete_res.choices.push(Choice {
367 index: 0,
368 message: Message {
369 role: role.unwrap_or("unknown".to_string()),
370 content: buf_content,
371 reasoning_content: Some(buf_reasoning_content),
372 refusal: None,
373 tool_calls: None,
374 tool_call_id: None,
375 },
376 finish_reason: "stop".to_string(),
377 });
378
379 Ok(complete_res)
380 }
381
382 pub async fn chat_defer(
384 &self,
385 request: &ChatCompletionRequest,
386 ) -> Result<DeferredChatCompletionResponse> {
387 let mut complete_req = request.clone();
388 complete_req.stream = Some(false);
389 complete_req.deferred = Some(true);
390
391 let res = self
392 .client
393 .post(url::chat::POST_COMPLETION)
394 .header("Authorization", format!("Bearer {}", self.api_key))
395 .header("Content-Type", "application/json")
396 .json(&complete_req)
397 .send()
398 .await?;
399
400 Ok(res.json().await?)
401 }
402
403 pub async fn get_deferred_completion(
405 &self,
406 request_id: String,
407 ) -> Result<ChatCompletionResponse> {
408 let res = self
409 .client
410 .get(url::chat::get_deferred_completion(request_id))
411 .header("Authorization", format!("Bearer {}", self.api_key))
412 .send()
413 .await?;
414
415 Ok(res.json().await?)
416 }
417
418 pub async fn generate_image(&self, request: &ImageRequest) -> Result<ImageResponse> {
424 let res = self
425 .client
426 .post(url::image::POST_GENERATE)
427 .header("Authorization", format!("Bearer {}", self.api_key))
428 .json(request)
429 .send()
430 .await?;
431
432 Ok(res.json().await?)
433 }
434}