use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct ModelProfile {
pub model_id: String,
pub provider: String,
pub supports_tools: bool,
pub supports_streaming: bool,
pub supports_structured_output: bool,
pub max_context_tokens: Option<u64>,
pub max_output_tokens: Option<u64>,
}
pub mod capabilities {
pub const TOOLS: &str = "tools";
pub const STREAMING: &str = "streaming";
pub const STRUCTURED_OUTPUT: &str = "structured_output";
}
pub trait ModelProfileRegistry: Send + Sync {
fn register(&self, profile: ModelProfile);
fn get(&self, model_id: &str) -> Option<ModelProfile>;
fn supports(&self, model_id: &str, capability: &str) -> bool;
}
#[derive(Debug, Clone, Default)]
pub struct InMemoryModelProfileRegistry {
profiles: Arc<RwLock<HashMap<String, ModelProfile>>>,
}
impl InMemoryModelProfileRegistry {
pub fn new() -> Self {
Self::default()
}
}
impl ModelProfileRegistry for InMemoryModelProfileRegistry {
fn register(&self, profile: ModelProfile) {
let Ok(mut map) = self.profiles.write() else {
return;
};
let _ = map.insert(profile.model_id.clone(), profile);
}
fn get(&self, model_id: &str) -> Option<ModelProfile> {
let map = self.profiles.read().ok()?;
map.get(model_id).cloned()
}
fn supports(&self, model_id: &str, capability: &str) -> bool {
let Some(profile) = self.get(model_id) else {
return false;
};
match capability {
capabilities::TOOLS => profile.supports_tools,
capabilities::STREAMING => profile.supports_streaming,
capabilities::STRUCTURED_OUTPUT => profile.supports_structured_output,
_ => false,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn sample_profile() -> ModelProfile {
ModelProfile {
model_id: "gpt-4o".into(),
provider: "openai".into(),
supports_tools: true,
supports_streaming: true,
supports_structured_output: false,
max_context_tokens: Some(128_000),
max_output_tokens: Some(16_384),
}
}
#[test]
fn register_and_retrieve() {
let registry = InMemoryModelProfileRegistry::new();
registry.register(sample_profile());
let profile = registry.get("gpt-4o").unwrap();
assert_eq!(profile.model_id, "gpt-4o");
assert_eq!(profile.provider, "openai");
assert!(profile.supports_tools);
assert_eq!(profile.max_context_tokens, Some(128_000));
}
#[test]
fn get_returns_none_for_unknown_model() {
let registry = InMemoryModelProfileRegistry::new();
assert!(registry.get("nonexistent-model").is_none());
}
#[test]
fn supports_tools() {
let registry = InMemoryModelProfileRegistry::new();
registry.register(sample_profile());
assert!(registry.supports("gpt-4o", "tools"));
}
#[test]
fn supports_streaming() {
let registry = InMemoryModelProfileRegistry::new();
registry.register(sample_profile());
assert!(registry.supports("gpt-4o", "streaming"));
}
#[test]
fn supports_structured_output_false() {
let registry = InMemoryModelProfileRegistry::new();
registry.register(sample_profile());
assert!(!registry.supports("gpt-4o", "structured_output"));
}
#[test]
fn supports_unknown_capability() {
let registry = InMemoryModelProfileRegistry::new();
registry.register(sample_profile());
assert!(!registry.supports("gpt-4o", "vision"));
}
#[test]
fn supports_unknown_model() {
let registry = InMemoryModelProfileRegistry::new();
assert!(!registry.supports("nonexistent", "tools"));
}
#[test]
fn register_replaces_existing() {
let registry = InMemoryModelProfileRegistry::new();
registry.register(sample_profile());
let mut updated = sample_profile();
updated.supports_structured_output = true;
registry.register(updated);
assert!(registry.supports("gpt-4o", "structured_output"));
}
#[test]
fn clone_shares_state() {
let registry = InMemoryModelProfileRegistry::new();
let cloned = registry.clone();
registry.register(sample_profile());
assert!(cloned.get("gpt-4o").is_some());
}
}