oxirs_embed/api/
helpers.rs1use super::ApiState;
6use crate::ModelRegistry;
7use anyhow::{anyhow, Result};
8use std::sync::Arc;
9use uuid::Uuid;
10
11pub async fn get_production_model_version(state: &ApiState) -> Result<Uuid> {
13 let models = state.models.read().await;
15
16 if models.is_empty() {
17 return Err(anyhow!("No models available"));
18 }
19
20 let mut best_model: Option<(Uuid, f64)> = None;
26
27 for (uuid, model) in models.iter() {
28 if !model.is_trained() {
29 continue; }
31
32 let stats = model.get_stats();
33
34 let mut score = 0.0;
36
37 if stats.is_trained {
39 score += 100.0;
40 }
41
42 if stats.num_entities > 0 {
50 score += (stats.num_entities as f64).ln() * 10.0;
51 }
52 if stats.num_relations > 0 {
53 score += (stats.num_relations as f64).ln() * 10.0;
54 }
55
56 if let Some(last_training) = &stats.last_training_time {
58 let days_since_training = (chrono::Utc::now() - *last_training).num_days();
59 if days_since_training <= 30 {
60 score += 20.0; }
62 }
63
64 if let Some((_, best_score)) = best_model {
66 if score > best_score {
67 best_model = Some((*uuid, score));
68 }
69 } else {
70 best_model = Some((*uuid, score));
71 }
72 }
73
74 if let Some((uuid, _)) = best_model {
76 Ok(uuid)
77 } else {
78 let (uuid, _) = models
80 .iter()
81 .next()
82 .ok_or_else(|| anyhow!("No models available in store"))?;
83 Ok(*uuid)
84 }
85}
86
87pub fn validate_api_key(api_key: &str, state: &ApiState) -> bool {
89 if !state.config.auth.require_api_key {
91 return true;
92 }
93
94 if state.config.auth.api_keys.contains(&api_key.to_string()) {
96 return true;
97 }
98
99 false
101}
102
103pub fn calculate_cache_hit_rate(hits: usize, total: usize) -> f64 {
105 if total == 0 {
106 0.0
107 } else {
108 (hits as f64 / total as f64) * 100.0
109 }
110}
111
112pub async fn get_production_model(
118 registry: &Arc<ModelRegistry>,
119) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
120 let models_meta = registry.list_models().await;
122
123 for meta in &models_meta {
124 if let Some(_prod_version_id) = meta.production_version {
125 return Err(anyhow!(
130 "Production model '{}' is registered but not yet loaded into the model store. \
131 Use the model management API to load the model first.",
132 meta.name
133 ));
134 }
135 }
136
137 Err(anyhow!(
138 "No production model has been designated. \
139 Use the model management API to promote a model version to production."
140 ))
141}
142
143pub async fn get_production_model_from_state(
148 state: &ApiState,
149) -> Result<Arc<dyn crate::EmbeddingModel + Send + Sync>> {
150 let models_meta = state.registry.list_models().await;
152
153 let prod_meta = models_meta
154 .into_iter()
155 .find(|m| m.production_version.is_some());
156
157 let Some(meta) = prod_meta else {
158 return Err(anyhow!(
159 "No production model has been designated in the registry"
160 ));
161 };
162
163 let models = state.models.read().await;
166
167 let found = models
168 .iter()
169 .find(|(uuid, _)| **uuid == meta.model_id)
170 .map(|(_, model)| model.clone());
171
172 found.ok_or_else(|| {
173 anyhow!(
174 "Production model '{}' (id={}) is registered but not loaded into the server. \
175 Load it via the model management API first.",
176 meta.name,
177 meta.model_id
178 )
179 })
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use std::path::PathBuf;
186
187 #[tokio::test]
188 async fn test_get_production_model_empty_registry() {
189 let registry = Arc::new(ModelRegistry::new(PathBuf::from(
190 "/tmp/oxirs_test_registry",
191 )));
192 let result = get_production_model(®istry).await;
193 assert!(result.is_err());
194 let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
195 assert!(
196 msg.contains("No production model") || msg.contains("registered"),
197 "Expected informative error, got: {msg}"
198 );
199 }
200
201 #[test]
202 fn test_calculate_cache_hit_rate_zero_total() {
203 let rate = calculate_cache_hit_rate(0, 0);
204 assert_eq!(rate, 0.0);
205 }
206
207 #[test]
208 fn test_calculate_cache_hit_rate_normal() {
209 let rate = calculate_cache_hit_rate(50, 100);
210 assert!((rate - 50.0).abs() < 1e-9);
211 }
212
213 #[test]
214 fn test_calculate_cache_hit_rate_full() {
215 let rate = calculate_cache_hit_rate(100, 100);
216 assert!((rate - 100.0).abs() < 1e-9);
217 }
218}