Skip to main content

modelexpress_server/p2p/
state.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! State management for P2P model metadata.
5//!
6//! `P2pStateManager` wraps a metadata backend (Redis or Kubernetes CRD).
7//! All state — model metadata and source status — is persisted to the backend,
8//! making the server stateless and horizontally scalable.
9
10use crate::p2p::backend::{BackendConfig, MetadataBackend, MetadataResult, create_backend};
11use modelexpress_common::grpc::p2p::{SourceIdentity, WorkerMetadata};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info};
15
16// Re-export types for backwards compatibility
17pub use crate::p2p::backend::{
18    BackendMetadataRecord, ModelMetadataRecord, TensorRecord, WorkerRecord,
19};
20
21/// State manager that handles P2P metadata operations.
22///
23/// Wraps the metadata backend abstraction and provides a simpler API for
24/// common operations. Configure via `MX_METADATA_BACKEND` env var.
25#[derive(Clone)]
26pub struct P2pStateManager {
27    backend: Arc<RwLock<Option<Arc<dyn MetadataBackend>>>>,
28    config: Option<BackendConfig>,
29}
30
31impl Default for P2pStateManager {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl P2pStateManager {
38    /// Create a new state manager, resolving backend config from the environment.
39    ///
40    /// Configure via `MX_METADATA_BACKEND` (required) and `REDIS_URL` /
41    /// `MX_REDIS_HOST` / `MX_REDIS_PORT` (for Redis).
42    pub fn new() -> Self {
43        Self {
44            backend: Arc::new(RwLock::new(None)),
45            config: BackendConfig::from_env().ok(),
46        }
47    }
48
49    /// Create a new state manager with an explicit backend configuration.
50    pub fn with_config(config: BackendConfig) -> Self {
51        Self {
52            backend: Arc::new(RwLock::new(None)),
53            config: Some(config),
54        }
55    }
56
57    /// Inject a pre-built backend directly (test only).
58    #[cfg(test)]
59    pub fn with_backend(backend: Arc<dyn MetadataBackend>) -> Self {
60        Self {
61            backend: Arc::new(RwLock::new(Some(backend))),
62            config: None,
63        }
64    }
65
66    /// Initialize the backend connection. Returns the backend type name on success.
67    pub async fn connect(&self) -> MetadataResult<String> {
68        let config = self.config.clone().ok_or(
69            "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
70        )?;
71
72        let backend_name = config.to_string();
73        let backend = create_backend(config).await?;
74        let mut guard = self.backend.write().await;
75        *guard = Some(backend);
76
77        info!("P2pStateManager connected (backend: {})", backend_name);
78        Ok(backend_name)
79    }
80
81    /// Get the backend, connecting lazily if not yet connected.
82    async fn get_backend(&self) -> MetadataResult<Arc<dyn MetadataBackend>> {
83        {
84            let guard = self.backend.read().await;
85            if let Some(backend) = guard.as_ref() {
86                return Ok(backend.clone());
87            }
88        }
89
90        let mut guard = self.backend.write().await;
91        if let Some(backend) = guard.as_ref() {
92            return Ok(backend.clone());
93        }
94
95        let config = self.config.clone().ok_or(
96            "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
97        )?;
98
99        let backend = create_backend(config.clone()).await?;
100        info!("P2pStateManager connected with {:?}", config);
101        *guard = Some(backend.clone());
102        Ok(backend)
103    }
104
105    // ========================================================================
106    // Model Metadata
107    // ========================================================================
108
109    /// Publish metadata for a source instance.
110    pub async fn publish_metadata(
111        &self,
112        identity: &SourceIdentity,
113        worker_id: &str,
114        worker: WorkerMetadata,
115    ) -> MetadataResult<()> {
116        self.get_backend()
117            .await?
118            .publish_metadata(identity, worker_id, worker)
119            .await
120    }
121
122    /// Get full tensor metadata for one specific instance.
123    pub async fn get_metadata(
124        &self,
125        source_id: &str,
126        worker_id: &str,
127    ) -> MetadataResult<Option<ModelMetadataRecord>> {
128        self.get_backend()
129            .await?
130            .get_metadata(source_id, worker_id)
131            .await
132    }
133
134    /// List available source instances, optionally filtered by status.
135    pub async fn list_workers(
136        &self,
137        source_id: Option<String>,
138        status_filter: Option<modelexpress_common::grpc::p2p::SourceStatus>,
139    ) -> MetadataResult<Vec<crate::p2p::backend::SourceInstanceInfo>> {
140        self.get_backend()
141            .await?
142            .list_workers(source_id, status_filter)
143            .await
144    }
145
146    /// Remove metadata by mx_source_id.
147    pub async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
148        self.get_backend().await?.remove_metadata(source_id).await
149    }
150
151    /// Remove a single worker by source_id and worker_id.
152    pub async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
153        self.get_backend()
154            .await?
155            .remove_worker(source_id, worker_id)
156            .await
157    }
158
159    /// List all registered source IDs and model names.
160    pub async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
161        self.get_backend().await?.list_sources().await
162    }
163
164    // ========================================================================
165    // Worker Status
166    // ========================================================================
167
168    /// Update the status of a worker within its stored metadata record.
169    pub async fn update_worker_status(
170        &self,
171        source_id: &str,
172        worker_id: &str,
173        worker_rank: u32,
174        status: modelexpress_common::grpc::p2p::SourceStatus,
175    ) -> MetadataResult<()> {
176        let updated_at = chrono::Utc::now().timestamp_millis();
177        self.get_backend()
178            .await?
179            .update_status(source_id, worker_id, worker_rank, status, updated_at)
180            .await?;
181
182        debug!(
183            "Updated status for source '{}' worker '{}' rank {} -> {}",
184            source_id, worker_id, worker_rank, status as i32
185        );
186        Ok(())
187    }
188}
189
190#[cfg(test)]
191#[allow(clippy::expect_used)]
192mod tests {
193    use super::*;
194    use crate::p2p::backend::MockMetadataBackend;
195    use mockall::predicate::eq;
196    use modelexpress_common::grpc::p2p::{
197        MxSourceType, SourceIdentity, SourceStatus, TensorDescriptor,
198    };
199
200    fn test_identity() -> SourceIdentity {
201        SourceIdentity {
202            mx_version: "0.4.0".to_string(),
203            mx_source_type: MxSourceType::Weights as i32,
204            model_name: "my-model".to_string(),
205            backend_framework: 1,
206            tensor_parallel_size: 8,
207            pipeline_parallel_size: 1,
208            expert_parallel_size: 0,
209            dtype: "bfloat16".to_string(),
210            quantization: String::new(),
211            extra_parameters: Default::default(),
212            revision: String::new(),
213        }
214    }
215
216    #[test]
217    fn test_tensor_record_conversion() {
218        let desc = TensorDescriptor {
219            name: "model.layers.0.weight".to_string(),
220            addr: 0x7f0000000000,
221            size: 1024 * 1024 * 1024,
222            device_id: 0,
223            dtype: "bfloat16".to_string(),
224        };
225
226        let record = TensorRecord::from(desc.clone());
227        assert_eq!(record.name, "model.layers.0.weight");
228        assert_eq!(record.size, 1024 * 1024 * 1024);
229
230        let back: TensorDescriptor = record.into();
231        assert_eq!(back.name, desc.name);
232        assert_eq!(back.addr, desc.addr);
233    }
234
235    #[test]
236    fn test_worker_record_conversion() {
237        use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
238
239        let meta = WorkerMetadata {
240            worker_rank: 3,
241            backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3, 4, 5])),
242            tensors: vec![TensorDescriptor {
243                name: "test.weight".to_string(),
244                addr: 0x1000,
245                size: 4096,
246                device_id: 3,
247                dtype: "float16".to_string(),
248            }],
249            status: SourceStatus::Initializing as i32,
250            updated_at: 1234567890000,
251            ..Default::default()
252        };
253
254        let record = WorkerRecord::from(meta.clone());
255        assert_eq!(record.worker_rank, 3);
256        assert!(matches!(
257            &record.backend_metadata,
258            BackendMetadataRecord::Nixl(d) if d == &vec![1, 2, 3, 4, 5]
259        ));
260        assert_eq!(record.tensors.len(), 1);
261        assert_eq!(record.status, SourceStatus::Initializing as i32);
262        assert_eq!(record.updated_at, 1234567890000);
263
264        let back: WorkerMetadata = record.into();
265        assert_eq!(back.worker_rank, meta.worker_rank);
266        assert_eq!(back.backend_metadata, meta.backend_metadata);
267    }
268
269    #[test]
270    fn test_worker_record_transfer_engine_roundtrip() {
271        use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
272
273        let meta = WorkerMetadata {
274            worker_rank: 1,
275            backend_metadata: Some(BackendMetadata::TransferEngineSessionId(
276                "192.168.1.10:12345".to_string(),
277            )),
278            tensors: vec![TensorDescriptor {
279                name: "test.weight".to_string(),
280                addr: 0x2000,
281                size: 8192,
282                device_id: 0,
283                dtype: "float16".to_string(),
284            }],
285            status: 0,
286            updated_at: 0,
287            ..Default::default()
288        };
289
290        let record = WorkerRecord::from(meta.clone());
291        assert_eq!(record.worker_rank, 1);
292        assert!(matches!(
293            &record.backend_metadata,
294            BackendMetadataRecord::TransferEngine(sid) if sid == "192.168.1.10:12345"
295        ));
296        assert_eq!(
297            record.backend_metadata.backend_type_str(),
298            "transfer_engine"
299        );
300
301        let back: WorkerMetadata = record.into();
302        assert_eq!(back.worker_rank, meta.worker_rank);
303        assert_eq!(back.backend_metadata, meta.backend_metadata);
304    }
305
306    #[test]
307    fn test_backend_metadata_from_flat_with_discriminator() {
308        // Explicit backend_type takes precedence
309        let te = BackendMetadataRecord::from_flat(
310            Vec::new(),
311            Some("10.0.0.1:5000".into()),
312            Some("transfer_engine"),
313        );
314        assert!(matches!(te, BackendMetadataRecord::TransferEngine(ref s) if s == "10.0.0.1:5000"));
315
316        let nixl = BackendMetadataRecord::from_flat(vec![1, 2, 3], None, Some("nixl"));
317        assert!(matches!(nixl, BackendMetadataRecord::Nixl(ref d) if d == &vec![1, 2, 3]));
318
319        let none = BackendMetadataRecord::from_flat(Vec::new(), None, Some("none"));
320        assert!(matches!(none, BackendMetadataRecord::None));
321
322        // Backwards compat: missing backend_type infers from fields
323        let inferred_te =
324            BackendMetadataRecord::from_flat(Vec::new(), Some("10.0.0.1:5000".into()), None);
325        assert!(matches!(
326            inferred_te,
327            BackendMetadataRecord::TransferEngine(_)
328        ));
329
330        let inferred_nixl = BackendMetadataRecord::from_flat(vec![1, 2], None, None);
331        assert!(matches!(inferred_nixl, BackendMetadataRecord::Nixl(_)));
332
333        let inferred_none = BackendMetadataRecord::from_flat(Vec::new(), None, None);
334        assert!(matches!(inferred_none, BackendMetadataRecord::None));
335    }
336
337    #[test]
338    fn test_model_record_creation() {
339        let record = ModelMetadataRecord {
340            source_id: "abc123def456abcd".to_string(),
341            worker_id: "test-instance-id".to_string(),
342            model_name: "meta-llama/Llama-3.1-70B".to_string(),
343            workers: vec![
344                WorkerRecord {
345                    worker_rank: 0,
346                    backend_metadata: BackendMetadataRecord::Nixl(vec![10, 20, 30]),
347                    tensors: vec![TensorRecord {
348                        name: "layer.0.weight".to_string(),
349                        addr: 0x7f00_0000_0000,
350                        size: 1_000_000,
351                        device_id: 0,
352                        dtype: "bfloat16".to_string(),
353                    }],
354                    status: SourceStatus::Ready as i32,
355                    updated_at: 1234567890000,
356                    metadata_endpoint: String::new(),
357                    agent_name: String::new(),
358                    worker_grpc_endpoint: String::new(),
359                },
360                WorkerRecord {
361                    worker_rank: 1,
362                    backend_metadata: BackendMetadataRecord::Nixl(vec![40, 50, 60]),
363                    tensors: vec![TensorRecord {
364                        name: "layer.0.weight".to_string(),
365                        addr: 0x7f00_0000_0000,
366                        size: 1_000_000,
367                        device_id: 1,
368                        dtype: "bfloat16".to_string(),
369                    }],
370                    status: SourceStatus::Ready as i32,
371                    updated_at: 1234567890000,
372                    metadata_endpoint: String::new(),
373                    agent_name: String::new(),
374                    worker_grpc_endpoint: String::new(),
375                },
376            ],
377            published_at: 1234567890,
378        };
379
380        assert_eq!(record.model_name, "meta-llama/Llama-3.1-70B");
381        assert_eq!(record.workers.len(), 2);
382        assert_eq!(record.workers[0].worker_rank, 0);
383        assert_eq!(record.workers[1].worker_rank, 1);
384    }
385
386    #[tokio::test]
387    async fn test_publish_metadata_calls_backend() {
388        let mut mock = MockMetadataBackend::new();
389        mock.expect_publish_metadata()
390            .withf(|identity, worker_id, worker| {
391                identity.model_name == "my-model"
392                    && identity.tensor_parallel_size == 8
393                    && worker_id == "a1b2c3d4"
394                    && worker.worker_rank == 3
395            })
396            .once()
397            .returning(|_, _, _| Ok(()));
398
399        let manager = P2pStateManager::with_backend(Arc::new(mock));
400        manager
401            .publish_metadata(
402                &test_identity(),
403                "a1b2c3d4",
404                WorkerMetadata {
405                    worker_rank: 3,
406                    backend_metadata: None,
407                    tensors: vec![],
408                    status: SourceStatus::Initializing as i32,
409                    updated_at: 0,
410                    ..Default::default()
411                },
412            )
413            .await
414            .expect("publish_metadata failed");
415    }
416
417    #[tokio::test]
418    async fn test_publish_metadata_propagates_backend_error() {
419        let mut mock = MockMetadataBackend::new();
420        mock.expect_publish_metadata()
421            .once()
422            .returning(|_, _, _| Err("storage unavailable".into()));
423
424        let manager = P2pStateManager::with_backend(Arc::new(mock));
425        assert!(
426            manager
427                .publish_metadata(&test_identity(), "a1b2c3d4", WorkerMetadata::default())
428                .await
429                .is_err()
430        );
431    }
432
433    #[tokio::test]
434    async fn test_connect_fails_without_config() {
435        let manager = P2pStateManager {
436            backend: Arc::new(RwLock::new(None)),
437            config: None,
438        };
439        assert!(manager.connect().await.is_err());
440    }
441
442    #[tokio::test]
443    async fn test_update_worker_status_calls_backend() {
444        let mut mock = MockMetadataBackend::new();
445        mock.expect_update_status()
446            .with(
447                eq("abc123def456abcd"),
448                eq("test-instance"),
449                eq(2u32),
450                eq(SourceStatus::Ready),
451                mockall::predicate::always(),
452            )
453            .once()
454            .returning(|_, _, _, _, _| Ok(()));
455
456        let manager = P2pStateManager::with_backend(Arc::new(mock));
457        manager
458            .update_worker_status("abc123def456abcd", "test-instance", 2, SourceStatus::Ready)
459            .await
460            .expect("update_worker_status failed");
461    }
462
463    #[tokio::test]
464    async fn test_update_worker_status_propagates_backend_error() {
465        let mut mock = MockMetadataBackend::new();
466        mock.expect_update_status()
467            .once()
468            .returning(|_, _, _, _, _| Err("redis unavailable".into()));
469
470        let manager = P2pStateManager::with_backend(Arc::new(mock));
471        assert!(
472            manager
473                .update_worker_status("abc123def456abcd", "test-instance", 0, SourceStatus::Ready)
474                .await
475                .is_err()
476        );
477    }
478
479    #[tokio::test]
480    async fn test_list_workers_calls_backend() {
481        let mut mock = MockMetadataBackend::new();
482        mock.expect_list_workers()
483            .withf(|source_id, status_filter| {
484                source_id.as_deref() == Some("abc123def456abcd")
485                    && *status_filter == Some(SourceStatus::Ready)
486            })
487            .once()
488            .returning(|_, _| {
489                Ok(vec![crate::p2p::backend::SourceInstanceInfo {
490                    source_id: "abc123def456abcd".to_string(),
491                    worker_id: "w1".to_string(),
492                    model_name: "my-model".to_string(),
493                    worker_rank: 0,
494                    status: SourceStatus::Ready as i32,
495                    updated_at: 1234567890000,
496                }])
497            });
498
499        let manager = P2pStateManager::with_backend(Arc::new(mock));
500        let result = manager
501            .list_workers(
502                Some("abc123def456abcd".to_string()),
503                Some(SourceStatus::Ready),
504            )
505            .await
506            .expect("list_workers failed");
507        assert_eq!(result.len(), 1);
508        assert_eq!(result[0].worker_id, "w1");
509    }
510
511    #[tokio::test]
512    async fn test_list_workers_propagates_backend_error() {
513        let mut mock = MockMetadataBackend::new();
514        mock.expect_list_workers()
515            .once()
516            .returning(|_, _| Err("backend error".into()));
517
518        let manager = P2pStateManager::with_backend(Arc::new(mock));
519        assert!(manager.list_workers(None, None).await.is_err());
520    }
521
522    #[tokio::test]
523    async fn test_remove_metadata_calls_backend() {
524        let mut mock = MockMetadataBackend::new();
525        mock.expect_remove_metadata()
526            .with(eq("abc123def456abcd"))
527            .once()
528            .returning(|_| Ok(()));
529
530        let manager = P2pStateManager::with_backend(Arc::new(mock));
531        manager
532            .remove_metadata("abc123def456abcd")
533            .await
534            .expect("remove_metadata failed");
535    }
536
537    #[tokio::test]
538    async fn test_remove_metadata_propagates_backend_error() {
539        let mut mock = MockMetadataBackend::new();
540        mock.expect_remove_metadata()
541            .once()
542            .returning(|_| Err("delete failed".into()));
543
544        let manager = P2pStateManager::with_backend(Arc::new(mock));
545        assert!(manager.remove_metadata("abc123def456abcd").await.is_err());
546    }
547
548    #[tokio::test]
549    async fn test_list_sources_calls_backend() {
550        let mut mock = MockMetadataBackend::new();
551        mock.expect_list_sources()
552            .once()
553            .returning(|| Ok(vec![("src1".to_string(), "model-a".to_string())]));
554
555        let manager = P2pStateManager::with_backend(Arc::new(mock));
556        let result = manager.list_sources().await.expect("list_sources failed");
557        assert_eq!(result.len(), 1);
558        assert_eq!(result[0].0, "src1");
559        assert_eq!(result[0].1, "model-a");
560    }
561
562    #[tokio::test]
563    async fn test_update_worker_status_stores_correct_status() {
564        let mut mock = MockMetadataBackend::new();
565        mock.expect_update_status()
566            .withf(|source_id, worker_id, worker_rank, status, _updated_at| {
567                source_id == "abc123def456abcd"
568                    && worker_id == "test-instance"
569                    && *worker_rank == 7
570                    && *status == SourceStatus::Ready
571            })
572            .once()
573            .returning(|_, _, _, _, _| Ok(()));
574
575        let manager = P2pStateManager::with_backend(Arc::new(mock));
576        manager
577            .update_worker_status("abc123def456abcd", "test-instance", 7, SourceStatus::Ready)
578            .await
579            .expect("update_worker_status failed");
580    }
581}