use std::collections::HashMap;
use std::sync::OnceLock;
use crate::core::providers::base::get_pricing_db;
use crate::core::types::model::ModelInfo;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ModelFeature {
FunctionCalling,
VisionSupport,
StreamingSupport,
SystemMessages,
ToolCalling,
DocumentSupport,
}
#[derive(Debug, Clone)]
pub struct ModelSpec {
pub model_info: ModelInfo,
pub features: Vec<ModelFeature>,
pub config: ModelConfig,
}
#[derive(Debug, Clone, Default)]
pub struct ModelConfig {
pub max_concurrent_requests: Option<u32>,
pub custom_params: HashMap<String, String>,
}
pub struct AI21ModelRegistry {
models: HashMap<String, ModelSpec>,
}
impl Default for AI21ModelRegistry {
fn default() -> Self {
Self::new()
}
}
impl AI21ModelRegistry {
pub fn new() -> Self {
let mut registry = Self {
models: HashMap::new(),
};
registry.load_models();
registry
}
fn load_models(&mut self) {
let pricing_db = get_pricing_db();
let model_ids = pricing_db.get_provider_models("ai21");
for model_id in &model_ids {
if let Some(model_info) = pricing_db.to_model_info(model_id, "ai21") {
let features = self.detect_features(&model_info);
let config = self.create_config(&model_info);
self.models.insert(
model_id.clone(),
ModelSpec {
model_info,
features,
config,
},
);
}
}
if self.models.is_empty() {
self.add_default_models();
}
}
fn detect_features(&self, model_info: &ModelInfo) -> Vec<ModelFeature> {
let mut features = vec![ModelFeature::SystemMessages, ModelFeature::StreamingSupport];
if model_info.supports_tools {
features.push(ModelFeature::FunctionCalling);
features.push(ModelFeature::ToolCalling);
}
if model_info.supports_multimodal {
features.push(ModelFeature::VisionSupport);
}
if model_info.id.contains("jamba") {
features.push(ModelFeature::DocumentSupport);
}
features
}
fn create_config(&self, model_info: &ModelInfo) -> ModelConfig {
ModelConfig {
max_concurrent_requests: Some(match model_info.id.as_str() {
"jamba-1.5-large" => 5,
"jamba-1.5-mini" => 10,
_ => 5,
}),
..Default::default()
}
}
fn add_default_models(&mut self) {
let default_models = vec![
(
"jamba-1.5-large",
"Jamba 1.5 Large",
256_000,
Some(4_096),
0.002, 0.008, true, false, ),
(
"jamba-1.5-mini",
"Jamba 1.5 Mini",
256_000,
Some(4_096),
0.0002, 0.0004, true, false, ),
(
"jamba-instruct",
"Jamba Instruct",
256_000,
Some(4_096),
0.0005,
0.0007,
true,
false,
),
];
for (
id,
name,
context_len,
output_len,
input_cost,
output_cost,
supports_tools,
supports_multimodal,
) in default_models
{
let model_info = ModelInfo {
id: id.to_string(),
name: name.to_string(),
provider: "ai21".to_string(),
max_context_length: context_len,
max_output_length: output_len,
supports_streaming: true,
supports_tools,
supports_multimodal,
input_cost_per_1k_tokens: Some(input_cost),
output_cost_per_1k_tokens: Some(output_cost),
currency: "USD".to_string(),
capabilities: vec![],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
};
let features = self.detect_features(&model_info);
let config = self.create_config(&model_info);
self.models.insert(
id.to_string(),
ModelSpec {
model_info,
features,
config,
},
);
}
}
pub fn get_all_models(&self) -> Vec<ModelInfo> {
self.models
.values()
.map(|spec| spec.model_info.clone())
.collect()
}
pub fn get_model_spec(&self, model_id: &str) -> Option<&ModelSpec> {
self.models.get(model_id)
}
pub fn supports_feature(&self, model_id: &str, feature: &ModelFeature) -> bool {
self.models
.get(model_id)
.map(|spec| spec.features.contains(feature))
.unwrap_or(false)
}
pub fn get_models_with_feature(&self, feature: &ModelFeature) -> Vec<String> {
self.models
.iter()
.filter_map(|(id, spec)| {
if spec.features.contains(feature) {
Some(id.clone())
} else {
None
}
})
.collect()
}
}
static AI21_REGISTRY: OnceLock<AI21ModelRegistry> = OnceLock::new();
pub fn get_ai21_registry() -> &'static AI21ModelRegistry {
AI21_REGISTRY.get_or_init(AI21ModelRegistry::new)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_registry_creation() {
let registry = AI21ModelRegistry::new();
assert!(!registry.get_all_models().is_empty());
}
#[test]
fn test_feature_detection() {
let registry = get_ai21_registry();
let models = registry.get_all_models();
assert!(!models.is_empty());
for model in &models {
assert!(registry.supports_feature(&model.id, &ModelFeature::SystemMessages));
assert!(registry.supports_feature(&model.id, &ModelFeature::StreamingSupport));
}
}
#[test]
fn test_models_with_feature() {
let registry = get_ai21_registry();
let tool_models = registry.get_models_with_feature(&ModelFeature::ToolCalling);
assert!(!tool_models.is_empty());
}
#[test]
fn test_default_impl() {
let registry = AI21ModelRegistry::default();
assert!(!registry.get_all_models().is_empty());
}
#[test]
fn test_get_model_spec() {
let registry = get_ai21_registry();
let spec = registry.get_model_spec("jamba-1.5-large");
assert!(spec.is_some());
let spec = spec.unwrap();
assert_eq!(spec.model_info.provider, "ai21");
}
#[test]
fn test_get_model_spec_nonexistent() {
let registry = get_ai21_registry();
let spec = registry.get_model_spec("nonexistent-model");
assert!(spec.is_none());
}
#[test]
fn test_jamba_document_support() {
let registry = get_ai21_registry();
let document_models = registry.get_models_with_feature(&ModelFeature::DocumentSupport);
for model in &document_models {
assert!(model.contains("jamba"));
}
}
#[test]
fn test_model_config() {
let registry = get_ai21_registry();
if let Some(spec) = registry.get_model_spec("jamba-1.5-large") {
assert_eq!(spec.config.max_concurrent_requests, Some(5));
}
if let Some(spec) = registry.get_model_spec("jamba-1.5-mini") {
assert_eq!(spec.config.max_concurrent_requests, Some(10));
}
}
#[test]
fn test_model_feature_equality() {
assert_eq!(ModelFeature::FunctionCalling, ModelFeature::FunctionCalling);
assert_ne!(ModelFeature::FunctionCalling, ModelFeature::VisionSupport);
}
#[test]
fn test_model_info_properties() {
let registry = get_ai21_registry();
let models = registry.get_all_models();
for model in models {
assert!(!model.id.is_empty());
assert!(!model.name.is_empty());
assert_eq!(model.provider, "ai21");
assert!(model.max_context_length > 0);
assert_eq!(model.currency, "USD");
assert!(model.input_cost_per_1k_tokens.is_some());
assert!(model.output_cost_per_1k_tokens.is_some());
}
}
#[test]
fn test_global_registry() {
let registry1 = get_ai21_registry();
let registry2 = get_ai21_registry();
assert_eq!(
registry1.get_all_models().len(),
registry2.get_all_models().len()
);
}
#[test]
fn test_jamba_large_model() {
let registry = get_ai21_registry();
let spec = registry.get_model_spec("jamba-1.5-large").unwrap();
assert_eq!(spec.model_info.max_context_length, 256_000);
assert!(registry.supports_feature("jamba-1.5-large", &ModelFeature::ToolCalling));
assert!(registry.supports_feature("jamba-1.5-large", &ModelFeature::FunctionCalling));
assert!(registry.supports_feature("jamba-1.5-large", &ModelFeature::DocumentSupport));
}
#[test]
fn test_supports_feature_nonexistent() {
let registry = get_ai21_registry();
assert!(!registry.supports_feature("nonexistent", &ModelFeature::FunctionCalling));
assert!(!registry.supports_feature("nonexistent", &ModelFeature::ToolCalling));
}
}