use std::{collections::HashMap, sync::Arc};
use crate::{error::BackendConstructError, provider::LlmBackend};
type BackendBuilder = fn(
reqwest::ClientBuilder,
&str,
Option<&str>,
) -> Result<Arc<dyn LlmBackend>, BackendConstructError>;
pub struct BackendFactory {
builders: HashMap<&'static str, BackendBuilder>,
}
impl BackendFactory {
pub fn new() -> Self {
#[cfg(any(feature = "deepseek", feature = "openai-compat"))]
{
let mut factory = Self::empty();
#[cfg(feature = "deepseek")]
factory.register::<crate::provider::DeepSeekBackend>();
#[cfg(feature = "openai-compat")]
factory.register::<crate::provider::OpenAiCompatBackend>();
factory
}
#[cfg(not(any(feature = "deepseek", feature = "openai-compat")))]
{
Self::empty()
}
}
pub fn empty() -> Self {
Self {
builders: HashMap::new(),
}
}
pub fn register<C: LlmBackend>(&mut self) -> &mut Self {
self.builders.insert(<C as LlmBackend>::family(), C::new);
self
}
pub fn create(
&self,
family: &str,
http: reqwest::ClientBuilder,
api_key: &str,
base_url: Option<&str>,
) -> Result<Arc<dyn LlmBackend>, BackendConstructError> {
let build = self
.builders
.get(family)
.copied()
.ok_or_else(|| BackendConstructError::unknown_family(family.to_owned()))?;
build(http, api_key, base_url)
}
pub fn families(&self) -> impl Iterator<Item = &str> {
self.builders.keys().copied()
}
}
impl Default for BackendFactory {
fn default() -> Self {
Self::new()
}
}