use crate::core::types::{model::ModelInfo, model::ProviderCapability};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct OpenAILikeModelRegistry {
known_models: HashMap<String, OpenAILikeModelConfig>,
default_context_length: u32,
default_output_length: u32,
}
#[derive(Debug, Clone)]
pub struct OpenAILikeModelConfig {
pub id: String,
pub max_context_length: u32,
pub max_output_length: Option<u32>,
pub supports_streaming: bool,
pub supports_tools: bool,
pub supports_multimodal: bool,
pub input_cost_per_1k: Option<f64>,
pub output_cost_per_1k: Option<f64>,
}
impl Default for OpenAILikeModelConfig {
fn default() -> Self {
Self {
id: "unknown".to_string(),
max_context_length: 4096,
max_output_length: Some(4096),
supports_streaming: true,
supports_tools: false,
supports_multimodal: false,
input_cost_per_1k: None,
output_cost_per_1k: None,
}
}
}
impl Default for OpenAILikeModelRegistry {
fn default() -> Self {
Self::new()
}
}
impl OpenAILikeModelRegistry {
pub fn new() -> Self {
Self {
known_models: HashMap::new(),
default_context_length: 4096,
default_output_length: 4096,
}
}
pub fn with_defaults() -> Self {
let mut registry = Self::new();
registry.default_context_length = 128000; registry.default_output_length = 4096;
registry
}
pub fn with_default_context_length(mut self, length: u32) -> Self {
self.default_context_length = length;
self
}
pub fn with_default_output_length(mut self, length: u32) -> Self {
self.default_output_length = length;
self
}
pub fn register_model(&mut self, config: OpenAILikeModelConfig) {
self.known_models.insert(config.id.clone(), config);
}
pub fn get_model_info(&self, model_id: &str) -> ModelInfo {
if let Some(config) = self.known_models.get(model_id) {
ModelInfo {
id: config.id.clone(),
name: config.id.clone(),
provider: "openai_like".to_string(),
max_context_length: config.max_context_length,
max_output_length: config.max_output_length,
supports_streaming: config.supports_streaming,
supports_tools: config.supports_tools,
supports_multimodal: config.supports_multimodal,
capabilities: self.build_capabilities(config),
input_cost_per_1k_tokens: config.input_cost_per_1k,
output_cost_per_1k_tokens: config.output_cost_per_1k,
currency: "USD".to_string(),
created_at: None,
updated_at: None,
metadata: HashMap::new(),
}
} else {
self.create_default_model_info(model_id)
}
}
fn create_default_model_info(&self, model_id: &str) -> ModelInfo {
ModelInfo {
id: model_id.to_string(),
name: model_id.to_string(),
provider: "openai_like".to_string(),
max_context_length: self.default_context_length,
max_output_length: Some(self.default_output_length),
supports_streaming: true, supports_tools: true, supports_multimodal: false,
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
],
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
created_at: None,
updated_at: None,
metadata: HashMap::new(),
}
}
fn build_capabilities(&self, config: &OpenAILikeModelConfig) -> Vec<ProviderCapability> {
let mut capabilities = vec![ProviderCapability::ChatCompletion];
if config.supports_streaming {
capabilities.push(ProviderCapability::ChatCompletionStream);
}
if config.supports_tools {
capabilities.push(ProviderCapability::ToolCalling);
capabilities.push(ProviderCapability::FunctionCalling);
}
capabilities
}
pub fn is_known_model(&self, model_id: &str) -> bool {
self.known_models.contains_key(model_id)
}
pub fn get_all_models(&self) -> Vec<ModelInfo> {
self.known_models
.keys()
.map(|id| self.get_model_info(id))
.collect()
}
pub fn supports_model(&self, _model_id: &str) -> bool {
true
}
}
pub fn get_openai_like_registry() -> &'static OpenAILikeModelRegistry {
static REGISTRY: std::sync::LazyLock<OpenAILikeModelRegistry> =
std::sync::LazyLock::new(OpenAILikeModelRegistry::with_defaults);
®ISTRY
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unknown_model_returns_default_info() {
let registry = OpenAILikeModelRegistry::new();
let info = registry.get_model_info("my-custom-model");
assert_eq!(info.id, "my-custom-model");
assert_eq!(info.name, "my-custom-model");
assert_eq!(info.provider, "openai_like");
assert!(info.supports_streaming);
}
#[test]
fn test_all_models_supported() {
let registry = OpenAILikeModelRegistry::new();
assert!(registry.supports_model("any-model-name"));
assert!(registry.supports_model("gpt-4"));
assert!(registry.supports_model("llama-2-70b"));
assert!(registry.supports_model("custom/my-model"));
}
#[test]
fn test_known_model_returns_specific_info() {
let mut registry = OpenAILikeModelRegistry::new();
registry.register_model(OpenAILikeModelConfig {
id: "llama-2-70b".to_string(),
max_context_length: 4096,
max_output_length: Some(2048),
supports_streaming: true,
supports_tools: false,
supports_multimodal: false,
input_cost_per_1k: Some(0.0001),
output_cost_per_1k: Some(0.0002),
});
let info = registry.get_model_info("llama-2-70b");
assert_eq!(info.max_context_length, 4096);
assert_eq!(info.max_output_length, Some(2048));
assert!(!info.supports_tools);
}
#[test]
fn test_custom_defaults() {
let registry = OpenAILikeModelRegistry::new()
.with_default_context_length(32000)
.with_default_output_length(8000);
let info = registry.get_model_info("unknown-model");
assert_eq!(info.max_context_length, 32000);
assert_eq!(info.max_output_length, Some(8000));
}
#[test]
fn test_is_known_model() {
let mut registry = OpenAILikeModelRegistry::new();
registry.register_model(OpenAILikeModelConfig {
id: "known-model".to_string(),
..Default::default()
});
assert!(registry.is_known_model("known-model"));
assert!(!registry.is_known_model("unknown-model"));
}
#[test]
fn test_static_registry() {
let registry = get_openai_like_registry();
assert!(registry.supports_model("any-model"));
}
}