use std::collections::HashMap;
use std::sync::Arc;
use llmsdk_provider::ProviderError;
use llmsdk_provider_utils::http::HttpClient;
use crate::chat::XaiChatModel;
use crate::files::XaiFiles;
use crate::image::XaiImageModel;
use crate::responses::XaiResponsesLanguageModel;
use crate::video::XaiVideoModel;
use crate::{API_KEY_ENV_VAR, DEFAULT_BASE_URL};
#[derive(Debug, Clone)]
pub struct Xai {
inner: Arc<Inner>,
}
#[derive(Debug)]
pub(crate) struct Inner {
pub(crate) base_url: String,
pub(crate) headers: HashMap<String, Option<String>>,
pub(crate) http: HttpClient,
}
impl Xai {
#[must_use]
pub fn builder() -> XaiBuilder {
XaiBuilder::default()
}
pub fn from_env() -> Result<Self, ProviderError> {
Self::builder().build()
}
#[must_use]
pub fn chat(&self, model_id: impl Into<String>) -> XaiChatModel {
XaiChatModel::new(Arc::clone(&self.inner), model_id.into())
}
#[must_use]
pub fn language_model(&self, model_id: impl Into<String>) -> XaiChatModel {
self.chat(model_id)
}
#[must_use]
pub fn files(&self) -> XaiFiles {
XaiFiles::new(Arc::clone(&self.inner))
}
#[must_use]
pub fn image(&self, model_id: impl Into<String>) -> XaiImageModel {
XaiImageModel::new(Arc::clone(&self.inner), model_id.into())
}
#[must_use]
pub fn image_model(&self, model_id: impl Into<String>) -> XaiImageModel {
self.image(model_id)
}
#[must_use]
pub fn video(&self, model_id: impl Into<String>) -> XaiVideoModel {
XaiVideoModel::new(Arc::clone(&self.inner), model_id.into())
}
#[must_use]
pub fn video_model(&self, model_id: impl Into<String>) -> XaiVideoModel {
self.video(model_id)
}
#[must_use]
pub fn responses(&self, model_id: impl Into<String>) -> XaiResponsesLanguageModel {
XaiResponsesLanguageModel::new(Arc::clone(&self.inner), model_id.into())
}
}
#[derive(Debug, Default, Clone)]
pub struct XaiBuilder {
api_key: Option<String>,
base_url: Option<String>,
extra_headers: HashMap<String, Option<String>>,
http: Option<HttpClient>,
}
impl XaiBuilder {
#[must_use]
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: Option<String>) -> Self {
self.extra_headers.insert(name.into(), value);
self
}
#[must_use]
pub fn http_client(mut self, client: HttpClient) -> Self {
self.http = Some(client);
self
}
pub fn build(self) -> Result<Xai, ProviderError> {
let api_key = llmsdk_provider_utils::api_key::load_api_key(
&llmsdk_provider_utils::api_key::LoadApiKey {
api_key: self.api_key.as_deref(),
env_var: API_KEY_ENV_VAR,
description: "xAI",
parameter_name: Some("api_key"),
},
)?;
let base_url = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned());
let mut headers = self.extra_headers;
headers.insert("authorization".into(), Some(format!("Bearer {api_key}")));
let http = match self.http {
Some(client) => client,
None => HttpClient::new()?,
};
Ok(Xai {
inner: Arc::new(Inner {
base_url,
headers,
http,
}),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_with_explicit_key_succeeds() {
let xai = Xai::builder().api_key("xai-test-key").build().expect("ok");
assert_eq!(xai.inner.base_url, DEFAULT_BASE_URL);
assert!(
xai.inner
.headers
.get("authorization")
.unwrap()
.as_ref()
.unwrap()
.starts_with("Bearer ")
);
}
#[test]
fn builder_custom_base_url() {
let xai = Xai::builder()
.api_key("k")
.base_url("https://proxy.example.com/v1")
.build()
.expect("ok");
assert_eq!(xai.inner.base_url, "https://proxy.example.com/v1");
}
}