pub mod error;
pub mod image;
pub mod responses;
pub mod tools;
pub mod types;
use crate::core::types::{GenerateOptions, GenerateResult, Prompt, ProviderSettings, StreamPart};
use crate::openai::OpenAIModel;
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
pub struct XAIModel {
pub inner: OpenAIModel,
}
impl XAIModel {
#[must_use]
pub fn new(api_key: String) -> Self {
Self {
inner: OpenAIModel {
api_key,
base_url: "https://api.x.ai/v1".to_string(),
client: Client::new(),
},
}
}
}
#[async_trait]
impl crate::core::LanguageModel for XAIModel {
#[tracing::instrument(skip(self, prompt), fields(model = options.model_id))]
async fn generate(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> crate::core::Result<GenerateResult> {
self.inner.generate(prompt, options).await
}
async fn generate_stream(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> crate::core::Result<BoxStream<'static, StreamPart>> {
self.inner.generate_stream(prompt, options).await
}
}
pub struct XAIProvider {
settings: ProviderSettings,
}
impl XAIProvider {
#[must_use]
pub fn chat(&self, _model_id: &str) -> XAIModel {
let api_key = self
.settings
.api_key
.clone()
.or_else(|| std::env::var("XAI_API_KEY").ok())
.unwrap_or_default();
let base_url = self
.settings
.base_url
.clone()
.unwrap_or_else(|| "https://api.x.ai/v1".to_string());
XAIModel {
inner: OpenAIModel {
api_key,
base_url,
client: Client::new(),
},
}
}
#[must_use]
pub fn language_model(&self, model_id: &str) -> XAIModel {
self.chat(model_id)
}
#[must_use]
pub fn image(&self, _model_id: &str) -> image::XaiImageModel {
let api_key = self
.settings
.api_key
.clone()
.or_else(|| std::env::var("XAI_API_KEY").ok())
.unwrap_or_default();
let base_url = self
.settings
.base_url
.clone()
.unwrap_or_else(|| "https://api.x.ai/v1".to_string());
image::XaiImageModel {
api_key,
base_url,
client: Client::new(),
}
}
#[must_use]
pub fn responses(&self, _model_id: &str) -> responses::XaiResponsesModel {
let api_key = self
.settings
.api_key
.clone()
.or_else(|| std::env::var("XAI_API_KEY").ok())
.unwrap_or_default();
responses::XaiResponsesModel::new(api_key)
}
}
#[must_use]
pub fn create_xai(settings: ProviderSettings) -> XAIProvider {
XAIProvider { settings }
}