use crate::types::RoutingError;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use terraphim_types::capability::{Capability, CostLevel, Latency, Provider, ProviderType};
use tokio::sync::broadcast;
#[cfg(feature = "persistence")]
use async_trait::async_trait;
#[cfg(feature = "persistence")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "persistence")]
use terraphim_persistence::Persistable;
#[derive(Debug, Clone)]
pub enum RegistryEvent {
ProviderAdded { provider_id: String },
ProviderRemoved { provider_id: String },
ProvidersReloaded { count: usize },
}
#[derive(Debug, Clone, Default)]
pub struct ProviderRegistry {
providers: HashMap<String, Provider>,
source_path: Option<PathBuf>,
change_sender: Option<broadcast::Sender<RegistryEvent>>,
}
#[derive(Debug, Clone)]
pub struct MarkdownProvider {
pub frontmatter: ProviderFrontmatter,
pub body: String,
pub path: PathBuf,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct ProviderFrontmatter {
pub id: String,
pub name: String,
#[serde(rename = "type")]
pub provider_type: String,
pub model_id: Option<String>,
pub api_endpoint: Option<String>,
pub agent_id: Option<String>,
pub cli_command: Option<String>,
pub working_dir: Option<PathBuf>,
pub capabilities: Vec<String>,
#[serde(default)]
pub cost: String,
#[serde(default)]
pub latency: String,
pub keywords: Vec<String>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
source_path: None,
change_sender: None,
}
}
pub fn with_path(path: impl Into<PathBuf>) -> Self {
Self {
providers: HashMap::new(),
source_path: Some(path.into()),
change_sender: None,
}
}
pub fn with_change_notifications(mut self, capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
self.change_sender = Some(sender);
self
}
pub fn subscribe_changes(&self) -> Option<broadcast::Receiver<RegistryEvent>> {
self.change_sender.as_ref().map(|s| s.subscribe())
}
pub fn add_provider(&mut self, provider: Provider) {
let id = provider.id.clone();
self.providers.insert(id.clone(), provider);
if let Some(sender) = &self.change_sender {
let _ = sender.send(RegistryEvent::ProviderAdded { provider_id: id });
}
}
pub fn get(&self, id: &str) -> Option<&Provider> {
self.providers.get(id)
}
pub fn source_path(&self) -> Option<&Path> {
self.source_path.as_deref()
}
pub fn remove_provider(&mut self, id: &str) -> Option<Provider> {
let removed = self.providers.remove(id);
if removed.is_some() {
if let Some(sender) = &self.change_sender {
let _ = sender.send(RegistryEvent::ProviderRemoved {
provider_id: id.to_string(),
});
}
}
removed
}
pub fn all(&self) -> Vec<&Provider> {
self.providers.values().collect()
}
pub fn find_by_capability(&self, capability: &Capability) -> Vec<&Provider> {
self.providers
.values()
.filter(|p| p.has_capability(capability))
.collect()
}
pub fn find_by_capabilities(&self, capabilities: &[Capability]) -> Vec<&Provider> {
self.providers
.values()
.filter(|p| capabilities.iter().any(|c| p.has_capability(c)))
.collect()
}
pub async fn load_from_dir(&mut self, dir: impl AsRef<Path>) -> Result<usize, RoutingError> {
let dir = dir.as_ref();
tracing::info!(directory = ?dir, "Loading providers from directory");
let mut count = 0;
let mut entries = tokio::fs::read_dir(dir)
.await
.map_err(|e| RoutingError::Io(e.to_string()))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| RoutingError::Io(e.to_string()))?
{
let path = entry.path();
if path.extension().map(|e| e == "md").unwrap_or(false) {
match Self::load_markdown_file(&path).await {
Ok(markdown) => match Self::provider_from_markdown(markdown) {
Ok(provider) => {
self.add_provider(provider);
count += 1;
}
Err(e) => {
tracing::warn!(
file_path = ?path,
error = %e,
"Failed to parse provider from markdown file"
);
}
},
Err(e) => {
tracing::warn!(
file_path = ?path,
error = %e,
"Failed to load markdown file"
);
}
}
}
}
if count > 0 {
if let Some(sender) = &self.change_sender {
let _ = sender.send(RegistryEvent::ProvidersReloaded { count });
}
}
Ok(count)
}
async fn load_markdown_file(path: &Path) -> Result<MarkdownProvider, RoutingError> {
let content = tokio::fs::read_to_string(path)
.await
.map_err(|e| RoutingError::Io(e.to_string()))?;
Self::parse_markdown(content, path.to_path_buf())
}
fn parse_markdown(content: String, path: PathBuf) -> Result<MarkdownProvider, RoutingError> {
if !content.starts_with("---") {
return Err(RoutingError::RegistryError(format!(
"No YAML frontmatter found in {:?}",
path
)));
}
let rest = &content[3..]; let Some(end_pos) = rest.find("---") else {
return Err(RoutingError::RegistryError(format!(
"Unclosed YAML frontmatter in {:?}",
path
)));
};
let yaml_content = &rest[..end_pos];
let body = &rest[end_pos + 3..];
let frontmatter: ProviderFrontmatter = serde_yaml::from_str(yaml_content).map_err(|e| {
RoutingError::Serialization(format!("Failed to parse YAML frontmatter: {}", e))
})?;
Ok(MarkdownProvider {
frontmatter,
body: body.trim().to_string(),
path,
})
}
fn provider_from_markdown(markdown: MarkdownProvider) -> Result<Provider, RoutingError> {
let fm = markdown.frontmatter;
let provider_type = match fm.provider_type.as_str() {
"llm" => {
let model_id = fm.model_id.ok_or_else(|| {
RoutingError::RegistryError("LLM provider missing model_id".to_string())
})?;
let api_endpoint = fm
.api_endpoint
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
ProviderType::Llm {
model_id,
api_endpoint,
}
}
"agent" => {
let agent_id = fm.agent_id.clone().ok_or_else(|| {
RoutingError::RegistryError("Agent provider missing agent_id".to_string())
})?;
let cli_command = fm.cli_command.ok_or_else(|| {
RoutingError::RegistryError("Agent provider missing cli_command".to_string())
})?;
let working_dir = fm.working_dir.unwrap_or_else(|| PathBuf::from("/tmp"));
ProviderType::Agent {
agent_id,
cli_command,
working_dir,
}
}
other => {
return Err(RoutingError::RegistryError(format!(
"Unknown provider type: {}",
other
)));
}
};
let capabilities = fm
.capabilities
.iter()
.filter_map(|c| Self::parse_capability(c))
.collect();
let cost_level = match fm.cost.to_lowercase().as_str() {
"cheap" | "low" => CostLevel::Cheap,
"expensive" | "high" => CostLevel::Expensive,
_ => CostLevel::Moderate,
};
let latency = match fm.latency.to_lowercase().as_str() {
"fast" | "quick" => Latency::Fast,
"slow" => Latency::Slow,
_ => Latency::Medium,
};
Ok(Provider {
id: fm.id,
name: fm.name,
provider_type,
capabilities,
cost_level,
latency,
keywords: fm.keywords,
})
}
fn parse_capability(s: &str) -> Option<Capability> {
match s.to_lowercase().replace("-", "_").as_str() {
"deep_thinking" | "deepthinking" => Some(Capability::DeepThinking),
"fast_thinking" | "fastthinking" => Some(Capability::FastThinking),
"code_generation" | "codegeneration" => Some(Capability::CodeGeneration),
"code_review" | "codereview" => Some(Capability::CodeReview),
"architecture" => Some(Capability::Architecture),
"testing" => Some(Capability::Testing),
"refactoring" => Some(Capability::Refactoring),
"documentation" => Some(Capability::Documentation),
"explanation" => Some(Capability::Explanation),
"security_audit" | "securityaudit" => Some(Capability::SecurityAudit),
"performance" => Some(Capability::Performance),
_ => {
tracing::warn!(capability_string = s, "Unknown capability string, skipping");
None
}
}
}
#[cfg(feature = "persistence")]
pub async fn with_persistence(registry_id: &str) -> Result<Self, RoutingError> {
let mut persisted = PersistedProviderRegistry::new(registry_id.to_string());
let providers = match persisted.load().await {
Ok(loaded) => {
tracing::info!(
registry_id = registry_id,
count = loaded.providers.len(),
"Loaded providers from persistent storage"
);
loaded.providers
}
Err(_) => {
tracing::debug!(
registry_id = registry_id,
"No persisted providers found, starting empty"
);
HashMap::new()
}
};
Ok(Self {
providers,
source_path: None,
change_sender: None,
})
}
#[cfg(feature = "persistence")]
pub async fn add_provider_persisted(
&mut self,
provider: Provider,
registry_id: &str,
) -> Result<(), RoutingError> {
self.providers.insert(provider.id.clone(), provider);
self.persist_all(registry_id).await
}
#[cfg(feature = "persistence")]
pub async fn remove_provider_persisted(
&mut self,
provider_id: &str,
registry_id: &str,
) -> Result<Option<Provider>, RoutingError> {
let removed = self.providers.remove(provider_id);
self.persist_all(registry_id).await?;
Ok(removed)
}
#[cfg(feature = "persistence")]
async fn persist_all(&self, registry_id: &str) -> Result<(), RoutingError> {
let persisted = PersistedProviderRegistry {
registry_id: registry_id.to_string(),
providers: self.providers.clone(),
};
persisted
.save()
.await
.map_err(|e| RoutingError::RegistryError(format!("Persistence error: {}", e)))
}
pub async fn load_default() -> Result<Self, RoutingError> {
let mut registry = Self::new();
let home = dirs::home_dir()
.ok_or_else(|| RoutingError::Io("Could not find home directory".to_string()))?;
let providers_dir = home.join(".terraphim").join("providers");
if providers_dir.exists() {
registry.load_from_dir(&providers_dir).await?;
}
Ok(registry)
}
}
#[cfg(feature = "persistence")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedProviderRegistry {
pub registry_id: String,
pub providers: HashMap<String, Provider>,
}
#[cfg(feature = "persistence")]
#[async_trait]
impl Persistable for PersistedProviderRegistry {
fn new(key: String) -> Self {
Self {
registry_id: key,
providers: HashMap::new(),
}
}
async fn save(&self) -> terraphim_persistence::Result<()> {
self.save_to_all().await
}
async fn save_to_one(&self, profile_name: &str) -> terraphim_persistence::Result<()> {
self.save_to_profile(profile_name).await
}
async fn load(&mut self) -> terraphim_persistence::Result<Self>
where
Self: Sized,
{
let op = &self.load_config().await?.1;
let key = self.get_key();
self.load_from_operator(&key, op).await
}
fn get_key(&self) -> String {
let normalized = self.normalize_key(&self.registry_id);
format!("provider_registry_{}.json", normalized)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_parse_markdown_with_frontmatter() {
let content = r#"---
id: "claude-opus"
name: "Claude Opus"
type: "llm"
model_id: "claude-3-opus-20240229"
api_endpoint: "https://api.anthropic.com/v1"
capabilities:
- deep-thinking
- code-generation
cost: expensive
latency: slow
keywords:
- think
- reasoning
---
# Claude Opus
Anthropic's most capable model.
"#;
let markdown =
ProviderRegistry::parse_markdown(content.to_string(), PathBuf::from("test.md"))
.unwrap();
assert_eq!(markdown.frontmatter.id, "claude-opus");
assert_eq!(markdown.frontmatter.provider_type, "llm");
assert_eq!(markdown.frontmatter.capabilities.len(), 2);
assert!(markdown.body.contains("Anthropic's most capable model"));
}
#[test]
fn test_provider_from_markdown_llm() {
let markdown = MarkdownProvider {
frontmatter: ProviderFrontmatter {
id: "test-llm".to_string(),
name: "Test LLM".to_string(),
provider_type: "llm".to_string(),
model_id: Some("gpt-4".to_string()),
api_endpoint: Some("https://api.openai.com".to_string()),
agent_id: None,
cli_command: None,
working_dir: None,
capabilities: vec!["code-generation".to_string()],
cost: "moderate".to_string(),
latency: "medium".to_string(),
keywords: vec!["code".to_string()],
},
body: "Test body".to_string(),
path: PathBuf::from("test.md"),
};
let provider = ProviderRegistry::provider_from_markdown(markdown).unwrap();
assert_eq!(provider.id, "test-llm");
assert!(provider.has_capability(&Capability::CodeGeneration));
assert_eq!(provider.cost_level, CostLevel::Moderate);
}
#[test]
fn test_provider_from_markdown_agent() {
let markdown = MarkdownProvider {
frontmatter: ProviderFrontmatter {
id: "@coder".to_string(),
name: "Coder Agent".to_string(),
provider_type: "agent".to_string(),
model_id: None,
api_endpoint: None,
agent_id: Some("@coder".to_string()),
cli_command: Some("opencode".to_string()),
working_dir: Some(PathBuf::from("/workspace")),
capabilities: vec!["code-generation".to_string(), "code-review".to_string()],
cost: "cheap".to_string(),
latency: "fast".to_string(),
keywords: vec!["implement".to_string()],
},
body: "Test body".to_string(),
path: PathBuf::from("test.md"),
};
let provider = ProviderRegistry::provider_from_markdown(markdown).unwrap();
assert_eq!(provider.id, "@coder");
assert!(matches!(provider.provider_type, ProviderType::Agent { .. }));
assert_eq!(provider.cost_level, CostLevel::Cheap);
}
#[tokio::test]
async fn test_load_from_dir() {
let temp_dir = tempfile::tempdir().unwrap();
let mut file = NamedTempFile::new_in(temp_dir.path()).unwrap();
file.write_all(
r#"---
id: "test-provider"
name: "Test Provider"
type: "llm"
model_id: "test-model"
capabilities:
- code-generation
cost: cheap
latency: fast
keywords:
- test
---
# Test Provider
This is a test.
"#
.as_bytes(),
)
.unwrap();
let md_path = temp_dir.path().join("test.md");
std::fs::rename(file.path(), &md_path).unwrap();
let mut registry = ProviderRegistry::new();
let count = registry.load_from_dir(temp_dir.path()).await.unwrap();
assert_eq!(count, 1);
assert!(registry.get("test-provider").is_some());
}
#[test]
fn test_registry_change_notifications() {
let mut registry = ProviderRegistry::new().with_change_notifications(16);
let mut rx = registry.subscribe_changes().unwrap();
let provider = Provider::new(
"test-llm",
"Test LLM",
ProviderType::Llm {
model_id: "gpt-4".to_string(),
api_endpoint: "https://api.openai.com".to_string(),
},
vec![Capability::CodeGeneration],
);
registry.add_provider(provider);
match rx.try_recv() {
Ok(RegistryEvent::ProviderAdded { provider_id }) => {
assert_eq!(provider_id, "test-llm");
}
other => panic!("Expected ProviderAdded, got {:?}", other),
}
let removed = registry.remove_provider("test-llm");
assert!(removed.is_some());
match rx.try_recv() {
Ok(RegistryEvent::ProviderRemoved { provider_id }) => {
assert_eq!(provider_id, "test-llm");
}
other => panic!("Expected ProviderRemoved, got {:?}", other),
}
let removed = registry.remove_provider("nonexistent");
assert!(removed.is_none());
assert!(rx.try_recv().is_err());
}
#[test]
fn test_subscribe_changes_none_when_not_enabled() {
let registry = ProviderRegistry::new();
assert!(registry.subscribe_changes().is_none());
}
#[cfg(feature = "persistence")]
mod persistence_tests {
use super::*;
use serial_test::serial;
use terraphim_persistence::Persistable;
async fn init_test_persistence() {
terraphim_persistence::DeviceStorage::init_memory_only()
.await
.expect("Failed to initialize memory-only DeviceStorage");
}
fn test_provider(id: &str) -> Provider {
Provider::new(
id,
format!("Test {}", id),
ProviderType::Llm {
model_id: format!("model-{}", id),
api_endpoint: "https://api.example.com".to_string(),
},
vec![Capability::CodeGeneration, Capability::CodeReview],
)
}
#[tokio::test]
#[serial]
async fn test_persist_and_load_empty_registry() {
init_test_persistence().await;
let persisted = PersistedProviderRegistry {
registry_id: "test-empty".to_string(),
providers: HashMap::new(),
};
persisted.save().await.unwrap();
let mut loaded = PersistedProviderRegistry::new("test-empty".to_string());
loaded = loaded.load().await.unwrap();
assert_eq!(loaded.providers.len(), 0);
assert_eq!(loaded.registry_id, "test-empty");
}
#[tokio::test]
#[serial]
async fn test_persist_and_load_with_providers() {
init_test_persistence().await;
let mut providers = HashMap::new();
providers.insert("p1".to_string(), test_provider("p1"));
providers.insert("p2".to_string(), test_provider("p2"));
let persisted = PersistedProviderRegistry {
registry_id: "test-with-providers".to_string(),
providers,
};
persisted.save().await.unwrap();
let mut loaded = PersistedProviderRegistry::new("test-with-providers".to_string());
loaded = loaded.load().await.unwrap();
assert_eq!(loaded.providers.len(), 2);
assert!(loaded.providers.contains_key("p1"));
assert!(loaded.providers.contains_key("p2"));
assert_eq!(loaded.providers["p1"].name, "Test p1");
}
#[tokio::test]
#[serial]
async fn test_with_persistence_loads_existing() {
init_test_persistence().await;
let mut providers = HashMap::new();
providers.insert("existing".to_string(), test_provider("existing"));
let persisted = PersistedProviderRegistry {
registry_id: "test-load".to_string(),
providers,
};
persisted.save().await.unwrap();
let registry = ProviderRegistry::with_persistence("test-load")
.await
.unwrap();
assert_eq!(registry.all().len(), 1);
assert!(registry.get("existing").is_some());
}
#[tokio::test]
#[serial]
async fn test_with_persistence_starts_empty_when_no_data() {
init_test_persistence().await;
let registry = ProviderRegistry::with_persistence("test-nonexistent")
.await
.unwrap();
assert_eq!(registry.all().len(), 0);
}
#[tokio::test]
#[serial]
async fn test_add_provider_persisted_roundtrip() {
init_test_persistence().await;
let mut registry = ProviderRegistry::with_persistence("test-add")
.await
.unwrap();
registry
.add_provider_persisted(test_provider("new-provider"), "test-add")
.await
.unwrap();
assert_eq!(registry.all().len(), 1);
let registry2 = ProviderRegistry::with_persistence("test-add")
.await
.unwrap();
assert_eq!(registry2.all().len(), 1);
assert!(registry2.get("new-provider").is_some());
}
#[tokio::test]
#[serial]
async fn test_remove_provider_persisted() {
init_test_persistence().await;
let mut registry = ProviderRegistry::with_persistence("test-remove")
.await
.unwrap();
registry
.add_provider_persisted(test_provider("to-remove"), "test-remove")
.await
.unwrap();
assert_eq!(registry.all().len(), 1);
let removed = registry
.remove_provider_persisted("to-remove", "test-remove")
.await
.unwrap();
assert!(removed.is_some());
assert_eq!(registry.all().len(), 0);
let registry2 = ProviderRegistry::with_persistence("test-remove")
.await
.unwrap();
assert_eq!(registry2.all().len(), 0);
}
#[tokio::test]
#[serial]
async fn test_backward_compat_non_persisted_still_works() {
let mut registry = ProviderRegistry::new();
registry.add_provider(test_provider("in-memory"));
assert_eq!(registry.all().len(), 1);
assert!(registry.get("in-memory").is_some());
}
}
}