use crate::client::LlmClient;
use crate::error::LlmError;
use crate::retry_api::RetryOptions;
use crate::stream::ChatStream;
use crate::traits::*;
use crate::types::*;
use std::collections::HashMap;
use std::time::Duration;
pub struct Siumai {
client: Box<dyn LlmClient>,
metadata: ProviderMetadata,
retry_options: Option<RetryOptions>,
}
impl Clone for Siumai {
fn clone(&self) -> Self {
let client = self.client.clone_box();
Self {
client,
metadata: self.metadata.clone(),
retry_options: self.retry_options.clone(),
}
}
}
impl std::fmt::Debug for Siumai {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Siumai")
.field("provider_type", &self.metadata.provider_type)
.field("provider_name", &self.metadata.provider_name)
.field(
"supported_models_count",
&self.metadata.supported_models.len(),
)
.field("capabilities", &self.metadata.capabilities)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ProviderMetadata {
pub provider_type: ProviderType,
pub provider_name: String,
pub supported_models: Vec<String>,
pub capabilities: ProviderCapabilities,
}
impl Siumai {
pub fn new(client: Box<dyn LlmClient>) -> Self {
let metadata = ProviderMetadata {
provider_type: client.provider_type(),
provider_name: client.provider_name().to_string(),
supported_models: client.supported_models(),
capabilities: client.capabilities(),
};
Self {
client,
metadata,
retry_options: None,
}
}
pub fn with_retry_options(mut self, options: Option<RetryOptions>) -> Self {
self.retry_options = options;
self
}
pub fn supports(&self, capability: &str) -> bool {
self.metadata.capabilities.supports(capability)
}
pub const fn metadata(&self) -> &ProviderMetadata {
&self.metadata
}
pub fn client(&self) -> &dyn LlmClient {
self.client.as_ref()
}
pub fn audio_capability(&self) -> AudioCapabilityProxy<'_> {
AudioCapabilityProxy::new(self, self.supports("audio"))
}
pub fn embedding_capability(&self) -> EmbeddingCapabilityProxy<'_> {
EmbeddingCapabilityProxy::new(self, self.supports("embedding"))
}
pub fn vision_capability(&self) -> VisionCapabilityProxy<'_> {
VisionCapabilityProxy::new(self, self.supports("vision"))
}
pub async fn embed(&self, texts: Vec<String>) -> Result<EmbeddingResponse, LlmError> {
EmbeddingCapability::embed(self, texts).await
}
}
#[async_trait::async_trait]
impl ChatCapability for Siumai {
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
if let Some(opts) = &self.retry_options {
crate::retry_api::retry_with(
|| {
let m = messages.clone();
let t = tools.clone();
async move { self.client.chat_with_tools(m, t).await }
},
opts.clone(),
)
.await
} else {
self.client.chat_with_tools(messages, tools).await
}
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError> {
self.client.chat_stream(messages, tools).await
}
}
#[async_trait::async_trait]
impl EmbeddingCapability for Siumai {
async fn embed(&self, texts: Vec<String>) -> Result<EmbeddingResponse, LlmError> {
if let Some(embedding_client) = self.client.as_embedding_capability() {
embedding_client.embed(texts).await
} else {
Err(LlmError::UnsupportedOperation(format!(
"Provider {} does not support embedding functionality. Consider using OpenAI, Gemini, or Ollama for embeddings.",
self.client.provider_name()
)))
}
}
fn embedding_dimension(&self) -> usize {
if let Some(embedding_client) = self.client.as_embedding_capability() {
embedding_client.embedding_dimension()
} else {
match self.client.provider_name() {
"openai" => 1536,
"ollama" => 384,
"gemini" => 768,
_ => 1536,
}
}
}
fn max_tokens_per_embedding(&self) -> usize {
if let Some(embedding_client) = self.client.as_embedding_capability() {
embedding_client.max_tokens_per_embedding()
} else {
match self.client.provider_name() {
"openai" => 8192,
"ollama" => 8192,
"gemini" => 2048,
_ => 8192,
}
}
}
fn supported_embedding_models(&self) -> Vec<String> {
if let Some(embedding_client) = self.client.as_embedding_capability() {
embedding_client.supported_embedding_models()
} else {
match self.client.provider_name() {
"openai" => vec![
"text-embedding-3-small".to_string(),
"text-embedding-3-large".to_string(),
"text-embedding-ada-002".to_string(),
],
"ollama" => vec![
"nomic-embed-text".to_string(),
"mxbai-embed-large".to_string(),
],
"gemini" => vec![
"embedding-001".to_string(),
"text-embedding-004".to_string(),
],
_ => vec![],
}
}
}
}
#[async_trait::async_trait]
impl EmbeddingExtensions for Siumai {
async fn embed_with_config(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, LlmError> {
if let Some(embedding_client) = self.client.as_embedding_capability() {
embedding_client.embed(request.input).await
} else {
Err(LlmError::UnsupportedOperation(format!(
"Provider {} does not support embedding functionality. Consider using OpenAI, Gemini, or Ollama for embeddings.",
self.client.provider_name()
)))
}
}
async fn list_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>, LlmError> {
if let Some(_embedding_client) = self.client.as_embedding_capability() {
let models = self.supported_embedding_models();
let model_infos = models
.into_iter()
.map(|id| {
let mut model_info = EmbeddingModelInfo::new(
id.clone(),
id,
self.embedding_dimension(),
self.max_tokens_per_embedding(),
);
if self.client.provider_name() == "gemini" {
model_info = model_info
.with_task(EmbeddingTaskType::RetrievalQuery)
.with_task(EmbeddingTaskType::RetrievalDocument)
.with_task(EmbeddingTaskType::SemanticSimilarity)
.with_task(EmbeddingTaskType::Classification)
.with_task(EmbeddingTaskType::Clustering)
.with_task(EmbeddingTaskType::QuestionAnswering)
.with_task(EmbeddingTaskType::FactVerification);
}
if self.client.provider_name() == "openai" {
model_info = model_info.with_custom_dimensions();
}
model_info
})
.collect();
Ok(model_infos)
} else {
Err(LlmError::UnsupportedOperation(format!(
"Provider {} does not support embedding functionality.",
self.client.provider_name()
)))
}
}
}
impl LlmClient for Siumai {
fn provider_name(&self) -> &'static str {
match self.metadata.provider_type {
ProviderType::OpenAi => "openai",
ProviderType::Anthropic => "anthropic",
ProviderType::Gemini => "gemini",
ProviderType::XAI => "xai",
ProviderType::Ollama => "ollama",
ProviderType::Custom(_) => "custom",
ProviderType::Groq => "groq",
}
}
fn supported_models(&self) -> Vec<String> {
self.metadata.supported_models.clone()
}
fn capabilities(&self) -> ProviderCapabilities {
self.metadata.capabilities.clone()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
}
pub struct SiumaiBuilder {
pub(crate) provider_type: Option<ProviderType>,
pub(crate) provider_name: Option<String>,
api_key: Option<String>,
base_url: Option<String>,
capabilities: Vec<String>,
common_params: CommonParams,
http_config: HttpConfig,
organization: Option<String>,
project: Option<String>,
tracing_config: Option<crate::tracing::TracingConfig>,
reasoning_enabled: Option<bool>,
reasoning_budget: Option<i32>,
retry_options: Option<RetryOptions>,
}
impl SiumaiBuilder {
pub fn new() -> Self {
Self {
provider_type: None,
provider_name: None,
api_key: None,
base_url: None,
capabilities: Vec::new(),
common_params: CommonParams::default(),
http_config: HttpConfig::default(),
organization: None,
project: None,
tracing_config: None,
reasoning_enabled: None,
reasoning_budget: None,
retry_options: None,
}
}
pub fn provider(mut self, provider_type: ProviderType) -> Self {
self.provider_type = Some(provider_type);
self
}
pub fn provider_name<S: Into<String>>(mut self, name: S) -> Self {
let name = name.into();
self.provider_name = Some(name.clone());
self.provider_type = Some(match name.as_str() {
"openai" => ProviderType::OpenAi,
"anthropic" => ProviderType::Anthropic,
"gemini" => ProviderType::Gemini,
"ollama" => ProviderType::Ollama,
"xai" => ProviderType::XAI,
"groq" => ProviderType::Groq,
"siliconflow" => ProviderType::Custom("siliconflow".to_string()),
"deepseek" => ProviderType::Custom("deepseek".to_string()),
"openrouter" => ProviderType::Custom("openrouter".to_string()),
_ => ProviderType::Custom(name),
});
self
}
pub fn api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn base_url<S: Into<String>>(mut self, base_url: S) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn model<S: Into<String>>(mut self, model: S) -> Self {
self.common_params.model = model.into();
self
}
pub const fn temperature(mut self, temperature: f32) -> Self {
self.common_params.temperature = Some(temperature);
self
}
pub const fn max_tokens(mut self, max_tokens: u32) -> Self {
self.common_params.max_tokens = Some(max_tokens);
self
}
pub const fn top_p(mut self, top_p: f32) -> Self {
self.common_params.top_p = Some(top_p);
self
}
pub const fn seed(mut self, seed: u64) -> Self {
self.common_params.seed = Some(seed);
self
}
pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.common_params.stop_sequences = Some(sequences);
self
}
pub const fn reasoning(mut self, enabled: bool) -> Self {
self.reasoning_enabled = Some(enabled);
self
}
pub const fn reasoning_budget(mut self, budget: i32) -> Self {
self.reasoning_budget = Some(budget);
if budget > 0 {
self.reasoning_enabled = Some(true);
} else if budget == 0 {
self.reasoning_enabled = Some(false);
}
self
}
pub fn organization<S: Into<String>>(mut self, organization: S) -> Self {
self.organization = Some(organization.into());
self
}
pub fn project<S: Into<String>>(mut self, project: S) -> Self {
self.project = Some(project.into());
self
}
pub fn with_capability<S: Into<String>>(mut self, capability: S) -> Self {
self.capabilities.push(capability.into());
self
}
pub fn with_audio(self) -> Self {
self.with_capability("audio")
}
pub fn with_vision(self) -> Self {
self.with_capability("vision")
}
pub fn with_embedding(self) -> Self {
self.with_capability("embedding")
}
pub fn with_image_generation(self) -> Self {
self.with_capability("image_generation")
}
pub fn http_timeout(mut self, timeout: Duration) -> Self {
self.http_config.timeout = Some(timeout);
self
}
pub fn http_connect_timeout(mut self, timeout: Duration) -> Self {
self.http_config.connect_timeout = Some(timeout);
self
}
pub fn http_user_agent<S: Into<String>>(mut self, user_agent: S) -> Self {
self.http_config.user_agent = Some(user_agent.into());
self
}
pub fn http_proxy<S: Into<String>>(mut self, proxy_url: S) -> Self {
self.http_config.proxy = Some(proxy_url.into());
self
}
pub fn http_header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.http_config.headers.insert(key.into(), value.into());
self
}
pub fn http_headers(mut self, headers: HashMap<String, String>) -> Self {
self.http_config.headers.extend(headers);
self
}
pub fn tracing(mut self, config: crate::tracing::TracingConfig) -> Self {
self.tracing_config = Some(config);
self
}
pub fn debug_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::development())
}
pub fn minimal_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::minimal())
}
pub fn json_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::json_production())
}
pub fn enable_tracing(self) -> Self {
self.debug_tracing()
}
pub fn disable_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::disabled())
}
pub fn with_retry(mut self, options: RetryOptions) -> Self {
self.retry_options = Some(options);
self
}
pub async fn build(self) -> Result<Siumai, LlmError> {
fn build_http_client_from_config(cfg: &HttpConfig) -> Result<reqwest::Client, LlmError> {
let mut builder = reqwest::Client::builder();
if let Some(timeout) = cfg.timeout {
builder = builder.timeout(timeout);
}
if let Some(connect_timeout) = cfg.connect_timeout {
builder = builder.connect_timeout(connect_timeout);
}
if let Some(proxy_url) = &cfg.proxy {
let proxy = reqwest::Proxy::all(proxy_url)
.map_err(|e| LlmError::ConfigurationError(format!("Invalid proxy URL: {e}")))?;
builder = builder.proxy(proxy);
}
if let Some(user_agent) = &cfg.user_agent {
builder = builder.user_agent(user_agent);
}
if !cfg.headers.is_empty() {
let mut headers = reqwest::header::HeaderMap::new();
for (k, v) in &cfg.headers {
let name =
reqwest::header::HeaderName::from_bytes(k.as_bytes()).map_err(|e| {
LlmError::ConfigurationError(format!("Invalid header name '{k}': {e}"))
})?;
let value = reqwest::header::HeaderValue::from_str(v).map_err(|e| {
LlmError::ConfigurationError(format!("Invalid header value for '{k}': {e}"))
})?;
headers.insert(name, value);
}
builder = builder.default_headers(headers);
}
builder.build().map_err(|e| {
LlmError::ConfigurationError(format!("Failed to build HTTP client: {e}"))
})
}
let provider_type = self.provider_type.clone().ok_or_else(|| {
LlmError::ConfigurationError("Provider type not specified".to_string())
})?;
let requires_api_key = match provider_type {
ProviderType::Ollama => false, _ => true, };
let api_key = if requires_api_key {
self.api_key
.clone()
.ok_or_else(|| LlmError::ConfigurationError("API key not specified".to_string()))?
} else {
self.api_key.clone().unwrap_or_default()
};
let base_url = self.base_url.clone();
let organization = self.organization.clone();
let project = self.project.clone();
let reasoning_enabled = self.reasoning_enabled;
let reasoning_budget = self.reasoning_budget;
let http_config = self.http_config.clone();
let built_http_client = build_http_client_from_config(&http_config)?;
let mut common_params = self.common_params.clone();
if common_params.model.is_empty() {
#[cfg(any(feature = "openai", feature = "anthropic", feature = "google"))]
use crate::types::models::model_constants as models;
common_params.model = match provider_type {
#[cfg(feature = "openai")]
ProviderType::OpenAi => models::openai::GPT_4O.to_string(),
#[cfg(feature = "anthropic")]
ProviderType::Anthropic => models::anthropic::CLAUDE_SONNET_3_5.to_string(),
#[cfg(feature = "google")]
ProviderType::Gemini => models::gemini::GEMINI_2_5_FLASH.to_string(),
#[cfg(feature = "ollama")]
ProviderType::Ollama => "llama3.2".to_string(),
#[cfg(feature = "xai")]
ProviderType::XAI => "grok-beta".to_string(),
#[cfg(feature = "groq")]
ProviderType::Groq => "llama-3.1-70b-versatile".to_string(),
ProviderType::Custom(ref name) => match name.as_str() {
#[cfg(feature = "openai")]
"siliconflow" => {
models::openai_compatible::siliconflow::DEEPSEEK_V3_1.to_string()
}
#[cfg(feature = "openai")]
"deepseek" => models::openai_compatible::deepseek::CHAT.to_string(),
#[cfg(feature = "openai")]
"openrouter" => models::openai_compatible::openrouter::GPT_4O.to_string(),
_ => "default-model".to_string(),
},
#[cfg(not(feature = "openai"))]
ProviderType::OpenAi => {
return Err(LlmError::UnsupportedOperation(
"OpenAI feature not enabled".to_string(),
));
}
#[cfg(not(feature = "anthropic"))]
ProviderType::Anthropic => {
return Err(LlmError::UnsupportedOperation(
"Anthropic feature not enabled".to_string(),
));
}
#[cfg(not(feature = "google"))]
ProviderType::Gemini => {
return Err(LlmError::UnsupportedOperation(
"Google feature not enabled".to_string(),
));
}
#[cfg(not(feature = "ollama"))]
ProviderType::Ollama => {
return Err(LlmError::UnsupportedOperation(
"Ollama feature not enabled".to_string(),
));
}
#[cfg(not(feature = "xai"))]
ProviderType::XAI => {
return Err(LlmError::UnsupportedOperation(
"xAI feature not enabled".to_string(),
));
}
#[cfg(not(feature = "groq"))]
ProviderType::Groq => {
return Err(LlmError::UnsupportedOperation(
"Groq feature not enabled".to_string(),
));
}
};
}
let provider_params = match provider_type {
ProviderType::Anthropic => {
let mut params = ProviderParams::anthropic();
if let Some(budget) = reasoning_budget {
params = params.with_param("thinking_budget", budget as u32);
}
Some(params)
}
ProviderType::Gemini => {
let mut params = ProviderParams::gemini();
if let Some(budget) = reasoning_budget {
params = params.with_param("thinking_budget", budget as u32);
}
Some(params)
}
ProviderType::Ollama => {
let mut params = ProviderParams::new();
if reasoning_enabled.unwrap_or(false) {
params = params.with_param("think", true);
}
Some(params)
}
_ => {
None
}
};
let _request_builder =
crate::request_factory::RequestBuilderFactory::create_and_validate_builder(
&provider_type,
common_params.clone(),
provider_params.clone(),
)?;
let client: Box<dyn LlmClient> = match provider_type {
#[cfg(feature = "openai")]
ProviderType::OpenAi => {
let mut config = crate::providers::openai::OpenAiConfig::new(api_key)
.with_base_url(
base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
)
.with_model(common_params.model.clone());
if let Some(temp) = common_params.temperature {
config = config.with_temperature(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
config = config.with_max_tokens(max_tokens);
}
if let Some(org) = organization {
config = config.with_organization(org);
}
if let Some(proj) = project {
config = config.with_project(proj);
}
let mut client =
crate::providers::openai::OpenAiClient::new(config, built_http_client.clone());
if let Some(tc) = self.tracing_config.clone() {
let guard = crate::tracing::init_tracing(tc).map_err(|e| {
LlmError::ConfigurationError(format!("Failed to init tracing: {e}"))
})?;
client.set_tracing_guard(guard);
}
Box::new(client)
}
#[cfg(feature = "anthropic")]
ProviderType::Anthropic => {
let anthropic_base_url =
base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string());
let mut anthropic_params = crate::params::AnthropicParams::default();
if let Some(ref params) = provider_params
&& let Some(budget) = params.get::<u32>("thinking_budget")
{
anthropic_params.thinking_budget = Some(budget);
}
let mut client = crate::providers::anthropic::AnthropicClient::new(
api_key,
anthropic_base_url,
built_http_client.clone(),
common_params.clone(),
anthropic_params,
http_config.clone(),
);
if let Some(tc) = self.tracing_config.clone() {
let guard = crate::tracing::init_tracing(tc.clone()).map_err(|e| {
LlmError::ConfigurationError(format!("Failed to init tracing: {e}"))
})?;
client.set_tracing_guard(guard);
client.set_tracing_config(self.tracing_config.clone());
}
Box::new(client)
}
#[cfg(feature = "google")]
ProviderType::Gemini => {
let mut base_builder = crate::builder::LlmBuilder::new();
if let Some(t) = http_config.timeout {
base_builder = base_builder.with_timeout(t);
}
if let Some(ct) = http_config.connect_timeout {
base_builder = base_builder.with_connect_timeout(ct);
}
if let Some(ref ua) = http_config.user_agent {
base_builder = base_builder.with_user_agent(ua);
}
if let Some(ref proxy) = http_config.proxy {
base_builder = base_builder.with_proxy(proxy.clone());
}
for (k, v) in &http_config.headers {
base_builder = base_builder.with_header(k.clone(), v.clone());
}
let mut builder = base_builder
.gemini()
.api_key(api_key)
.model(&common_params.model);
if let Some(temp) = common_params.temperature {
builder = builder.temperature(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
builder = builder.max_tokens(max_tokens as i32);
}
if let Some(top_p) = common_params.top_p {
builder = builder.top_p(top_p);
}
if let Some(ref params) = provider_params
&& let Some(budget) = params.get::<u32>("thinking_budget")
{
builder = builder.thinking_budget(budget as i32);
}
if let Some(tc) = self.tracing_config.clone() {
builder = builder.tracing(tc);
}
Box::new(builder.build().await.map_err(|e| {
LlmError::ConfigurationError(format!("Failed to build Gemini client: {e}"))
})?)
}
#[cfg(feature = "xai")]
ProviderType::XAI => {
let mut base_builder = crate::builder::LlmBuilder::new();
if let Some(t) = http_config.timeout {
base_builder = base_builder.with_timeout(t);
}
if let Some(ct) = http_config.connect_timeout {
base_builder = base_builder.with_connect_timeout(ct);
}
if let Some(ref ua) = http_config.user_agent {
base_builder = base_builder.with_user_agent(ua);
}
if let Some(ref proxy) = http_config.proxy {
base_builder = base_builder.with_proxy(proxy.clone());
}
for (k, v) in &http_config.headers {
base_builder = base_builder.with_header(k.clone(), v.clone());
}
let mut builder = base_builder
.xai()
.api_key(api_key)
.model(&common_params.model);
if let Some(temp) = common_params.temperature {
builder = builder.temperature(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
builder = builder.max_tokens(max_tokens);
}
if let Some(top_p) = common_params.top_p {
builder = builder.top_p(top_p);
}
if let Some(tc) = self.tracing_config.clone() {
builder = builder.tracing(tc);
}
Box::new(builder.build().await.map_err(|e| {
LlmError::ConfigurationError(format!("Failed to build xAI client: {e}"))
})?)
}
#[cfg(feature = "ollama")]
ProviderType::Ollama => {
let ollama_base_url =
base_url.unwrap_or_else(|| "http://localhost:11434".to_string());
let mut ollama_params = crate::providers::ollama::config::OllamaParams::default();
if let Some(ref params) = provider_params
&& let Some(think) = params.get::<bool>("think")
{
ollama_params.think = Some(think);
}
let config = crate::providers::ollama::config::OllamaConfig {
base_url: ollama_base_url,
model: Some(common_params.model.clone()),
common_params: common_params.clone(),
ollama_params,
http_config,
};
let mut client =
crate::providers::ollama::OllamaClient::new(config, built_http_client.clone());
if let Some(tc) = self.tracing_config.clone() {
let guard = crate::tracing::init_tracing(tc.clone()).map_err(|e| {
LlmError::ConfigurationError(format!("Failed to init tracing: {e}"))
})?;
client.set_tracing_guard(guard);
client.set_tracing_config(self.tracing_config.clone());
}
Box::new(client)
}
#[cfg(feature = "groq")]
ProviderType::Groq => {
let groq_base_url =
base_url.unwrap_or_else(|| "https://api.groq.com/openai/v1".to_string());
let mut config = crate::providers::groq::GroqConfig::new(api_key)
.with_base_url(groq_base_url)
.with_model(common_params.model.clone());
if let Some(temp) = common_params.temperature {
config = config.with_temperature(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
config = config.with_max_tokens(max_tokens);
}
let mut client =
crate::providers::groq::GroqClient::new(config, built_http_client.clone());
if let Some(tc) = self.tracing_config.clone() {
let guard = crate::tracing::init_tracing(tc.clone()).map_err(|e| {
LlmError::ConfigurationError(format!("Failed to init tracing: {e}"))
})?;
client.set_tracing_guard(guard);
client.set_tracing_config(self.tracing_config.clone());
}
Box::new(client)
}
ProviderType::Custom(name) => {
match name.as_str() {
#[cfg(feature = "openai")]
"deepseek" => {
let adapter =
crate::providers::openai_compatible::get_provider_adapter("deepseek")?;
let base_url =
base_url.unwrap_or_else(|| "https://api.deepseek.com/v1".to_string());
let mut config =
crate::providers::openai_compatible::OpenAiCompatibleConfig::new(
"deepseek", &api_key, &base_url, adapter,
);
config = config.with_model(&common_params.model);
if let Some(temp) = common_params.temperature {
config.common_params.temperature = Some(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
config.common_params.max_tokens = Some(max_tokens);
}
let config = config.with_http_config(http_config.clone());
let client = crate::providers::openai_compatible::OpenAiCompatibleClient::with_http_client(
config,
built_http_client.clone(),
)
.await?;
Box::new(client)
}
#[cfg(feature = "openai")]
"siliconflow" => {
let adapter = crate::providers::openai_compatible::get_provider_adapter(
"siliconflow",
)?;
let base_url =
base_url.unwrap_or_else(|| "https://api.siliconflow.cn/v1".to_string());
let mut config =
crate::providers::openai_compatible::OpenAiCompatibleConfig::new(
"siliconflow",
&api_key,
&base_url,
adapter,
);
config = config.with_model(&common_params.model);
if let Some(temp) = common_params.temperature {
config.common_params.temperature = Some(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
config.common_params.max_tokens = Some(max_tokens);
}
let config = config.with_http_config(http_config.clone());
let client = crate::providers::openai_compatible::OpenAiCompatibleClient::with_http_client(
config,
built_http_client.clone(),
)
.await?;
Box::new(client)
}
#[cfg(feature = "openai")]
"openrouter" => {
let adapter = crate::providers::openai_compatible::get_provider_adapter(
"openrouter",
)?;
let base_url =
base_url.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
let mut config =
crate::providers::openai_compatible::OpenAiCompatibleConfig::new(
"openrouter",
&api_key,
&base_url,
adapter,
);
config = config.with_model(&common_params.model);
if let Some(temp) = common_params.temperature {
config.common_params.temperature = Some(temp);
}
if let Some(max_tokens) = common_params.max_tokens {
config.common_params.max_tokens = Some(max_tokens);
}
let config = config.with_http_config(http_config.clone());
let client = crate::providers::openai_compatible::OpenAiCompatibleClient::with_http_client(
config,
built_http_client.clone(),
)
.await?;
Box::new(client)
}
_ => {
return Err(LlmError::UnsupportedOperation(format!(
"Custom provider '{name}' not yet implemented"
)));
}
}
}
#[cfg(not(feature = "openai"))]
ProviderType::OpenAi => {
return Err(LlmError::UnsupportedOperation(
"OpenAI provider requires the 'openai' feature to be enabled".to_string(),
));
}
#[cfg(not(feature = "anthropic"))]
ProviderType::Anthropic => {
return Err(LlmError::UnsupportedOperation(
"Anthropic provider requires the 'anthropic' feature to be enabled".to_string(),
));
}
#[cfg(not(feature = "google"))]
ProviderType::Gemini => {
return Err(LlmError::UnsupportedOperation(
"Gemini provider requires the 'google' feature to be enabled".to_string(),
));
}
#[cfg(not(feature = "ollama"))]
ProviderType::Ollama => {
return Err(LlmError::UnsupportedOperation(
"Ollama provider requires the 'ollama' feature to be enabled".to_string(),
));
}
#[cfg(not(feature = "xai"))]
ProviderType::XAI => {
return Err(LlmError::UnsupportedOperation(
"xAI provider requires the 'xai' feature to be enabled".to_string(),
));
}
#[cfg(not(feature = "groq"))]
ProviderType::Groq => {
return Err(LlmError::UnsupportedOperation(
"Groq provider requires the 'groq' feature to be enabled".to_string(),
));
}
};
let siumai = Siumai::new(client).with_retry_options(self.retry_options.clone());
Ok(siumai)
}
}
pub struct AudioCapabilityProxy<'a> {
provider: &'a Siumai,
reported_support: bool,
}
impl<'a> AudioCapabilityProxy<'a> {
pub const fn new(provider: &'a Siumai, reported_support: bool) -> Self {
Self {
provider,
reported_support,
}
}
pub const fn is_reported_as_supported(&self) -> bool {
self.reported_support
}
pub fn provider_name(&self) -> &'static str {
self.provider.provider_name()
}
pub fn support_status_message(&self) -> String {
if self.reported_support {
format!("Provider {} reports audio support", self.provider_name())
} else {
format!(
"Provider {} does not report audio support, but this may still work depending on the model",
self.provider_name()
)
}
}
pub async fn placeholder_operation(&self) -> Result<String, LlmError> {
Err(LlmError::UnsupportedOperation(
"Audio operations not yet implemented. Use provider-specific client.".to_string(),
))
}
}
pub struct EmbeddingCapabilityProxy<'a> {
provider: &'a Siumai,
reported_support: bool,
}
impl<'a> EmbeddingCapabilityProxy<'a> {
pub const fn new(provider: &'a Siumai, reported_support: bool) -> Self {
Self {
provider,
reported_support,
}
}
pub const fn is_reported_as_supported(&self) -> bool {
self.reported_support
}
pub fn provider_name(&self) -> &'static str {
self.provider.provider_name()
}
pub fn support_status_message(&self) -> String {
if self.reported_support {
format!(
"Provider {} reports embedding support",
self.provider_name()
)
} else {
format!(
"Provider {} does not report embedding support, but this may still work depending on the model",
self.provider_name()
)
}
}
pub async fn embed(&self, texts: Vec<String>) -> Result<EmbeddingResponse, LlmError> {
self.provider.embed(texts).await
}
pub fn embedding_dimension(&self) -> usize {
self.provider.embedding_dimension()
}
pub fn max_tokens_per_embedding(&self) -> usize {
self.provider.max_tokens_per_embedding()
}
pub fn supported_embedding_models(&self) -> Vec<String> {
self.provider.supported_embedding_models()
}
#[deprecated(note = "Use embed() method instead")]
pub async fn placeholder_operation(&self) -> Result<String, LlmError> {
Err(LlmError::UnsupportedOperation(
"Use embed() method instead of placeholder_operation()".to_string(),
))
}
}
pub struct VisionCapabilityProxy<'a> {
provider: &'a Siumai,
reported_support: bool,
}
impl<'a> VisionCapabilityProxy<'a> {
pub const fn new(provider: &'a Siumai, reported_support: bool) -> Self {
Self {
provider,
reported_support,
}
}
pub const fn is_reported_as_supported(&self) -> bool {
self.reported_support
}
pub fn provider_name(&self) -> &'static str {
self.provider.provider_name()
}
pub fn support_status_message(&self) -> String {
if self.reported_support {
format!("Provider {} reports vision support", self.provider_name())
} else {
format!(
"Provider {} does not report vision support, but this may still work depending on the model",
self.provider_name()
)
}
}
pub async fn placeholder_operation(&self) -> Result<String, LlmError> {
Err(LlmError::UnsupportedOperation(
"Vision operations not yet implemented. Use provider-specific client.".to_string(),
))
}
}
impl Default for SiumaiBuilder {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for SiumaiBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut debug_struct = f.debug_struct("SiumaiBuilder");
debug_struct
.field("provider_type", &self.provider_type)
.field("provider_name", &self.provider_name)
.field("base_url", &self.base_url)
.field("model", &self.common_params.model)
.field("temperature", &self.common_params.temperature)
.field("max_tokens", &self.common_params.max_tokens)
.field("top_p", &self.common_params.top_p)
.field("seed", &self.common_params.seed)
.field("capabilities_count", &self.capabilities.len())
.field("reasoning_enabled", &self.reasoning_enabled)
.field("reasoning_budget", &self.reasoning_budget)
.field("has_tracing", &self.tracing_config.is_some())
.field("timeout", &self.http_config.timeout);
if self.api_key.is_some() {
debug_struct.field("has_api_key", &true);
}
if self.organization.is_some() {
debug_struct.field("has_organization", &true);
}
if self.project.is_some() {
debug_struct.field("has_project", &true);
}
debug_struct.finish()
}
}
pub struct ProviderRegistry {
factories: HashMap<String, Box<dyn ProviderFactory>>,
}
pub trait ProviderFactory: Send + Sync {
fn create_provider(&self, config: ProviderConfig) -> Result<Box<dyn LlmClient>, LlmError>;
fn supported_capabilities(&self) -> Vec<String>;
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub api_key: String,
pub base_url: Option<String>,
pub model: Option<String>,
pub capabilities: Vec<String>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
factories: HashMap::new(),
}
}
pub fn register<S: Into<String>>(&mut self, name: S, factory: Box<dyn ProviderFactory>) {
self.factories.insert(name.into(), factory);
}
pub fn create_provider(&self, name: &str, config: ProviderConfig) -> Result<Siumai, LlmError> {
let factory = self
.factories
.get(name)
.ok_or_else(|| LlmError::ConfigurationError(format!("Unknown provider: {name}")))?;
let client = factory.create_provider(config)?;
Ok(Siumai::new(client))
}
pub fn supported_providers(&self) -> Vec<String> {
self.factories.keys().cloned().collect()
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
#[derive(Debug)]
struct MockProvider;
#[async_trait]
impl ChatCapability for MockProvider {
async fn chat_with_tools(
&self,
_messages: Vec<ChatMessage>,
_tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
Ok(ChatResponse {
id: Some("mock-123".to_string()),
content: MessageContent::Text("Mock response".to_string()),
model: Some("mock-model".to_string()),
usage: None,
finish_reason: Some(FinishReason::Stop),
tool_calls: None,
thinking: None,
metadata: std::collections::HashMap::new(),
})
}
async fn chat_stream(
&self,
_messages: Vec<ChatMessage>,
_tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError> {
Err(LlmError::UnsupportedOperation(
"Streaming not supported in mock".to_string(),
))
}
}
impl LlmClient for MockProvider {
fn provider_name(&self) -> &'static str {
"mock"
}
fn supported_models(&self) -> Vec<String> {
vec!["mock-model".to_string()]
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::new().with_chat()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn LlmClient> {
Box::new(MockProvider)
}
}
#[tokio::test]
async fn test_siumai_embedding_unsupported_provider() {
let mock_provider = MockProvider;
let siumai = Siumai::new(Box::new(mock_provider));
let result = siumai.embed(vec!["test".to_string()]).await;
assert!(result.is_err());
if let Err(LlmError::UnsupportedOperation(msg)) = result {
assert!(msg.contains("does not support embedding functionality"));
} else {
panic!("Expected UnsupportedOperation error");
}
}
#[test]
fn test_embedding_capability_proxy() {
let mock_provider = MockProvider;
let siumai = Siumai::new(Box::new(mock_provider));
let proxy = siumai.embedding_capability();
assert_eq!(proxy.provider_name(), "custom"); assert!(!proxy.is_reported_as_supported()); }
#[tokio::test]
async fn test_embedding_capability_proxy_embed() {
let mock_provider = MockProvider;
let siumai = Siumai::new(Box::new(mock_provider));
let proxy = siumai.embedding_capability();
let result = proxy.embed(vec!["test".to_string()]).await;
assert!(result.is_err());
if let Err(LlmError::UnsupportedOperation(msg)) = result {
assert!(msg.contains("does not support embedding functionality"));
} else {
panic!("Expected UnsupportedOperation error");
}
}
#[tokio::test]
async fn test_ollama_build_without_api_key() {
let result = SiumaiBuilder::new()
.ollama()
.model("llama3.2")
.build()
.await;
match result {
Ok(_) => {
}
Err(LlmError::ConfigurationError(msg)) => {
assert!(
!msg.contains("API key not specified"),
"Ollama should not require API key, but got: {}",
msg
);
}
Err(_) => {
}
}
}
#[tokio::test]
async fn test_openai_requires_api_key() {
let result = SiumaiBuilder::new().openai().model("gpt-4o").build().await;
assert!(result.is_err());
if let Err(LlmError::ConfigurationError(msg)) = result {
assert!(msg.contains("API key not specified"));
} else {
panic!("Expected ConfigurationError for missing API key");
}
}
}