oxirs_embed/api/
helpers.rs

1//! Helper functions for API handlers
2//!
3//! This module contains utility functions used across different API handlers.
4
5use super::ApiState;
6use anyhow::{anyhow, Result};
7use std::sync::Arc;
8use uuid::Uuid;
9
10/// Get the production model version
11pub async fn get_production_model_version(state: &ApiState) -> Result<Uuid> {
12    // First check if there's a designated production model in the registry
13    let models = state.models.read().await;
14
15    if models.is_empty() {
16        return Err(anyhow!("No models available"));
17    }
18
19    // Strategy: Find the best model based on criteria
20    // 1. Prioritize trained models over untrained ones
21    // 2. Prefer models with higher accuracy/lower loss
22    // 3. Consider model version and last update time
23
24    let mut best_model: Option<(Uuid, f64)> = None;
25
26    for (uuid, model) in models.iter() {
27        if !model.is_trained() {
28            continue; // Skip untrained models
29        }
30
31        let stats = model.get_stats();
32
33        // Calculate a composite score for model quality
34        let mut score = 0.0;
35
36        // Trained models get base score
37        if stats.is_trained {
38            score += 100.0;
39        }
40
41        // Higher accuracy is better (if available)
42        // TODO: ModelStats doesn't have an accuracy field yet
43        // if let Some(accuracy) = stats.accuracy {
44        //     score += accuracy * 100.0;
45        // }
46
47        // More entities/relations indicate a more complete model
48        score += (stats.num_entities as f64).ln() * 10.0;
49        score += (stats.num_relations as f64).ln() * 10.0;
50
51        // Recent training is preferred
52        if let Some(last_training) = &stats.last_training_time {
53            let days_since_training = (chrono::Utc::now() - *last_training).num_days();
54            if days_since_training <= 30 {
55                score += 20.0; // Bonus for recent training
56            }
57        }
58
59        // Update best model if this one is better
60        if let Some((_, best_score)) = best_model {
61            if score > best_score {
62                best_model = Some((*uuid, score));
63            }
64        } else {
65            best_model = Some((*uuid, score));
66        }
67    }
68
69    // If no trained models, fall back to any available model
70    if let Some((uuid, _)) = best_model {
71        Ok(uuid)
72    } else {
73        // Return first available model as fallback
74        let (uuid, _) = models.iter().next().unwrap();
75        Ok(*uuid)
76    }
77}
78
79/// Validate API key (if authentication is enabled)
80pub fn validate_api_key(api_key: &str, state: &ApiState) -> bool {
81    // If authentication is not required, allow all requests
82    if !state.config.auth.require_api_key {
83        return true;
84    }
85
86    // Check if the provided API key is valid
87    if state.config.auth.api_keys.contains(&api_key.to_string()) {
88        return true;
89    }
90
91    // API key validation failed
92    false
93}
94
95/// Calculate cache hit rate
96pub fn calculate_cache_hit_rate(hits: usize, total: usize) -> f64 {
97    if total == 0 {
98        0.0
99    } else {
100        (hits as f64 / total as f64) * 100.0
101    }
102}
103
104/// Get the production model (not just the version)
105/// This is a stub that needs to be properly implemented based on the registry type
106pub async fn get_production_model<T>(
107    _registry: &T,
108) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
109    Err(anyhow!("get_production_model not yet implemented"))
110}