use std::collections::HashMap;
use std::sync::OnceLock;
use crate::core::types::model::ModelInfo;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ModelFeature {
ChatCompletion,
StreamingSupport,
FunctionCalling,
SystemMessages,
WebSocketSupport,
}
#[derive(Debug, Clone)]
pub struct ModelPricing {
pub input_price: f64,
pub output_price: f64,
}
#[derive(Debug, Clone)]
pub struct ModelLimits {
pub max_context_length: u32,
pub max_output_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct ModelSpec {
pub model_info: ModelInfo,
pub features: Vec<ModelFeature>,
pub pricing: ModelPricing,
pub limits: ModelLimits,
}
#[derive(Debug, Clone)]
pub struct SparkModelRegistry {
models: HashMap<String, ModelSpec>,
}
impl SparkModelRegistry {
pub fn new() -> Self {
let mut registry = Self {
models: HashMap::new(),
};
registry.initialize_models();
registry
}
fn initialize_models(&mut self) {
self.register_model(
"spark-desk-v3.5",
ModelSpec {
model_info: ModelInfo {
id: "spark-desk-v3.5".to_string(),
name: "Spark Desk v3.5".to_string(),
provider: "spark".to_string(),
max_context_length: 8192,
max_output_length: Some(4096),
supports_streaming: true,
supports_tools: true,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.003),
output_cost_per_1k_tokens: Some(0.006),
currency: "USD".to_string(),
capabilities: vec![
crate::core::types::model::ProviderCapability::ChatCompletion,
crate::core::types::model::ProviderCapability::ChatCompletionStream,
],
..Default::default()
},
features: vec![
ModelFeature::ChatCompletion,
ModelFeature::StreamingSupport,
ModelFeature::FunctionCalling,
ModelFeature::SystemMessages,
ModelFeature::WebSocketSupport,
],
pricing: ModelPricing {
input_price: 3.0, output_price: 6.0, },
limits: ModelLimits {
max_context_length: 8192,
max_output_tokens: 4096,
},
},
);
self.register_model(
"spark-desk-v3",
ModelSpec {
model_info: ModelInfo {
id: "spark-desk-v3".to_string(),
name: "Spark Desk v3".to_string(),
provider: "spark".to_string(),
max_context_length: 8192,
max_output_length: Some(4096),
supports_streaming: true,
supports_tools: true,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.0025),
output_cost_per_1k_tokens: Some(0.005),
currency: "USD".to_string(),
capabilities: vec![
crate::core::types::model::ProviderCapability::ChatCompletion,
crate::core::types::model::ProviderCapability::ChatCompletionStream,
],
..Default::default()
},
features: vec![
ModelFeature::ChatCompletion,
ModelFeature::StreamingSupport,
ModelFeature::FunctionCalling,
ModelFeature::SystemMessages,
ModelFeature::WebSocketSupport,
],
pricing: ModelPricing {
input_price: 2.5, output_price: 5.0, },
limits: ModelLimits {
max_context_length: 8192,
max_output_tokens: 4096,
},
},
);
self.register_model(
"spark-desk-v2",
ModelSpec {
model_info: ModelInfo {
id: "spark-desk-v2".to_string(),
name: "Spark Desk v2".to_string(),
provider: "spark".to_string(),
max_context_length: 4096,
max_output_length: Some(2048),
supports_streaming: true,
supports_tools: false,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.002),
output_cost_per_1k_tokens: Some(0.004),
currency: "USD".to_string(),
capabilities: vec![
crate::core::types::model::ProviderCapability::ChatCompletion,
crate::core::types::model::ProviderCapability::ChatCompletionStream,
],
..Default::default()
},
features: vec![
ModelFeature::ChatCompletion,
ModelFeature::StreamingSupport,
ModelFeature::SystemMessages,
ModelFeature::WebSocketSupport,
],
pricing: ModelPricing {
input_price: 2.0, output_price: 4.0, },
limits: ModelLimits {
max_context_length: 4096,
max_output_tokens: 2048,
},
},
);
self.register_model(
"spark-desk-v1.5",
ModelSpec {
model_info: ModelInfo {
id: "spark-desk-v1.5".to_string(),
name: "Spark Desk v1.5".to_string(),
provider: "spark".to_string(),
max_context_length: 4096,
max_output_length: Some(2048),
supports_streaming: true,
supports_tools: false,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.0015),
output_cost_per_1k_tokens: Some(0.003),
currency: "USD".to_string(),
capabilities: vec![
crate::core::types::model::ProviderCapability::ChatCompletion,
crate::core::types::model::ProviderCapability::ChatCompletionStream,
],
..Default::default()
},
features: vec![
ModelFeature::ChatCompletion,
ModelFeature::StreamingSupport,
ModelFeature::SystemMessages,
ModelFeature::WebSocketSupport,
],
pricing: ModelPricing {
input_price: 1.5, output_price: 3.0, },
limits: ModelLimits {
max_context_length: 4096,
max_output_tokens: 2048,
},
},
);
}
fn register_model(&mut self, id: &str, spec: ModelSpec) {
self.models.insert(id.to_string(), spec);
}
pub fn get_model_spec(&self, model_id: &str) -> Option<&ModelSpec> {
self.models.get(model_id)
}
pub fn list_models(&self) -> Vec<&ModelSpec> {
self.models.values().collect()
}
pub fn supports_feature(&self, model_id: &str, feature: &ModelFeature) -> bool {
self.get_model_spec(model_id)
.map(|spec| spec.features.contains(feature))
.unwrap_or(false)
}
}
impl Default for SparkModelRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn get_spark_registry() -> &'static SparkModelRegistry {
static REGISTRY: OnceLock<SparkModelRegistry> = OnceLock::new();
REGISTRY.get_or_init(SparkModelRegistry::new)
}
pub struct CostCalculator;
impl CostCalculator {
pub fn calculate_cost(model_id: &str, input_tokens: u32, output_tokens: u32) -> Option<f64> {
let registry = get_spark_registry();
let spec = registry.get_model_spec(model_id)?;
let input_cost = (input_tokens as f64 / 1_000_000.0) * spec.pricing.input_price;
let output_cost = (output_tokens as f64 / 1_000_000.0) * spec.pricing.output_price;
Some(input_cost + output_cost)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_initialization() {
let registry = get_spark_registry();
let models = registry.list_models();
assert_eq!(models.len(), 4);
assert!(registry.get_model_spec("spark-desk-v3.5").is_some());
assert!(registry.get_model_spec("spark-desk-v3").is_some());
assert!(registry.get_model_spec("spark-desk-v2").is_some());
assert!(registry.get_model_spec("spark-desk-v1.5").is_some());
}
#[test]
fn test_model_features() {
let registry = get_spark_registry();
assert!(registry.supports_feature("spark-desk-v3.5", &ModelFeature::FunctionCalling));
assert!(registry.supports_feature("spark-desk-v3.5", &ModelFeature::StreamingSupport));
assert!(!registry.supports_feature("spark-desk-v2", &ModelFeature::FunctionCalling));
assert!(registry.supports_feature("spark-desk-v2", &ModelFeature::StreamingSupport));
}
#[test]
fn test_model_limits() {
let registry = get_spark_registry();
let v3_5_spec = registry.get_model_spec("spark-desk-v3.5").unwrap();
assert_eq!(v3_5_spec.limits.max_context_length, 8192);
assert_eq!(v3_5_spec.limits.max_output_tokens, 4096);
let v2_spec = registry.get_model_spec("spark-desk-v2").unwrap();
assert_eq!(v2_spec.limits.max_context_length, 4096);
assert_eq!(v2_spec.limits.max_output_tokens, 2048);
}
#[test]
fn test_cost_calculation() {
let cost = CostCalculator::calculate_cost("spark-desk-v3.5", 1000, 500);
assert!(cost.is_some());
let cost_value = cost.unwrap();
assert!((cost_value - 0.006).abs() < 0.0001);
}
#[test]
fn test_unknown_model() {
let registry = get_spark_registry();
assert!(registry.get_model_spec("unknown-model").is_none());
let cost = CostCalculator::calculate_cost("unknown-model", 1000, 500);
assert!(cost.is_none());
}
}