modelexpress_server/registry/
state.rs1use crate::backend_config::BackendConfig;
7use crate::registry::backend::{
8 ClaimOutcome, ModelRecord, RegistryBackend, RegistryResult, create_registry_backend,
9};
10use modelexpress_common::models::{ModelProvider, ModelStatus};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::info;
14
15#[derive(Clone)]
16pub struct RegistryManager {
17 backend: Arc<RwLock<Option<Arc<dyn RegistryBackend>>>>,
18 config: Option<BackendConfig>,
19}
20
21impl Default for RegistryManager {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl RegistryManager {
28 pub fn new() -> Self {
29 Self {
30 backend: Arc::new(RwLock::new(None)),
31 config: BackendConfig::from_env().ok(),
32 }
33 }
34
35 pub fn with_config(config: BackendConfig) -> Self {
36 Self {
37 backend: Arc::new(RwLock::new(None)),
38 config: Some(config),
39 }
40 }
41
42 #[cfg(test)]
44 pub fn with_backend(backend: Arc<dyn RegistryBackend>) -> Self {
45 Self {
46 backend: Arc::new(RwLock::new(Some(backend))),
47 config: None,
48 }
49 }
50
51 pub async fn connect(&self) -> RegistryResult<String> {
55 {
56 let guard = self.backend.read().await;
57 if guard.is_some() {
58 let name = self
59 .config
60 .as_ref()
61 .map(|c| c.to_string())
62 .unwrap_or_else(|| "unknown".to_string());
63 return Ok(name);
64 }
65 }
66 let config = self.config.clone().ok_or(
67 "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
68 )?;
69 let mut guard = self.backend.write().await;
70 if guard.is_some() {
71 return Ok(config.to_string());
72 }
73 let backend_name = config.to_string();
74 let backend = create_registry_backend(config).await?;
75 *guard = Some(backend);
76 info!("RegistryManager connected (backend: {})", backend_name);
77 Ok(backend_name)
78 }
79
80 async fn get_backend(&self) -> RegistryResult<Arc<dyn RegistryBackend>> {
81 {
82 let guard = self.backend.read().await;
83 if let Some(backend) = guard.as_ref() {
84 return Ok(backend.clone());
85 }
86 }
87 let mut guard = self.backend.write().await;
88 if let Some(backend) = guard.as_ref() {
89 return Ok(backend.clone());
90 }
91 let config = self.config.clone().ok_or(
92 "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
93 )?;
94 let backend_name = config.to_string();
97 let backend = create_registry_backend(config).await?;
98 info!("RegistryManager lazily connected ({})", backend_name);
99 *guard = Some(backend.clone());
100 Ok(backend)
101 }
102
103 pub async fn get_status(&self, model_name: &str) -> RegistryResult<Option<ModelStatus>> {
104 self.get_backend().await?.get_status(model_name).await
105 }
106
107 pub async fn get_model_record(&self, model_name: &str) -> RegistryResult<Option<ModelRecord>> {
108 self.get_backend().await?.get_model_record(model_name).await
109 }
110
111 pub async fn set_status(
112 &self,
113 model_name: &str,
114 provider: ModelProvider,
115 status: ModelStatus,
116 message: Option<String>,
117 ) -> RegistryResult<()> {
118 self.get_backend()
119 .await?
120 .set_status(model_name, provider, status, message)
121 .await
122 }
123
124 pub async fn touch_model(&self, model_name: &str) -> RegistryResult<()> {
125 self.get_backend().await?.touch_model(model_name).await
126 }
127
128 pub async fn delete_model(&self, model_name: &str) -> RegistryResult<()> {
129 self.get_backend().await?.delete_model(model_name).await
130 }
131
132 pub async fn get_models_by_last_used(
133 &self,
134 limit: Option<u32>,
135 ) -> RegistryResult<Vec<ModelRecord>> {
136 self.get_backend()
137 .await?
138 .get_models_by_last_used(limit)
139 .await
140 }
141
142 pub async fn get_status_counts(&self) -> RegistryResult<(u32, u32, u32)> {
143 self.get_backend().await?.get_status_counts().await
144 }
145
146 pub async fn try_claim_for_download(
147 &self,
148 model_name: &str,
149 provider: ModelProvider,
150 ) -> RegistryResult<ClaimOutcome> {
151 self.get_backend()
152 .await?
153 .try_claim_for_download(model_name, provider)
154 .await
155 }
156
157 pub async fn try_reset_error_for_retry(
158 &self,
159 model_name: &str,
160 provider: ModelProvider,
161 ) -> RegistryResult<bool> {
162 self.get_backend()
163 .await?
164 .try_reset_error_for_retry(model_name, provider)
165 .await
166 }
167}
168
169#[cfg(test)]
170#[allow(clippy::expect_used)]
171mod tests {
172 use super::*;
173 use crate::registry::backend::MockRegistryBackend;
174 use mockall::predicate::eq;
175
176 #[tokio::test]
177 async fn connect_fails_when_no_config() {
178 let mgr = RegistryManager {
179 backend: Arc::new(RwLock::new(None)),
180 config: None,
181 };
182 assert!(mgr.connect().await.is_err());
183 }
184
185 #[tokio::test]
186 async fn try_claim_delegates_to_backend() {
187 let mut mock = MockRegistryBackend::new();
188 mock.expect_try_claim_for_download()
189 .with(eq("m"), eq(ModelProvider::HuggingFace))
190 .once()
191 .returning(|_, _| Ok(ClaimOutcome::Claimed));
192 let mgr = RegistryManager::with_backend(Arc::new(mock));
193 let outcome = mgr
194 .try_claim_for_download("m", ModelProvider::HuggingFace)
195 .await
196 .expect("claim");
197 assert_eq!(outcome, ClaimOutcome::Claimed);
198 }
199
200 #[tokio::test]
201 async fn try_reset_error_delegates_to_backend() {
202 let mut mock = MockRegistryBackend::new();
203 mock.expect_try_reset_error_for_retry()
204 .with(eq("m"), eq(ModelProvider::HuggingFace))
205 .once()
206 .returning(|_, _| Ok(true));
207 let mgr = RegistryManager::with_backend(Arc::new(mock));
208 let won = mgr
209 .try_reset_error_for_retry("m", ModelProvider::HuggingFace)
210 .await
211 .expect("retry cas");
212 assert!(won);
213 }
214
215 #[tokio::test]
216 async fn set_status_propagates_errors() {
217 let mut mock = MockRegistryBackend::new();
218 mock.expect_set_status()
219 .once()
220 .returning(|_, _, _, _| Err("backend down".into()));
221 let mgr = RegistryManager::with_backend(Arc::new(mock));
222 assert!(
223 mgr.set_status("m", ModelProvider::HuggingFace, ModelStatus::ERROR, None)
224 .await
225 .is_err()
226 );
227 }
228
229 #[tokio::test]
230 async fn get_models_by_last_used_passes_limit() {
231 let mut mock = MockRegistryBackend::new();
232 mock.expect_get_models_by_last_used()
233 .with(eq(Some(3_u32)))
234 .once()
235 .returning(|_| Ok(Vec::new()));
236 let mgr = RegistryManager::with_backend(Arc::new(mock));
237 let v = mgr.get_models_by_last_used(Some(3)).await.expect("call");
238 assert!(v.is_empty());
239 }
240
241 #[tokio::test]
242 async fn get_status_delegates_to_backend() {
243 let mut mock = MockRegistryBackend::new();
244 mock.expect_get_status()
245 .with(eq("m"))
246 .once()
247 .returning(|_| Ok(Some(ModelStatus::DOWNLOADED)));
248 let mgr = RegistryManager::with_backend(Arc::new(mock));
249 assert_eq!(
250 mgr.get_status("m").await.expect("get_status"),
251 Some(ModelStatus::DOWNLOADED)
252 );
253 }
254
255 #[tokio::test]
256 async fn get_model_record_delegates_to_backend() {
257 let mut mock = MockRegistryBackend::new();
258 mock.expect_get_model_record()
259 .with(eq("m"))
260 .once()
261 .returning(|_| Ok(None));
262 let mgr = RegistryManager::with_backend(Arc::new(mock));
263 assert!(
264 mgr.get_model_record("m")
265 .await
266 .expect("get_model_record")
267 .is_none()
268 );
269 }
270
271 #[tokio::test]
272 async fn touch_model_delegates_to_backend() {
273 let mut mock = MockRegistryBackend::new();
274 mock.expect_touch_model()
275 .with(eq("m"))
276 .once()
277 .returning(|_| Ok(()));
278 let mgr = RegistryManager::with_backend(Arc::new(mock));
279 mgr.touch_model("m").await.expect("touch");
280 }
281
282 #[tokio::test]
283 async fn delete_model_delegates_to_backend() {
284 let mut mock = MockRegistryBackend::new();
285 mock.expect_delete_model()
286 .with(eq("m"))
287 .once()
288 .returning(|_| Ok(()));
289 let mgr = RegistryManager::with_backend(Arc::new(mock));
290 mgr.delete_model("m").await.expect("delete");
291 }
292
293 #[tokio::test]
294 async fn get_status_counts_delegates_to_backend() {
295 let mut mock = MockRegistryBackend::new();
296 mock.expect_get_status_counts()
297 .once()
298 .returning(|| Ok((2, 3, 1)));
299 let mgr = RegistryManager::with_backend(Arc::new(mock));
300 assert_eq!(mgr.get_status_counts().await.expect("counts"), (2, 3, 1));
301 }
302
303 #[tokio::test]
304 async fn set_status_forwards_all_args() {
305 let mut mock = MockRegistryBackend::new();
306 mock.expect_set_status()
307 .with(
308 eq("m"),
309 eq(ModelProvider::HuggingFace),
310 eq(ModelStatus::DOWNLOADED),
311 eq(Some("done".to_string())),
312 )
313 .once()
314 .returning(|_, _, _, _| Ok(()));
315 let mgr = RegistryManager::with_backend(Arc::new(mock));
316 mgr.set_status(
317 "m",
318 ModelProvider::HuggingFace,
319 ModelStatus::DOWNLOADED,
320 Some("done".to_string()),
321 )
322 .await
323 .expect("set_status");
324 }
325
326 #[tokio::test]
327 async fn get_backend_caches_first_connection() {
328 let mut mock = MockRegistryBackend::new();
332 mock.expect_get_status().times(2).returning(|_| Ok(None));
333 let mgr = RegistryManager::with_backend(Arc::new(mock));
334 assert!(mgr.get_status("a").await.expect("first").is_none());
335 assert!(mgr.get_status("b").await.expect("second").is_none());
336 }
337}