Skip to main content

oxirs_embed/api/
helpers.rs

1//! Helper functions for API handlers
2//!
3//! This module contains utility functions used across different API handlers.
4
5use super::ApiState;
6use crate::ModelRegistry;
7use anyhow::{anyhow, Result};
8use std::sync::Arc;
9use uuid::Uuid;
10
11/// Get the production model version
12pub async fn get_production_model_version(state: &ApiState) -> Result<Uuid> {
13    // First check if there's a designated production model in the registry
14    let models = state.models.read().await;
15
16    if models.is_empty() {
17        return Err(anyhow!("No models available"));
18    }
19
20    // Strategy: Find the best model based on criteria
21    // 1. Prioritize trained models over untrained ones
22    // 2. Prefer models with higher accuracy/lower loss
23    // 3. Consider model version and last update time
24
25    let mut best_model: Option<(Uuid, f64)> = None;
26
27    for (uuid, model) in models.iter() {
28        if !model.is_trained() {
29            continue; // Skip untrained models
30        }
31
32        let stats = model.get_stats();
33
34        // Calculate a composite score for model quality
35        let mut score = 0.0;
36
37        // Trained models get base score
38        if stats.is_trained {
39            score += 100.0;
40        }
41
42        // Higher accuracy is better (if available)
43        // TODO: ModelStats doesn't have an accuracy field yet
44        // if let Some(accuracy) = stats.accuracy {
45        //     score += accuracy * 100.0;
46        // }
47
48        // More entities/relations indicate a more complete model
49        if stats.num_entities > 0 {
50            score += (stats.num_entities as f64).ln() * 10.0;
51        }
52        if stats.num_relations > 0 {
53            score += (stats.num_relations as f64).ln() * 10.0;
54        }
55
56        // Recent training is preferred
57        if let Some(last_training) = &stats.last_training_time {
58            let days_since_training = (chrono::Utc::now() - *last_training).num_days();
59            if days_since_training <= 30 {
60                score += 20.0; // Bonus for recent training
61            }
62        }
63
64        // Update best model if this one is better
65        if let Some((_, best_score)) = best_model {
66            if score > best_score {
67                best_model = Some((*uuid, score));
68            }
69        } else {
70            best_model = Some((*uuid, score));
71        }
72    }
73
74    // If no trained models, fall back to any available model
75    if let Some((uuid, _)) = best_model {
76        Ok(uuid)
77    } else {
78        // Return first available model as fallback
79        let (uuid, _) = models
80            .iter()
81            .next()
82            .ok_or_else(|| anyhow!("No models available in store"))?;
83        Ok(*uuid)
84    }
85}
86
87/// Validate API key (if authentication is enabled)
88pub fn validate_api_key(api_key: &str, state: &ApiState) -> bool {
89    // If authentication is not required, allow all requests
90    if !state.config.auth.require_api_key {
91        return true;
92    }
93
94    // Check if the provided API key is valid
95    if state.config.auth.api_keys.contains(&api_key.to_string()) {
96        return true;
97    }
98
99    // API key validation failed
100    false
101}
102
103/// Calculate cache hit rate
104pub fn calculate_cache_hit_rate(hits: usize, total: usize) -> f64 {
105    if total == 0 {
106        0.0
107    } else {
108        (hits as f64 / total as f64) * 100.0
109    }
110}
111
112/// Get the production model from the registry.
113///
114/// Searches the registry for the first model that has a designated production
115/// version, then looks up that version's model UUID in the in-memory model
116/// store and returns the corresponding loaded model.
117pub async fn get_production_model(
118    registry: &Arc<ModelRegistry>,
119) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
120    // Find any model that has a production version pinned
121    let models_meta = registry.list_models().await;
122
123    for meta in &models_meta {
124        if let Some(_prod_version_id) = meta.production_version {
125            // The in-memory model store is keyed by model_id, not version_id.
126            // Return the first model whose UUID matches this entry.
127            // Callers that need version-level granularity should use
128            // `get_production_model_version` and look up via state.models directly.
129            return Err(anyhow!(
130                "Production model '{}' is registered but not yet loaded into the model store. \
131                 Use the model management API to load the model first.",
132                meta.name
133            ));
134        }
135    }
136
137    Err(anyhow!(
138        "No production model has been designated. \
139         Use the model management API to promote a model version to production."
140    ))
141}
142
143/// Get the production model from the in-memory model store.
144///
145/// Combines registry metadata with the live model store to return
146/// the currently running production model instance, if any.
147pub async fn get_production_model_from_state(
148    state: &ApiState,
149) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
150    // 1. Find which model has a production_version set in the registry
151    let models_meta = state.registry.list_models().await;
152
153    let prod_meta = models_meta
154        .into_iter()
155        .find(|m| m.production_version.is_some());
156
157    let Some(meta) = prod_meta else {
158        return Err(anyhow!(
159            "No production model has been designated in the registry"
160        ));
161    };
162
163    // 2. The live model store is keyed by each model's own UUID (returned by
164    //    model.model_id()). Find a loaded model whose UUID matches meta.model_id.
165    let models = state.models.read().await;
166
167    let found = models
168        .iter()
169        .find(|(uuid, _)| **uuid == meta.model_id)
170        .map(|(_, model)| model.clone());
171
172    found.ok_or_else(|| {
173        anyhow!(
174            "Production model '{}' (id={}) is registered but not loaded into the server. \
175             Load it via the model management API first.",
176            meta.name,
177            meta.model_id
178        )
179    })
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[tokio::test]
187    async fn test_get_production_model_empty_registry() {
188        let registry = Arc::new(ModelRegistry::new(
189            std::env::temp_dir().join(format!("oxirs_test_registry_{}", std::process::id())),
190        ));
191        let result = get_production_model(&registry).await;
192        assert!(result.is_err());
193        let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
194        assert!(
195            msg.contains("No production model") || msg.contains("registered"),
196            "Expected informative error, got: {msg}"
197        );
198    }
199
200    #[test]
201    fn test_calculate_cache_hit_rate_zero_total() {
202        let rate = calculate_cache_hit_rate(0, 0);
203        assert_eq!(rate, 0.0);
204    }
205
206    #[test]
207    fn test_calculate_cache_hit_rate_normal() {
208        let rate = calculate_cache_hit_rate(50, 100);
209        assert!((rate - 50.0).abs() < 1e-9);
210    }
211
212    #[test]
213    fn test_calculate_cache_hit_rate_full() {
214        let rate = calculate_cache_hit_rate(100, 100);
215        assert!((rate - 100.0).abs() < 1e-9);
216    }
217}