Skip to main content

modelexpress_server/
state.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! State management for P2P model metadata.
5//!
6//! `P2pStateManager` wraps a metadata backend (Redis or Kubernetes CRD).
7//! All state — model metadata and source status — is persisted to the backend,
8//! making the server stateless and horizontally scalable.
9
10use crate::metadata_backend::{BackendConfig, MetadataBackend, MetadataResult, create_backend};
11use modelexpress_common::grpc::p2p::{SourceIdentity, WorkerMetadata};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info};
15
16// Re-export types for backwards compatibility
17pub use crate::metadata_backend::{
18    BackendMetadataRecord, ModelMetadataRecord, TensorRecord, WorkerRecord,
19};
20
21/// State manager that handles P2P metadata operations.
22///
23/// Wraps the metadata backend abstraction and provides a simpler API for
24/// common operations. Configure via `MX_METADATA_BACKEND` env var.
25#[derive(Clone)]
26pub struct P2pStateManager {
27    backend: Arc<RwLock<Option<Arc<dyn MetadataBackend>>>>,
28    config: Option<BackendConfig>,
29}
30
31impl Default for P2pStateManager {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl P2pStateManager {
38    /// Create a new state manager, resolving backend config from the environment.
39    ///
40    /// Configure via `MX_METADATA_BACKEND` (required) and `REDIS_URL` /
41    /// `MX_REDIS_HOST` / `MX_REDIS_PORT` (for Redis).
42    pub fn new() -> Self {
43        Self {
44            backend: Arc::new(RwLock::new(None)),
45            config: BackendConfig::from_env().ok(),
46        }
47    }
48
49    /// Create a new state manager with an explicit backend configuration.
50    pub fn with_config(config: BackendConfig) -> Self {
51        Self {
52            backend: Arc::new(RwLock::new(None)),
53            config: Some(config),
54        }
55    }
56
57    /// Inject a pre-built backend directly (test only).
58    #[cfg(test)]
59    pub fn with_backend(backend: Arc<dyn MetadataBackend>) -> Self {
60        Self {
61            backend: Arc::new(RwLock::new(Some(backend))),
62            config: None,
63        }
64    }
65
66    /// Initialize the backend connection. Returns the backend type name on success.
67    pub async fn connect(&self) -> MetadataResult<String> {
68        let config = self.config.clone().ok_or(
69            "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
70        )?;
71
72        let backend_name = config.to_string();
73        let backend = create_backend(config).await?;
74        let mut guard = self.backend.write().await;
75        *guard = Some(backend);
76
77        info!("P2pStateManager connected (backend: {})", backend_name);
78        Ok(backend_name)
79    }
80
81    /// Get the backend, connecting lazily if not yet connected.
82    async fn get_backend(&self) -> MetadataResult<Arc<dyn MetadataBackend>> {
83        {
84            let guard = self.backend.read().await;
85            if let Some(backend) = guard.as_ref() {
86                return Ok(backend.clone());
87            }
88        }
89
90        let mut guard = self.backend.write().await;
91        if let Some(backend) = guard.as_ref() {
92            return Ok(backend.clone());
93        }
94
95        let config = self.config.clone().ok_or(
96            "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
97        )?;
98
99        let backend = create_backend(config.clone()).await?;
100        info!("P2pStateManager connected with {:?}", config);
101        *guard = Some(backend.clone());
102        Ok(backend)
103    }
104
105    // ========================================================================
106    // Model Metadata
107    // ========================================================================
108
109    /// Publish metadata for a source instance.
110    pub async fn publish_metadata(
111        &self,
112        identity: &SourceIdentity,
113        worker_id: &str,
114        worker: WorkerMetadata,
115    ) -> MetadataResult<()> {
116        self.get_backend()
117            .await?
118            .publish_metadata(identity, worker_id, worker)
119            .await
120    }
121
122    /// Get full tensor metadata for one specific instance.
123    pub async fn get_metadata(
124        &self,
125        source_id: &str,
126        worker_id: &str,
127    ) -> MetadataResult<Option<ModelMetadataRecord>> {
128        self.get_backend()
129            .await?
130            .get_metadata(source_id, worker_id)
131            .await
132    }
133
134    /// List available source instances, optionally filtered by status.
135    pub async fn list_workers(
136        &self,
137        source_id: Option<String>,
138        status_filter: Option<modelexpress_common::grpc::p2p::SourceStatus>,
139    ) -> MetadataResult<Vec<crate::metadata_backend::SourceInstanceInfo>> {
140        self.get_backend()
141            .await?
142            .list_workers(source_id, status_filter)
143            .await
144    }
145
146    /// Remove metadata by mx_source_id.
147    pub async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
148        self.get_backend().await?.remove_metadata(source_id).await
149    }
150
151    /// Remove a single worker by source_id and worker_id.
152    pub async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
153        self.get_backend()
154            .await?
155            .remove_worker(source_id, worker_id)
156            .await
157    }
158
159    /// List all registered source IDs and model names.
160    pub async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
161        self.get_backend().await?.list_sources().await
162    }
163
164    // ========================================================================
165    // Worker Status
166    // ========================================================================
167
168    /// Update the status of a worker within its stored metadata record.
169    pub async fn update_worker_status(
170        &self,
171        source_id: &str,
172        worker_id: &str,
173        worker_rank: u32,
174        status: modelexpress_common::grpc::p2p::SourceStatus,
175    ) -> MetadataResult<()> {
176        let updated_at = chrono::Utc::now().timestamp_millis();
177        self.get_backend()
178            .await?
179            .update_status(source_id, worker_id, worker_rank, status, updated_at)
180            .await?;
181
182        debug!(
183            "Updated status for source '{}' worker '{}' rank {} -> {}",
184            source_id, worker_id, worker_rank, status as i32
185        );
186        Ok(())
187    }
188}
189
190#[cfg(test)]
191#[allow(clippy::expect_used)]
192mod tests {
193    use super::*;
194    use crate::metadata_backend::MockMetadataBackend;
195    use mockall::predicate::eq;
196    use modelexpress_common::grpc::p2p::{
197        MxSourceType, SourceIdentity, SourceStatus, TensorDescriptor,
198    };
199
200    fn test_identity() -> SourceIdentity {
201        SourceIdentity {
202            mx_version: "0.3.0".to_string(),
203            mx_source_type: MxSourceType::Weights as i32,
204            model_name: "my-model".to_string(),
205            backend_framework: 1,
206            tensor_parallel_size: 8,
207            pipeline_parallel_size: 1,
208            expert_parallel_size: 0,
209            dtype: "bfloat16".to_string(),
210            quantization: String::new(),
211            extra_parameters: Default::default(),
212        }
213    }
214
215    #[test]
216    fn test_tensor_record_conversion() {
217        let desc = TensorDescriptor {
218            name: "model.layers.0.weight".to_string(),
219            addr: 0x7f0000000000,
220            size: 1024 * 1024 * 1024,
221            device_id: 0,
222            dtype: "bfloat16".to_string(),
223        };
224
225        let record = TensorRecord::from(desc.clone());
226        assert_eq!(record.name, "model.layers.0.weight");
227        assert_eq!(record.size, 1024 * 1024 * 1024);
228
229        let back: TensorDescriptor = record.into();
230        assert_eq!(back.name, desc.name);
231        assert_eq!(back.addr, desc.addr);
232    }
233
234    #[test]
235    fn test_worker_record_conversion() {
236        use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
237
238        let meta = WorkerMetadata {
239            worker_rank: 3,
240            backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3, 4, 5])),
241            tensors: vec![TensorDescriptor {
242                name: "test.weight".to_string(),
243                addr: 0x1000,
244                size: 4096,
245                device_id: 3,
246                dtype: "float16".to_string(),
247            }],
248            status: SourceStatus::Initializing as i32,
249            updated_at: 1234567890000,
250            ..Default::default()
251        };
252
253        let record = WorkerRecord::from(meta.clone());
254        assert_eq!(record.worker_rank, 3);
255        assert!(matches!(
256            &record.backend_metadata,
257            BackendMetadataRecord::Nixl(d) if d == &vec![1, 2, 3, 4, 5]
258        ));
259        assert_eq!(record.tensors.len(), 1);
260        assert_eq!(record.status, SourceStatus::Initializing as i32);
261        assert_eq!(record.updated_at, 1234567890000);
262
263        let back: WorkerMetadata = record.into();
264        assert_eq!(back.worker_rank, meta.worker_rank);
265        assert_eq!(back.backend_metadata, meta.backend_metadata);
266    }
267
268    #[test]
269    fn test_worker_record_transfer_engine_roundtrip() {
270        use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
271
272        let meta = WorkerMetadata {
273            worker_rank: 1,
274            backend_metadata: Some(BackendMetadata::TransferEngineSessionId(
275                "192.168.1.10:12345".to_string(),
276            )),
277            tensors: vec![TensorDescriptor {
278                name: "test.weight".to_string(),
279                addr: 0x2000,
280                size: 8192,
281                device_id: 0,
282                dtype: "float16".to_string(),
283            }],
284            status: 0,
285            updated_at: 0,
286            ..Default::default()
287        };
288
289        let record = WorkerRecord::from(meta.clone());
290        assert_eq!(record.worker_rank, 1);
291        assert!(matches!(
292            &record.backend_metadata,
293            BackendMetadataRecord::TransferEngine(sid) if sid == "192.168.1.10:12345"
294        ));
295        assert_eq!(
296            record.backend_metadata.backend_type_str(),
297            "transfer_engine"
298        );
299
300        let back: WorkerMetadata = record.into();
301        assert_eq!(back.worker_rank, meta.worker_rank);
302        assert_eq!(back.backend_metadata, meta.backend_metadata);
303    }
304
305    #[test]
306    fn test_backend_metadata_from_flat_with_discriminator() {
307        // Explicit backend_type takes precedence
308        let te = BackendMetadataRecord::from_flat(
309            Vec::new(),
310            Some("10.0.0.1:5000".into()),
311            Some("transfer_engine"),
312        );
313        assert!(matches!(te, BackendMetadataRecord::TransferEngine(ref s) if s == "10.0.0.1:5000"));
314
315        let nixl = BackendMetadataRecord::from_flat(vec![1, 2, 3], None, Some("nixl"));
316        assert!(matches!(nixl, BackendMetadataRecord::Nixl(ref d) if d == &vec![1, 2, 3]));
317
318        let none = BackendMetadataRecord::from_flat(Vec::new(), None, Some("none"));
319        assert!(matches!(none, BackendMetadataRecord::None));
320
321        // Backwards compat: missing backend_type infers from fields
322        let inferred_te =
323            BackendMetadataRecord::from_flat(Vec::new(), Some("10.0.0.1:5000".into()), None);
324        assert!(matches!(
325            inferred_te,
326            BackendMetadataRecord::TransferEngine(_)
327        ));
328
329        let inferred_nixl = BackendMetadataRecord::from_flat(vec![1, 2], None, None);
330        assert!(matches!(inferred_nixl, BackendMetadataRecord::Nixl(_)));
331
332        let inferred_none = BackendMetadataRecord::from_flat(Vec::new(), None, None);
333        assert!(matches!(inferred_none, BackendMetadataRecord::None));
334    }
335
336    #[test]
337    fn test_model_record_creation() {
338        let record = ModelMetadataRecord {
339            source_id: "abc123def456abcd".to_string(),
340            worker_id: "test-instance-id".to_string(),
341            model_name: "meta-llama/Llama-3.1-70B".to_string(),
342            workers: vec![
343                WorkerRecord {
344                    worker_rank: 0,
345                    backend_metadata: BackendMetadataRecord::Nixl(vec![10, 20, 30]),
346                    tensors: vec![TensorRecord {
347                        name: "layer.0.weight".to_string(),
348                        addr: 0x7f00_0000_0000,
349                        size: 1_000_000,
350                        device_id: 0,
351                        dtype: "bfloat16".to_string(),
352                    }],
353                    status: SourceStatus::Ready as i32,
354                    updated_at: 1234567890000,
355                    metadata_endpoint: String::new(),
356                    agent_name: String::new(),
357                    worker_grpc_endpoint: String::new(),
358                },
359                WorkerRecord {
360                    worker_rank: 1,
361                    backend_metadata: BackendMetadataRecord::Nixl(vec![40, 50, 60]),
362                    tensors: vec![TensorRecord {
363                        name: "layer.0.weight".to_string(),
364                        addr: 0x7f00_0000_0000,
365                        size: 1_000_000,
366                        device_id: 1,
367                        dtype: "bfloat16".to_string(),
368                    }],
369                    status: SourceStatus::Ready as i32,
370                    updated_at: 1234567890000,
371                    metadata_endpoint: String::new(),
372                    agent_name: String::new(),
373                    worker_grpc_endpoint: String::new(),
374                },
375            ],
376            published_at: 1234567890,
377        };
378
379        assert_eq!(record.model_name, "meta-llama/Llama-3.1-70B");
380        assert_eq!(record.workers.len(), 2);
381        assert_eq!(record.workers[0].worker_rank, 0);
382        assert_eq!(record.workers[1].worker_rank, 1);
383    }
384
385    #[test]
386    fn test_backend_config_parsing() {
387        let default_redis = "redis://localhost:6379";
388        let default_ns = "default";
389
390        // Redis
391        let config =
392            BackendConfig::from_type_str("redis", "redis://myhost:6379", default_ns).expect("ok");
393        assert!(matches!(config, BackendConfig::Redis { url } if url == "redis://myhost:6379"));
394
395        // Kubernetes aliases
396        for alias in &["kubernetes", "k8s", "crd"] {
397            let config = BackendConfig::from_type_str(alias, default_redis, "prod-ns").expect("ok");
398            assert!(
399                matches!(config, BackendConfig::Kubernetes { namespace } if namespace == "prod-ns")
400            );
401        }
402
403        // Unknown returns Err
404        assert!(BackendConfig::from_type_str("bogus", default_redis, default_ns).is_err());
405        assert!(BackendConfig::from_type_str("memory", default_redis, default_ns).is_err());
406        assert!(BackendConfig::from_type_str("", default_redis, default_ns).is_err());
407
408        // Case insensitive
409        let config = BackendConfig::from_type_str("REDIS", default_redis, default_ns).expect("ok");
410        assert!(matches!(config, BackendConfig::Redis { .. }));
411    }
412
413    #[tokio::test]
414    async fn test_publish_metadata_calls_backend() {
415        let mut mock = MockMetadataBackend::new();
416        mock.expect_publish_metadata()
417            .withf(|identity, worker_id, worker| {
418                identity.model_name == "my-model"
419                    && identity.tensor_parallel_size == 8
420                    && worker_id == "a1b2c3d4"
421                    && worker.worker_rank == 3
422            })
423            .once()
424            .returning(|_, _, _| Ok(()));
425
426        let manager = P2pStateManager::with_backend(Arc::new(mock));
427        manager
428            .publish_metadata(
429                &test_identity(),
430                "a1b2c3d4",
431                WorkerMetadata {
432                    worker_rank: 3,
433                    backend_metadata: None,
434                    tensors: vec![],
435                    status: SourceStatus::Initializing as i32,
436                    updated_at: 0,
437                    ..Default::default()
438                },
439            )
440            .await
441            .expect("publish_metadata failed");
442    }
443
444    #[tokio::test]
445    async fn test_publish_metadata_propagates_backend_error() {
446        let mut mock = MockMetadataBackend::new();
447        mock.expect_publish_metadata()
448            .once()
449            .returning(|_, _, _| Err("storage unavailable".into()));
450
451        let manager = P2pStateManager::with_backend(Arc::new(mock));
452        assert!(
453            manager
454                .publish_metadata(&test_identity(), "a1b2c3d4", WorkerMetadata::default())
455                .await
456                .is_err()
457        );
458    }
459
460    #[tokio::test]
461    async fn test_connect_fails_without_config() {
462        let manager = P2pStateManager {
463            backend: Arc::new(RwLock::new(None)),
464            config: None,
465        };
466        assert!(manager.connect().await.is_err());
467    }
468
469    #[tokio::test]
470    async fn test_update_worker_status_calls_backend() {
471        let mut mock = MockMetadataBackend::new();
472        mock.expect_update_status()
473            .with(
474                eq("abc123def456abcd"),
475                eq("test-instance"),
476                eq(2u32),
477                eq(SourceStatus::Ready),
478                mockall::predicate::always(),
479            )
480            .once()
481            .returning(|_, _, _, _, _| Ok(()));
482
483        let manager = P2pStateManager::with_backend(Arc::new(mock));
484        manager
485            .update_worker_status("abc123def456abcd", "test-instance", 2, SourceStatus::Ready)
486            .await
487            .expect("update_worker_status failed");
488    }
489
490    #[tokio::test]
491    async fn test_update_worker_status_propagates_backend_error() {
492        let mut mock = MockMetadataBackend::new();
493        mock.expect_update_status()
494            .once()
495            .returning(|_, _, _, _, _| Err("redis unavailable".into()));
496
497        let manager = P2pStateManager::with_backend(Arc::new(mock));
498        assert!(
499            manager
500                .update_worker_status("abc123def456abcd", "test-instance", 0, SourceStatus::Ready)
501                .await
502                .is_err()
503        );
504    }
505
506    #[tokio::test]
507    async fn test_list_workers_calls_backend() {
508        let mut mock = MockMetadataBackend::new();
509        mock.expect_list_workers()
510            .withf(|source_id, status_filter| {
511                source_id.as_deref() == Some("abc123def456abcd")
512                    && *status_filter == Some(SourceStatus::Ready)
513            })
514            .once()
515            .returning(|_, _| {
516                Ok(vec![crate::metadata_backend::SourceInstanceInfo {
517                    source_id: "abc123def456abcd".to_string(),
518                    worker_id: "w1".to_string(),
519                    model_name: "my-model".to_string(),
520                    worker_rank: 0,
521                    status: SourceStatus::Ready as i32,
522                    updated_at: 1234567890000,
523                }])
524            });
525
526        let manager = P2pStateManager::with_backend(Arc::new(mock));
527        let result = manager
528            .list_workers(
529                Some("abc123def456abcd".to_string()),
530                Some(SourceStatus::Ready),
531            )
532            .await
533            .expect("list_workers failed");
534        assert_eq!(result.len(), 1);
535        assert_eq!(result[0].worker_id, "w1");
536    }
537
538    #[tokio::test]
539    async fn test_list_workers_propagates_backend_error() {
540        let mut mock = MockMetadataBackend::new();
541        mock.expect_list_workers()
542            .once()
543            .returning(|_, _| Err("backend error".into()));
544
545        let manager = P2pStateManager::with_backend(Arc::new(mock));
546        assert!(manager.list_workers(None, None).await.is_err());
547    }
548
549    #[tokio::test]
550    async fn test_remove_metadata_calls_backend() {
551        let mut mock = MockMetadataBackend::new();
552        mock.expect_remove_metadata()
553            .with(eq("abc123def456abcd"))
554            .once()
555            .returning(|_| Ok(()));
556
557        let manager = P2pStateManager::with_backend(Arc::new(mock));
558        manager
559            .remove_metadata("abc123def456abcd")
560            .await
561            .expect("remove_metadata failed");
562    }
563
564    #[tokio::test]
565    async fn test_remove_metadata_propagates_backend_error() {
566        let mut mock = MockMetadataBackend::new();
567        mock.expect_remove_metadata()
568            .once()
569            .returning(|_| Err("delete failed".into()));
570
571        let manager = P2pStateManager::with_backend(Arc::new(mock));
572        assert!(manager.remove_metadata("abc123def456abcd").await.is_err());
573    }
574
575    #[tokio::test]
576    async fn test_list_sources_calls_backend() {
577        let mut mock = MockMetadataBackend::new();
578        mock.expect_list_sources()
579            .once()
580            .returning(|| Ok(vec![("src1".to_string(), "model-a".to_string())]));
581
582        let manager = P2pStateManager::with_backend(Arc::new(mock));
583        let result = manager.list_sources().await.expect("list_sources failed");
584        assert_eq!(result.len(), 1);
585        assert_eq!(result[0].0, "src1");
586        assert_eq!(result[0].1, "model-a");
587    }
588
589    #[tokio::test]
590    async fn test_update_worker_status_stores_correct_status() {
591        let mut mock = MockMetadataBackend::new();
592        mock.expect_update_status()
593            .withf(|source_id, worker_id, worker_rank, status, _updated_at| {
594                source_id == "abc123def456abcd"
595                    && worker_id == "test-instance"
596                    && *worker_rank == 7
597                    && *status == SourceStatus::Ready
598            })
599            .once()
600            .returning(|_, _, _, _, _| Ok(()));
601
602        let manager = P2pStateManager::with_backend(Arc::new(mock));
603        manager
604            .update_worker_status("abc123def456abcd", "test-instance", 7, SourceStatus::Ready)
605            .await
606            .expect("update_worker_status failed");
607    }
608}