pub mod builtins;
pub mod cli_base;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use tracing::warn;
use crate::agents::NsedAgent;
use crate::agents::config::AgentConfig;
use crate::config::ProviderEntry;
pub trait ProviderFactory: Send + Sync {
fn provider_type(&self) -> &str;
fn requires_api_key(&self) -> bool {
false
}
fn build_agent(
&self,
agent_config: &AgentConfig,
provider: &ProviderEntry,
) -> Result<Option<Arc<dyn NsedAgent>>>;
}
pub struct ProviderRegistry {
factories: HashMap<String, Arc<dyn ProviderFactory>>,
}
impl ProviderRegistry {
pub fn empty() -> Self {
Self {
factories: HashMap::new(),
}
}
pub fn with_builtins() -> Self {
let mut registry = Self::empty();
registry.register(Arc::new(builtins::ExecFactory));
registry.register(Arc::new(builtins::McpFactory));
registry.register(Arc::new(builtins::ClaudeFactory));
registry.register(Arc::new(builtins::OpenAiCompatibleFactory::new(
"openai", true,
)));
registry.register(Arc::new(builtins::OpenAiCompatibleFactory::new(
"ollama", false,
)));
registry.register(Arc::new(builtins::OpenAiCompatibleFactory::new(
"simulated",
false,
)));
registry
}
pub fn register(&mut self, factory: Arc<dyn ProviderFactory>) {
self.factories
.insert(factory.provider_type().to_string(), factory);
}
pub fn get(&self, provider_type: &str) -> Option<&Arc<dyn ProviderFactory>> {
self.factories.get(provider_type)
}
pub fn provider_types(&self) -> Vec<String> {
let mut types: Vec<String> = self.factories.keys().cloned().collect();
types.sort();
types
}
pub fn is_local(&self, provider_type: &str) -> bool {
self.factories
.get(provider_type)
.map(|factory| !factory.requires_api_key())
.unwrap_or(false)
}
pub fn build_agent(
&self,
provider_type: &str,
agent_config: &AgentConfig,
provider: &ProviderEntry,
) -> Result<Option<Arc<dyn NsedAgent>>> {
match self.factories.get(provider_type) {
Some(factory) => factory.build_agent(agent_config, provider),
None => {
warn!(
agent = %agent_config.name,
provider_type = %provider_type,
supported = ?self.provider_types(),
"unknown provider_type — skipping"
);
Ok(None)
}
}
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::with_builtins()
}
}
impl std::fmt::Debug for ProviderRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProviderRegistry")
.field("provider_types", &self.provider_types())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn resolve(yaml: &str, agent: &str) -> (AgentConfig, ProviderEntry) {
let fleet: crate::config::AgentFleetConfig =
serde_yaml::from_str(yaml).expect("fleet yaml must parse");
crate::config::load_agent_from_config(&fleet, agent).expect("agent must resolve")
}
#[test]
fn builtins_registers_all_provider_types() {
let registry = ProviderRegistry::with_builtins();
assert_eq!(
registry.provider_types(),
vec!["claude", "exec", "mcp", "ollama", "openai", "simulated"]
);
}
#[test]
fn is_local_matches_old_dispatch_exemptions() {
let registry = ProviderRegistry::with_builtins();
for local in ["exec", "mcp", "claude", "ollama", "simulated"] {
assert!(registry.is_local(local), "{local} must be local");
}
assert!(!registry.is_local("openai"));
assert!(!registry.is_local("definitely-not-a-provider"));
}
#[test]
fn requires_api_key_only_for_openai() {
let registry = ProviderRegistry::with_builtins();
assert!(registry.get("openai").unwrap().requires_api_key());
for local in ["exec", "mcp", "claude", "ollama", "simulated"] {
assert!(!registry.get(local).unwrap().requires_api_key());
}
}
#[test]
fn unknown_provider_type_skips_cleanly() {
let registry = ProviderRegistry::with_builtins();
let (cfg, provider) = resolve(
r#"
providers:
bogus:
type: not-a-real-provider
api_key: "sk-real-key"
agents:
- name: ghost
provider_id: bogus
model_name: custom
"#,
"ghost",
);
let result = registry
.build_agent("not-a-real-provider", &cfg, &provider)
.expect("unknown type must not error");
assert!(result.is_none(), "unknown provider_type must skip (None)");
}
#[test]
fn register_overrides_existing_factory() {
struct Fake;
impl ProviderFactory for Fake {
fn provider_type(&self) -> &str {
"exec"
}
fn requires_api_key(&self) -> bool {
true
}
fn build_agent(
&self,
_agent_config: &AgentConfig,
_provider: &ProviderEntry,
) -> Result<Option<Arc<dyn NsedAgent>>> {
Ok(None)
}
}
let mut registry = ProviderRegistry::with_builtins();
assert!(registry.is_local("exec"));
registry.register(Arc::new(Fake));
assert!(!registry.is_local("exec"));
}
#[test]
fn third_party_factory_is_dispatchable() {
#[derive(Debug, Clone)]
struct CustomAgent;
#[async_trait::async_trait]
impl NsedAgent for CustomAgent {
async fn propose(
&self,
_context: &crate::agents::AgentContext,
) -> Result<crate::agents::Proposal> {
unreachable!("not exercised in this test")
}
async fn evaluate(
&self,
_context: &crate::agents::AgentContext,
) -> Result<Vec<(String, crate::agents::Evaluation)>> {
unreachable!("not exercised in this test")
}
fn name(&self) -> String {
"custom-agent".to_string()
}
}
struct CustomFactory;
impl ProviderFactory for CustomFactory {
fn provider_type(&self) -> &str {
"custom"
}
fn build_agent(
&self,
_agent_config: &AgentConfig,
_provider: &ProviderEntry,
) -> Result<Option<Arc<dyn NsedAgent>>> {
Ok(Some(Arc::new(CustomAgent)))
}
}
let mut registry = ProviderRegistry::empty();
registry.register(Arc::new(CustomFactory));
let (cfg, provider) = resolve(
r#"
providers:
mine:
type: custom
api_key: "sk-real-key"
agents:
- name: my-agent
provider_id: mine
model_name: custom
"#,
"my-agent",
);
let agent = registry
.build_agent("custom", &cfg, &provider)
.expect("custom factory must build")
.expect("custom factory must yield an agent");
assert_eq!(agent.name(), "custom-agent");
}
}