use async_trait::async_trait;
use serde::de::DeserializeOwned;
use crate::backend::usage::{GenerateResult, MaterializeResult};
use crate::backend::{LLMClient, MediaFile, ModelInfo};
use crate::error::{ApiErrorKind, RStructorError, Result};
use crate::model::Instructor;
#[cfg(feature = "anthropic")]
use crate::backend::anthropic::AnthropicClient;
#[cfg(feature = "gemini")]
use crate::backend::gemini::GeminiClient;
#[cfg(feature = "grok")]
use crate::backend::grok::GrokClient;
#[cfg(feature = "openai")]
use crate::backend::openai::OpenAIClient;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Provider {
#[cfg(feature = "openai")]
OpenAI,
#[cfg(feature = "anthropic")]
Anthropic,
#[cfg(feature = "grok")]
Grok,
#[cfg(feature = "gemini")]
Gemini,
}
#[derive(Clone)]
pub enum AnyClient {
#[cfg(feature = "openai")]
OpenAI(OpenAIClient),
#[cfg(feature = "anthropic")]
Anthropic(AnthropicClient),
#[cfg(feature = "grok")]
Grok(GrokClient),
#[cfg(feature = "gemini")]
Gemini(GeminiClient),
}
impl AnyClient {
pub fn from_env_for(provider: Provider) -> Result<Self> {
match provider {
#[cfg(feature = "openai")]
Provider::OpenAI => Ok(Self::OpenAI(OpenAIClient::from_env()?)),
#[cfg(feature = "anthropic")]
Provider::Anthropic => Ok(Self::Anthropic(AnthropicClient::from_env()?)),
#[cfg(feature = "grok")]
Provider::Grok => Ok(Self::Grok(GrokClient::from_env()?)),
#[cfg(feature = "gemini")]
Provider::Gemini => Ok(Self::Gemini(GeminiClient::from_env()?)),
}
}
#[must_use]
pub fn provider(&self) -> Provider {
match self {
#[cfg(feature = "openai")]
Self::OpenAI(_) => Provider::OpenAI,
#[cfg(feature = "anthropic")]
Self::Anthropic(_) => Provider::Anthropic,
#[cfg(feature = "grok")]
Self::Grok(_) => Provider::Grok,
#[cfg(feature = "gemini")]
Self::Gemini(_) => Provider::Gemini,
}
}
}
#[cfg(feature = "openai")]
impl From<OpenAIClient> for AnyClient {
fn from(client: OpenAIClient) -> Self {
Self::OpenAI(client)
}
}
#[cfg(feature = "anthropic")]
impl From<AnthropicClient> for AnyClient {
fn from(client: AnthropicClient) -> Self {
Self::Anthropic(client)
}
}
#[cfg(feature = "grok")]
impl From<GrokClient> for AnyClient {
fn from(client: GrokClient) -> Self {
Self::Grok(client)
}
}
#[cfg(feature = "gemini")]
impl From<GeminiClient> for AnyClient {
fn from(client: GeminiClient) -> Self {
Self::Gemini(client)
}
}
macro_rules! dispatch {
($self:expr, $client:ident => $call:expr) => {
match $self {
#[cfg(feature = "openai")]
Self::OpenAI($client) => $call,
#[cfg(feature = "anthropic")]
Self::Anthropic($client) => $call,
#[cfg(feature = "grok")]
Self::Grok($client) => $call,
#[cfg(feature = "gemini")]
Self::Gemini($client) => $call,
}
};
}
#[async_trait]
impl LLMClient for AnyClient {
async fn materialize<T>(&self, prompt: &str) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
dispatch!(self, c => c.materialize(prompt).await)
}
async fn materialize_with_media<T>(&self, prompt: &str, media: &[MediaFile]) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
dispatch!(self, c => c.materialize_with_media(prompt, media).await)
}
async fn materialize_with_metadata<T>(&self, prompt: &str) -> Result<MaterializeResult<T>>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
dispatch!(self, c => c.materialize_with_metadata(prompt).await)
}
async fn generate(&self, prompt: &str) -> Result<String> {
dispatch!(self, c => c.generate(prompt).await)
}
async fn generate_with_metadata(&self, prompt: &str) -> Result<GenerateResult> {
dispatch!(self, c => c.generate_with_metadata(prompt).await)
}
fn from_env() -> Result<Self> {
#[cfg(feature = "openai")]
if std::env::var("OPENAI_API_KEY").is_ok() {
return Ok(Self::OpenAI(OpenAIClient::from_env()?));
}
#[cfg(feature = "anthropic")]
if std::env::var("ANTHROPIC_API_KEY").is_ok() {
return Ok(Self::Anthropic(AnthropicClient::from_env()?));
}
#[cfg(feature = "grok")]
if std::env::var("XAI_API_KEY").is_ok() {
return Ok(Self::Grok(GrokClient::from_env()?));
}
#[cfg(feature = "gemini")]
if std::env::var("GEMINI_API_KEY").is_ok() {
return Ok(Self::Gemini(GeminiClient::from_env()?));
}
Err(RStructorError::api_error(
"AnyClient",
ApiErrorKind::AuthenticationFailed,
))
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
dispatch!(self, c => c.list_models().await)
}
}