use serde::{Deserialize, Serialize};
use strum::{Display, EnumIter, EnumString};
use crate::error::{self, Result};
use crate::types::api::{
ApiKey, ImageModel as ApiImageModel, LanguageModel as ApiLanguageModel, Model, TokenizeResponse,
};
use crate::types::chat::{
ChatCompletionRequest, ChatCompletionResponse, Choice, DeferredChatCompletionResponse, Message,
stream,
};
use crate::types::image::{ImageRequest, ImageResponse};
use futures::StreamExt;
#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
pub enum LanguageModel {
#[strum(serialize = "grok-4")]
Grok4,
#[strum(serialize = "grok-code-fast")]
GrokCode,
#[strum(serialize = "grok-3")]
Grok3,
#[strum(serialize = "grok-3-fast")]
Grok3Fast,
#[strum(serialize = "grok-3-mini")]
Grok3Mini,
#[strum(serialize = "grok-3-mini-fast")]
Grok3MiniFast,
#[strum(serialize = "grok-2")]
Grok2,
#[strum(serialize = "grok-2-vision")]
Grok2Vision,
}
impl LanguageModel {
pub fn err_ivalid_model(model: String) -> String {
format!("Invalid language model '{model}'")
}
}
#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
pub enum ImageModel {
#[strum(serialize = "grok-2-image")]
Grok2Image,
}
impl ImageModel {
pub fn err_ivalid_model(model: String) -> String {
format!("Invalid image model '{model}'")
}
}
#[derive(Clone, Copy, Debug, Display, PartialEq, EnumIter, EnumString, Serialize, Deserialize)]
#[strum(serialize_all = "snake_case")]
pub enum Role {
Assistant,
System,
Tool,
User,
}
pub mod url {
pub const HOST: &str = "https://api.x.ai/v1";
pub const MANAGEMENT_HOST: &str = "https://management-api.x.ai";
pub mod api {
use super::HOST;
use const_format::formatcp;
pub const GET_KEY: &str = formatcp!("{HOST}/api-key");
pub const GET_MODELS: &str = formatcp!("{HOST}/models");
pub const GET_LANGUAGE_MODELS: &str = formatcp!("{HOST}/language-models");
pub const GET_IMAGE_MODELS: &str = formatcp!("{HOST}/image-generation-models");
pub const POST_TOKENIZE_TEXT: &str = formatcp!("{HOST}/tokenize-text");
pub fn get_model(id: String) -> String {
format!("{GET_MODELS}/{id}")
}
pub fn get_language_model(id: String) -> String {
format!("{GET_LANGUAGE_MODELS}/{id}")
}
pub fn get_image_model(id: String) -> String {
format!("{GET_IMAGE_MODELS}/{id}")
}
}
pub mod chat {
use super::HOST;
use const_format::formatcp;
pub const POST_COMPLETION: &str = formatcp!("{HOST}/chat/completions");
pub const GET_DEFERED_COMPLETION: &str = formatcp!("{HOST}/chat/deferred-completion");
pub fn get_deferred_completion(request_id: String) -> String {
format!("{GET_DEFERED_COMPLETION}/{request_id}")
}
}
pub mod image {
use super::HOST;
use const_format::formatcp;
pub const POST_GENERATE: &str = formatcp!("{HOST}/images/generations");
}
}
#[derive(Debug, Clone)]
pub struct GrokClient {
client: reqwest::Client,
api_key: String,
}
impl GrokClient {
pub fn new(api_key: String) -> Self {
Self {
client: reqwest::Client::new(),
api_key,
}
}
pub fn with_client(client: reqwest::Client, api_key: String) -> Self {
Self { client, api_key }
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn client(&self) -> &reqwest::Client {
&self.client
}
pub async fn get_api_key(&self) -> Result<ApiKey> {
let res = self
.client
.get(url::api::GET_KEY)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
Ok(res.json().await?)
}
pub async fn get_model(&self, id: LanguageModel) -> Result<Model> {
let res = self
.client
.get(url::api::get_model(id.to_string()))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
Ok(res.json().await?)
}
pub async fn get_language_models(&self) -> Result<Vec<ApiLanguageModel>> {
let res = self
.client
.get(url::api::GET_LANGUAGE_MODELS)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let res: crate::types::api::LanguageModels = res.json().await?;
Ok(res.models)
}
pub async fn get_language_model(&self, id: LanguageModel) -> Result<ApiLanguageModel> {
let res = self
.client
.get(url::api::get_language_model(id.to_string()))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
Ok(res.json().await?)
}
pub async fn get_image_models(&self) -> Result<Vec<ApiImageModel>> {
let res = self
.client
.get(url::api::GET_IMAGE_MODELS)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
let res: crate::types::api::ImageModels = res.json().await?;
Ok(res.models)
}
pub async fn get_image_model(&self, id: ImageModel) -> Result<ApiImageModel> {
let res = self
.client
.get(url::api::get_image_model(id.to_string()))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
Ok(res.json().await?)
}
pub async fn tokenize_text(
&self,
model: LanguageModel,
text: String,
) -> Result<TokenizeResponse> {
let body = crate::types::api::TokenizeRequest::init(model, text);
let res = self
.client
.post(url::api::POST_TOKENIZE_TEXT)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
Ok(res.json().await?)
}
pub async fn chat_complete(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse> {
let mut complete_req = request.clone();
complete_req.stream = Some(false);
complete_req.deferred = Some(false);
let res = self
.client
.post(url::chat::POST_COMPLETION)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&complete_req)
.send()
.await?;
Ok(res.json().await?)
}
pub async fn chat_stream<F>(
&self,
request: &ChatCompletionRequest,
on_content_token: F,
on_reason_token: Option<F>,
) -> Result<ChatCompletionResponse>
where
F: Fn(&str), {
let mut complete_req = request.clone();
complete_req.stream = Some(true);
complete_req.deferred = Some(false);
let req_builder = self
.client
.post(url::chat::POST_COMPLETION)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&complete_req);
let mut stream = reqwest_eventsource::EventSource::new(req_builder)?;
let mut buf_reasoning_content = String::new();
let mut buf_content = String::new();
let mut complete_res = ChatCompletionResponse::new(0);
let mut init = true;
let mut role: Option<String> = None;
while let Some(event) = stream.next().await {
match event {
Ok(reqwest_eventsource::Event::Open) => {}
Ok(reqwest_eventsource::Event::Message(message)) => {
if message.data == "[DONE]" {
stream.close();
break;
}
let chunk: stream::ChatCompletionChunk = serde_json::from_str(&message.data)
.map_err(|e| error::Error::SerdeJson(e))?;
if init {
init = false;
complete_res.id = chunk.id;
complete_res.object = "chat.response".to_string();
complete_res.created = chunk.created;
complete_res.model = chunk.model;
complete_res.system_fingerprint = Some(chunk.system_fingerprint);
}
if let Some(choice) = chunk.choices.last()
&& role.is_none()
{
if let Some(r) = &choice.delta.role {
role = Some(r.clone());
}
}
if chunk.usage.is_some() {
complete_res.usage = chunk.usage;
}
if chunk.citations.is_some() {
complete_res.citations = chunk.citations;
}
if let Some(choice) = chunk.choices.get(0) {
if let (Some(cb_reason_token), Some(reason_token)) =
(&on_reason_token, &choice.delta.reasoning_content)
{
cb_reason_token(&reason_token);
buf_reasoning_content.push_str(reason_token);
}
if let Some(content_token) = &choice.delta.content {
on_content_token(&content_token);
buf_content.push_str(content_token);
}
}
}
Err(err) => {
stream.close();
return Err(error::Error::EventSource(err));
}
}
}
complete_res.choices.push(Choice {
index: 0,
message: Message {
role: role.unwrap_or("unknown".to_string()),
content: buf_content,
reasoning_content: Some(buf_reasoning_content),
refusal: None,
tool_calls: None,
tool_call_id: None,
},
finish_reason: "stop".to_string(),
});
Ok(complete_res)
}
pub async fn chat_defer(
&self,
request: &ChatCompletionRequest,
) -> Result<DeferredChatCompletionResponse> {
let mut complete_req = request.clone();
complete_req.stream = Some(false);
complete_req.deferred = Some(true);
let res = self
.client
.post(url::chat::POST_COMPLETION)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&complete_req)
.send()
.await?;
Ok(res.json().await?)
}
pub async fn get_deferred_completion(
&self,
request_id: String,
) -> Result<ChatCompletionResponse> {
let res = self
.client
.get(url::chat::get_deferred_completion(request_id))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await?;
Ok(res.json().await?)
}
pub async fn generate_image(&self, request: &ImageRequest) -> Result<ImageResponse> {
let res = self
.client
.post(url::image::POST_GENERATE)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(request)
.send()
.await?;
Ok(res.json().await?)
}
}