grok_client/
ai.rs

1// ####################
2// AI MODULE
3// ####################
4use serde::{Deserialize, Serialize};
5use strum::{Display, EnumIter, EnumString};
6
7use crate::error::{self, Result};
8use crate::types::api::{
9    ApiKey, ImageModel as ApiImageModel, LanguageModel as ApiLanguageModel, Model, TokenizeResponse,
10};
11use crate::types::chat::{
12    ChatCompletionRequest, ChatCompletionResponse, Choice, DeferredChatCompletionResponse, Message,
13    stream,
14};
15use crate::types::image::{ImageRequest, ImageResponse};
16use futures::StreamExt;
17
18#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
19pub enum LanguageModel {
20    #[strum(serialize = "grok-4")]
21    Grok4,
22
23    #[strum(serialize = "grok-code-fast")]
24    GrokCode,
25
26    #[strum(serialize = "grok-3")]
27    Grok3,
28    #[strum(serialize = "grok-3-fast")]
29    Grok3Fast,
30
31    #[strum(serialize = "grok-3-mini")]
32    Grok3Mini,
33    #[strum(serialize = "grok-3-mini-fast")]
34    Grok3MiniFast,
35
36    // Deprecated
37    #[strum(serialize = "grok-2")]
38    Grok2,
39
40    // Deprecated
41    #[strum(serialize = "grok-2-vision")]
42    Grok2Vision,
43}
44
45impl LanguageModel {
46    pub fn err_ivalid_model(model: String) -> String {
47        format!("Invalid language model '{model}'")
48    }
49}
50
51#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
52pub enum ImageModel {
53    #[strum(serialize = "grok-2-image")]
54    Grok2Image,
55}
56
57impl ImageModel {
58    pub fn err_ivalid_model(model: String) -> String {
59        format!("Invalid image model '{model}'")
60    }
61}
62
63#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
64#[strum(serialize_all = "snake_case")]
65pub enum Role {
66    Assistant,
67    System,
68    Tool,
69
70    User,
71}
72
73// ####################
74// AI API URLs
75// ####################
76pub mod url {
77    pub const HOST: &str = "https://api.x.ai/v1";
78    pub const MANAGEMENT_HOST: &str = "https://management-api.x.ai";
79
80    pub mod api {
81        use super::HOST;
82        use const_format::formatcp;
83
84        pub const GET_KEY: &str = formatcp!("{HOST}/api-key");
85        pub const GET_MODELS: &str = formatcp!("{HOST}/models");
86        pub const GET_LANGUAGE_MODELS: &str = formatcp!("{HOST}/language-models");
87        pub const GET_IMAGE_MODELS: &str = formatcp!("{HOST}/image-generation-models");
88
89        pub const POST_TOKENIZE_TEXT: &str = formatcp!("{HOST}/tokenize-text");
90
91        pub fn get_model(id: String) -> String {
92            format!("{GET_MODELS}/{id}")
93        }
94
95        pub fn get_language_model(id: String) -> String {
96            format!("{GET_LANGUAGE_MODELS}/{id}")
97        }
98
99        pub fn get_image_model(id: String) -> String {
100            format!("{GET_IMAGE_MODELS}/{id}")
101        }
102    }
103
104    pub mod chat {
105        use super::HOST;
106        use const_format::formatcp;
107
108        pub const POST_COMPLETION: &str = formatcp!("{HOST}/chat/completions");
109        pub const GET_DEFERED_COMPLETION: &str = formatcp!("{HOST}/chat/deferred-completion");
110
111        pub fn get_deferred_completion(request_id: String) -> String {
112            format!("{GET_DEFERED_COMPLETION}/{request_id}")
113        }
114    }
115
116    pub mod image {
117        use super::HOST;
118        use const_format::formatcp;
119
120        pub const POST_GENERATE: &str = formatcp!("{HOST}/images/generations");
121    }
122}
123
124// ####################
125// GROK CLIENT
126// ####################
127#[derive(Debug, Clone)]
128pub struct GrokClient {
129    client: reqwest::Client,
130    api_key: String,
131}
132
133impl GrokClient {
134    /// Create a new GrokClient with the provided API key
135    pub fn new(api_key: String) -> Self {
136        Self {
137            client: reqwest::Client::new(),
138            api_key,
139        }
140    }
141
142    /// Create a new GrokClient with a custom HTTP client and API key
143    pub fn with_client(client: reqwest::Client, api_key: String) -> Self {
144        Self { client, api_key }
145    }
146
147    /// Get the API key (for debugging or logging purposes)
148    pub fn api_key(&self) -> &str {
149        &self.api_key
150    }
151
152    /// Get a reference to the underlying HTTP client
153    pub fn client(&self) -> &reqwest::Client {
154        &self.client
155    }
156
157    // ####################
158    // API MANAGEMENT METHODS
159    // ####################
160
161    /// Get API key information
162    pub async fn get_api_key(&self) -> Result<ApiKey> {
163        let res = self
164            .client
165            .get(url::api::GET_KEY)
166            .header("Authorization", format!("Bearer {}", self.api_key))
167            .send()
168            .await?;
169
170        Ok(res.json().await?)
171    }
172
173    /// Get a specific model by ID
174    pub async fn get_model(&self, id: LanguageModel) -> Result<Model> {
175        let res = self
176            .client
177            .get(url::api::get_model(id.to_string()))
178            .header("Authorization", format!("Bearer {}", self.api_key))
179            .send()
180            .await?;
181
182        Ok(res.json().await?)
183    }
184
185    /// Get all available language models
186    pub async fn get_language_models(&self) -> Result<Vec<ApiLanguageModel>> {
187        let res = self
188            .client
189            .get(url::api::GET_LANGUAGE_MODELS)
190            .header("Authorization", format!("Bearer {}", self.api_key))
191            .send()
192            .await?;
193
194        let res: crate::types::api::LanguageModels = res.json().await?;
195        Ok(res.models)
196    }
197
198    /// Get a specific language model by ID
199    pub async fn get_language_model(&self, id: LanguageModel) -> Result<ApiLanguageModel> {
200        let res = self
201            .client
202            .get(url::api::get_language_model(id.to_string()))
203            .header("Authorization", format!("Bearer {}", self.api_key))
204            .send()
205            .await?;
206
207        Ok(res.json().await?)
208    }
209
210    /// Get all available image models
211    pub async fn get_image_models(&self) -> Result<Vec<ApiImageModel>> {
212        let res = self
213            .client
214            .get(url::api::GET_IMAGE_MODELS)
215            .header("Authorization", format!("Bearer {}", self.api_key))
216            .send()
217            .await?;
218
219        let res: crate::types::api::ImageModels = res.json().await?;
220        Ok(res.models)
221    }
222
223    /// Get a specific image model by ID
224    pub async fn get_image_model(&self, id: ImageModel) -> Result<ApiImageModel> {
225        let res = self
226            .client
227            .get(url::api::get_image_model(id.to_string()))
228            .header("Authorization", format!("Bearer {}", self.api_key))
229            .send()
230            .await?;
231
232        Ok(res.json().await?)
233    }
234
235    /// Tokenize text using a specific model
236    pub async fn tokenize_text(
237        &self,
238        model: LanguageModel,
239        text: String,
240    ) -> Result<TokenizeResponse> {
241        let body = crate::types::api::TokenizeRequest::init(model, text);
242        let res = self
243            .client
244            .post(url::api::POST_TOKENIZE_TEXT)
245            .header("Authorization", format!("Bearer {}", self.api_key))
246            .json(&body)
247            .send()
248            .await?;
249
250        Ok(res.json().await?)
251    }
252
253    // ####################
254    // CHAT METHODS
255    // ####################
256
257    /// Send a chat completion request
258    pub async fn chat_complete(
259        &self,
260        request: &ChatCompletionRequest,
261    ) -> Result<ChatCompletionResponse> {
262        let mut complete_req = request.clone();
263        complete_req.stream = Some(false);
264        complete_req.deferred = Some(false);
265
266        let res = self
267            .client
268            .post(url::chat::POST_COMPLETION)
269            .header("Authorization", format!("Bearer {}", self.api_key))
270            .header("Content-Type", "application/json")
271            .json(&complete_req)
272            .send()
273            .await?;
274
275        Ok(res.json().await?)
276    }
277
278    /// Send a streaming chat completion request
279    pub async fn chat_stream<F1, F2>(
280        &self,
281        request: &ChatCompletionRequest,
282        on_content_token: F1,
283        on_reason_token: Option<F2>,
284    ) -> Result<ChatCompletionResponse>
285    where
286        F1: Fn(&str),
287        F2: Fn(&str),
288    {
289        let mut complete_req = request.clone();
290        complete_req.stream = Some(true);
291        complete_req.deferred = Some(false);
292
293        let req_builder = self
294            .client
295            .post(url::chat::POST_COMPLETION)
296            .header("Authorization", format!("Bearer {}", self.api_key))
297            .header("Content-Type", "application/json")
298            .json(&complete_req);
299
300        let mut stream = reqwest_eventsource::EventSource::new(req_builder)?;
301
302        let mut buf_reasoning_content = String::new();
303        let mut buf_content = String::new();
304        let mut complete_res = ChatCompletionResponse::new(0);
305        let mut init = true;
306        let mut role: Option<String> = None;
307
308        while let Some(event) = stream.next().await {
309            match event {
310                Ok(reqwest_eventsource::Event::Open) => {}
311                Ok(reqwest_eventsource::Event::Message(message)) => {
312                    if message.data == "[DONE]" {
313                        stream.close();
314                        break;
315                    }
316
317                    let chunk: stream::ChatCompletionChunk = serde_json::from_str(&message.data)
318                        .map_err(|e| error::Error::SerdeJson(e))?;
319
320                    if init {
321                        init = false;
322                        complete_res.id = chunk.id;
323                        complete_res.object = "chat.response".to_string();
324                        complete_res.created = chunk.created;
325                        complete_res.model = chunk.model;
326                        complete_res.system_fingerprint = Some(chunk.system_fingerprint);
327                    }
328
329                    if let Some(choice) = chunk.choices.last()
330                        && role.is_none()
331                    {
332                        if let Some(r) = &choice.delta.role {
333                            role = Some(r.clone());
334                        }
335                    }
336
337                    if chunk.usage.is_some() {
338                        complete_res.usage = chunk.usage;
339                    }
340
341                    if chunk.citations.is_some() {
342                        complete_res.citations = chunk.citations;
343                    }
344
345                    if let Some(choice) = chunk.choices.get(0) {
346                        if let (Some(cb_reason_token), Some(reason_token)) =
347                            (&on_reason_token, &choice.delta.reasoning_content)
348                        {
349                            cb_reason_token(&reason_token);
350                            buf_reasoning_content.push_str(reason_token);
351                        }
352
353                        if let Some(content_token) = &choice.delta.content {
354                            on_content_token(&content_token);
355                            buf_content.push_str(content_token);
356                        }
357                    }
358                }
359                Err(err) => {
360                    stream.close();
361                    return Err(error::Error::EventSource(err));
362                }
363            }
364        }
365
366        complete_res.choices.push(Choice {
367            index: 0,
368            message: Message {
369                role: role.unwrap_or("unknown".to_string()),
370                content: buf_content,
371                reasoning_content: Some(buf_reasoning_content),
372                refusal: None,
373                tool_calls: None,
374                tool_call_id: None,
375            },
376            finish_reason: "stop".to_string(),
377        });
378
379        Ok(complete_res)
380    }
381
382    /// Send a deferred chat completion request
383    pub async fn chat_defer(
384        &self,
385        request: &ChatCompletionRequest,
386    ) -> Result<DeferredChatCompletionResponse> {
387        let mut complete_req = request.clone();
388        complete_req.stream = Some(false);
389        complete_req.deferred = Some(true);
390
391        let res = self
392            .client
393            .post(url::chat::POST_COMPLETION)
394            .header("Authorization", format!("Bearer {}", self.api_key))
395            .header("Content-Type", "application/json")
396            .json(&complete_req)
397            .send()
398            .await?;
399
400        Ok(res.json().await?)
401    }
402
403    /// Get the result of a deferred chat completion
404    pub async fn get_deferred_completion(
405        &self,
406        request_id: String,
407    ) -> Result<ChatCompletionResponse> {
408        let res = self
409            .client
410            .get(url::chat::get_deferred_completion(request_id))
411            .header("Authorization", format!("Bearer {}", self.api_key))
412            .send()
413            .await?;
414
415        Ok(res.json().await?)
416    }
417
418    // ####################
419    // IMAGE METHODS
420    // ####################
421
422    /// Generate images using the specified request
423    pub async fn generate_image(&self, request: &ImageRequest) -> Result<ImageResponse> {
424        let res = self
425            .client
426            .post(url::image::POST_GENERATE)
427            .header("Authorization", format!("Bearer {}", self.api_key))
428            .json(request)
429            .send()
430            .await?;
431
432        Ok(res.json().await?)
433    }
434}