Skip to main content

modelexpress_server/registry/
state.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Lazy-connect wrapper around `RegistryBackend`, parallel to [`crate::p2p::state`].
5
6use crate::backend_config::BackendConfig;
7use crate::registry::backend::{
8    ClaimOutcome, ModelRecord, RegistryBackend, RegistryResult, create_registry_backend,
9};
10use modelexpress_common::models::{ModelProvider, ModelStatus};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::info;
14
15#[derive(Clone)]
16pub struct RegistryManager {
17    backend: Arc<RwLock<Option<Arc<dyn RegistryBackend>>>>,
18    config: Option<BackendConfig>,
19}
20
21impl Default for RegistryManager {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl RegistryManager {
28    pub fn new() -> Self {
29        Self {
30            backend: Arc::new(RwLock::new(None)),
31            config: BackendConfig::from_env().ok(),
32        }
33    }
34
35    pub fn with_config(config: BackendConfig) -> Self {
36        Self {
37            backend: Arc::new(RwLock::new(None)),
38            config: Some(config),
39        }
40    }
41
42    /// Inject a pre-built backend directly (tests only).
43    #[cfg(test)]
44    pub fn with_backend(backend: Arc<dyn RegistryBackend>) -> Self {
45        Self {
46            backend: Arc::new(RwLock::new(Some(backend))),
47            config: None,
48        }
49    }
50
51    /// Eagerly connect to the configured backend. Returns the backend type name.
52    /// Idempotent: if a backend is already cached, return its name without
53    /// re-creating it.
54    pub async fn connect(&self) -> RegistryResult<String> {
55        {
56            let guard = self.backend.read().await;
57            if guard.is_some() {
58                let name = self
59                    .config
60                    .as_ref()
61                    .map(|c| c.to_string())
62                    .unwrap_or_else(|| "unknown".to_string());
63                return Ok(name);
64            }
65        }
66        let config = self.config.clone().ok_or(
67            "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
68        )?;
69        let mut guard = self.backend.write().await;
70        if guard.is_some() {
71            return Ok(config.to_string());
72        }
73        let backend_name = config.to_string();
74        let backend = create_registry_backend(config).await?;
75        *guard = Some(backend);
76        info!("RegistryManager connected (backend: {})", backend_name);
77        Ok(backend_name)
78    }
79
80    async fn get_backend(&self) -> RegistryResult<Arc<dyn RegistryBackend>> {
81        {
82            let guard = self.backend.read().await;
83            if let Some(backend) = guard.as_ref() {
84                return Ok(backend.clone());
85            }
86        }
87        let mut guard = self.backend.write().await;
88        if let Some(backend) = guard.as_ref() {
89            return Ok(backend.clone());
90        }
91        let config = self.config.clone().ok_or(
92            "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
93        )?;
94        // Use Display (redacts connection URLs for Redis) — Debug would print the full
95        // `BackendConfig` including the unredacted URL.
96        let backend_name = config.to_string();
97        let backend = create_registry_backend(config).await?;
98        info!("RegistryManager lazily connected ({})", backend_name);
99        *guard = Some(backend.clone());
100        Ok(backend)
101    }
102
103    pub async fn get_status(&self, model_name: &str) -> RegistryResult<Option<ModelStatus>> {
104        self.get_backend().await?.get_status(model_name).await
105    }
106
107    pub async fn get_model_record(&self, model_name: &str) -> RegistryResult<Option<ModelRecord>> {
108        self.get_backend().await?.get_model_record(model_name).await
109    }
110
111    pub async fn set_status(
112        &self,
113        model_name: &str,
114        provider: ModelProvider,
115        status: ModelStatus,
116        message: Option<String>,
117    ) -> RegistryResult<()> {
118        self.get_backend()
119            .await?
120            .set_status(model_name, provider, status, message)
121            .await
122    }
123
124    pub async fn touch_model(&self, model_name: &str) -> RegistryResult<()> {
125        self.get_backend().await?.touch_model(model_name).await
126    }
127
128    pub async fn delete_model(&self, model_name: &str) -> RegistryResult<()> {
129        self.get_backend().await?.delete_model(model_name).await
130    }
131
132    pub async fn get_models_by_last_used(
133        &self,
134        limit: Option<u32>,
135    ) -> RegistryResult<Vec<ModelRecord>> {
136        self.get_backend()
137            .await?
138            .get_models_by_last_used(limit)
139            .await
140    }
141
142    pub async fn get_status_counts(&self) -> RegistryResult<(u32, u32, u32)> {
143        self.get_backend().await?.get_status_counts().await
144    }
145
146    pub async fn try_claim_for_download(
147        &self,
148        model_name: &str,
149        provider: ModelProvider,
150    ) -> RegistryResult<ClaimOutcome> {
151        self.get_backend()
152            .await?
153            .try_claim_for_download(model_name, provider)
154            .await
155    }
156
157    pub async fn try_reset_error_for_retry(
158        &self,
159        model_name: &str,
160        provider: ModelProvider,
161    ) -> RegistryResult<bool> {
162        self.get_backend()
163            .await?
164            .try_reset_error_for_retry(model_name, provider)
165            .await
166    }
167}
168
169#[cfg(test)]
170#[allow(clippy::expect_used)]
171mod tests {
172    use super::*;
173    use crate::registry::backend::MockRegistryBackend;
174    use mockall::predicate::eq;
175
176    #[tokio::test]
177    async fn connect_fails_when_no_config() {
178        let mgr = RegistryManager {
179            backend: Arc::new(RwLock::new(None)),
180            config: None,
181        };
182        assert!(mgr.connect().await.is_err());
183    }
184
185    #[tokio::test]
186    async fn try_claim_delegates_to_backend() {
187        let mut mock = MockRegistryBackend::new();
188        mock.expect_try_claim_for_download()
189            .with(eq("m"), eq(ModelProvider::HuggingFace))
190            .once()
191            .returning(|_, _| Ok(ClaimOutcome::Claimed));
192        let mgr = RegistryManager::with_backend(Arc::new(mock));
193        let outcome = mgr
194            .try_claim_for_download("m", ModelProvider::HuggingFace)
195            .await
196            .expect("claim");
197        assert_eq!(outcome, ClaimOutcome::Claimed);
198    }
199
200    #[tokio::test]
201    async fn try_reset_error_delegates_to_backend() {
202        let mut mock = MockRegistryBackend::new();
203        mock.expect_try_reset_error_for_retry()
204            .with(eq("m"), eq(ModelProvider::HuggingFace))
205            .once()
206            .returning(|_, _| Ok(true));
207        let mgr = RegistryManager::with_backend(Arc::new(mock));
208        let won = mgr
209            .try_reset_error_for_retry("m", ModelProvider::HuggingFace)
210            .await
211            .expect("retry cas");
212        assert!(won);
213    }
214
215    #[tokio::test]
216    async fn set_status_propagates_errors() {
217        let mut mock = MockRegistryBackend::new();
218        mock.expect_set_status()
219            .once()
220            .returning(|_, _, _, _| Err("backend down".into()));
221        let mgr = RegistryManager::with_backend(Arc::new(mock));
222        assert!(
223            mgr.set_status("m", ModelProvider::HuggingFace, ModelStatus::ERROR, None)
224                .await
225                .is_err()
226        );
227    }
228
229    #[tokio::test]
230    async fn get_models_by_last_used_passes_limit() {
231        let mut mock = MockRegistryBackend::new();
232        mock.expect_get_models_by_last_used()
233            .with(eq(Some(3_u32)))
234            .once()
235            .returning(|_| Ok(Vec::new()));
236        let mgr = RegistryManager::with_backend(Arc::new(mock));
237        let v = mgr.get_models_by_last_used(Some(3)).await.expect("call");
238        assert!(v.is_empty());
239    }
240
241    #[tokio::test]
242    async fn get_status_delegates_to_backend() {
243        let mut mock = MockRegistryBackend::new();
244        mock.expect_get_status()
245            .with(eq("m"))
246            .once()
247            .returning(|_| Ok(Some(ModelStatus::DOWNLOADED)));
248        let mgr = RegistryManager::with_backend(Arc::new(mock));
249        assert_eq!(
250            mgr.get_status("m").await.expect("get_status"),
251            Some(ModelStatus::DOWNLOADED)
252        );
253    }
254
255    #[tokio::test]
256    async fn get_model_record_delegates_to_backend() {
257        let mut mock = MockRegistryBackend::new();
258        mock.expect_get_model_record()
259            .with(eq("m"))
260            .once()
261            .returning(|_| Ok(None));
262        let mgr = RegistryManager::with_backend(Arc::new(mock));
263        assert!(
264            mgr.get_model_record("m")
265                .await
266                .expect("get_model_record")
267                .is_none()
268        );
269    }
270
271    #[tokio::test]
272    async fn touch_model_delegates_to_backend() {
273        let mut mock = MockRegistryBackend::new();
274        mock.expect_touch_model()
275            .with(eq("m"))
276            .once()
277            .returning(|_| Ok(()));
278        let mgr = RegistryManager::with_backend(Arc::new(mock));
279        mgr.touch_model("m").await.expect("touch");
280    }
281
282    #[tokio::test]
283    async fn delete_model_delegates_to_backend() {
284        let mut mock = MockRegistryBackend::new();
285        mock.expect_delete_model()
286            .with(eq("m"))
287            .once()
288            .returning(|_| Ok(()));
289        let mgr = RegistryManager::with_backend(Arc::new(mock));
290        mgr.delete_model("m").await.expect("delete");
291    }
292
293    #[tokio::test]
294    async fn get_status_counts_delegates_to_backend() {
295        let mut mock = MockRegistryBackend::new();
296        mock.expect_get_status_counts()
297            .once()
298            .returning(|| Ok((2, 3, 1)));
299        let mgr = RegistryManager::with_backend(Arc::new(mock));
300        assert_eq!(mgr.get_status_counts().await.expect("counts"), (2, 3, 1));
301    }
302
303    #[tokio::test]
304    async fn set_status_forwards_all_args() {
305        let mut mock = MockRegistryBackend::new();
306        mock.expect_set_status()
307            .with(
308                eq("m"),
309                eq(ModelProvider::HuggingFace),
310                eq(ModelStatus::DOWNLOADED),
311                eq(Some("done".to_string())),
312            )
313            .once()
314            .returning(|_, _, _, _| Ok(()));
315        let mgr = RegistryManager::with_backend(Arc::new(mock));
316        mgr.set_status(
317            "m",
318            ModelProvider::HuggingFace,
319            ModelStatus::DOWNLOADED,
320            Some("done".to_string()),
321        )
322        .await
323        .expect("set_status");
324    }
325
326    #[tokio::test]
327    async fn get_backend_caches_first_connection() {
328        // With a pre-injected backend, repeated calls to any pass-through method must hit
329        // the same backend instance (the double-checked RwLock path returns the cached
330        // Arc rather than re-invoking create_registry_backend).
331        let mut mock = MockRegistryBackend::new();
332        mock.expect_get_status().times(2).returning(|_| Ok(None));
333        let mgr = RegistryManager::with_backend(Arc::new(mock));
334        assert!(mgr.get_status("a").await.expect("first").is_none());
335        assert!(mgr.get_status("b").await.expect("second").is_none());
336    }
337}