use super::ApiState;
use crate::ModelRegistry;
use anyhow::{anyhow, Result};
use std::sync::Arc;
use uuid::Uuid;
pub async fn get_production_model_version(state: &ApiState) -> Result<Uuid> {
let models = state.models.read().await;
if models.is_empty() {
return Err(anyhow!("No models available"));
}
let mut best_model: Option<(Uuid, f64)> = None;
for (uuid, model) in models.iter() {
if !model.is_trained() {
continue; }
let stats = model.get_stats();
let mut score = 0.0;
if stats.is_trained {
score += 100.0;
}
if stats.num_entities > 0 {
score += (stats.num_entities as f64).ln() * 10.0;
}
if stats.num_relations > 0 {
score += (stats.num_relations as f64).ln() * 10.0;
}
if let Some(last_training) = &stats.last_training_time {
let days_since_training = (chrono::Utc::now() - *last_training).num_days();
if days_since_training <= 30 {
score += 20.0; }
}
if let Some((_, best_score)) = best_model {
if score > best_score {
best_model = Some((*uuid, score));
}
} else {
best_model = Some((*uuid, score));
}
}
if let Some((uuid, _)) = best_model {
Ok(uuid)
} else {
let (uuid, _) = models
.iter()
.next()
.ok_or_else(|| anyhow!("No models available in store"))?;
Ok(*uuid)
}
}
pub fn validate_api_key(api_key: &str, state: &ApiState) -> bool {
if !state.config.auth.require_api_key {
return true;
}
if state.config.auth.api_keys.contains(&api_key.to_string()) {
return true;
}
false
}
pub fn calculate_cache_hit_rate(hits: usize, total: usize) -> f64 {
if total == 0 {
0.0
} else {
(hits as f64 / total as f64) * 100.0
}
}
pub async fn get_production_model(
registry: &Arc<ModelRegistry>,
) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
let models_meta = registry.list_models().await;
for meta in &models_meta {
if let Some(_prod_version_id) = meta.production_version {
return Err(anyhow!(
"Production model '{}' is registered but not yet loaded into the model store. \
Use the model management API to load the model first.",
meta.name
));
}
}
Err(anyhow!(
"No production model has been designated. \
Use the model management API to promote a model version to production."
))
}
pub async fn get_production_model_from_state(
state: &ApiState,
) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
let models_meta = state.registry.list_models().await;
let prod_meta = models_meta
.into_iter()
.find(|m| m.production_version.is_some());
let Some(meta) = prod_meta else {
return Err(anyhow!(
"No production model has been designated in the registry"
));
};
let models = state.models.read().await;
let found = models
.iter()
.find(|(uuid, _)| **uuid == meta.model_id)
.map(|(_, model)| model.clone());
found.ok_or_else(|| {
anyhow!(
"Production model '{}' (id={}) is registered but not loaded into the server. \
Load it via the model management API first.",
meta.name,
meta.model_id
)
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[tokio::test]
async fn test_get_production_model_empty_registry() {
let registry = Arc::new(ModelRegistry::new(PathBuf::from(
"/tmp/oxirs_test_registry",
)));
let result = get_production_model(®istry).await;
assert!(result.is_err());
let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
assert!(
msg.contains("No production model") || msg.contains("registered"),
"Expected informative error, got: {msg}"
);
}
#[test]
fn test_calculate_cache_hit_rate_zero_total() {
let rate = calculate_cache_hit_rate(0, 0);
assert_eq!(rate, 0.0);
}
#[test]
fn test_calculate_cache_hit_rate_normal() {
let rate = calculate_cache_hit_rate(50, 100);
assert!((rate - 50.0).abs() < 1e-9);
}
#[test]
fn test_calculate_cache_hit_rate_full() {
let rate = calculate_cache_hit_rate(100, 100);
assert!((rate - 100.0).abs() < 1e-9);
}
}