1use crate::metadata_backend::{BackendConfig, MetadataBackend, MetadataResult, create_backend};
11use modelexpress_common::grpc::p2p::{SourceIdentity, WorkerMetadata};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info};
15
16pub use crate::metadata_backend::{
18 BackendMetadataRecord, ModelMetadataRecord, TensorRecord, WorkerRecord,
19};
20
21#[derive(Clone)]
26pub struct P2pStateManager {
27 backend: Arc<RwLock<Option<Arc<dyn MetadataBackend>>>>,
28 config: Option<BackendConfig>,
29}
30
31impl Default for P2pStateManager {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl P2pStateManager {
38 pub fn new() -> Self {
43 Self {
44 backend: Arc::new(RwLock::new(None)),
45 config: BackendConfig::from_env().ok(),
46 }
47 }
48
49 pub fn with_config(config: BackendConfig) -> Self {
51 Self {
52 backend: Arc::new(RwLock::new(None)),
53 config: Some(config),
54 }
55 }
56
57 #[cfg(test)]
59 pub fn with_backend(backend: Arc<dyn MetadataBackend>) -> Self {
60 Self {
61 backend: Arc::new(RwLock::new(Some(backend))),
62 config: None,
63 }
64 }
65
66 pub async fn connect(&self) -> MetadataResult<String> {
68 let config = self.config.clone().ok_or(
69 "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
70 )?;
71
72 let backend_name = config.to_string();
73 let backend = create_backend(config).await?;
74 let mut guard = self.backend.write().await;
75 *guard = Some(backend);
76
77 info!("P2pStateManager connected (backend: {})", backend_name);
78 Ok(backend_name)
79 }
80
81 async fn get_backend(&self) -> MetadataResult<Arc<dyn MetadataBackend>> {
83 {
84 let guard = self.backend.read().await;
85 if let Some(backend) = guard.as_ref() {
86 return Ok(backend.clone());
87 }
88 }
89
90 let mut guard = self.backend.write().await;
91 if let Some(backend) = guard.as_ref() {
92 return Ok(backend.clone());
93 }
94
95 let config = self.config.clone().ok_or(
96 "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
97 )?;
98
99 let backend = create_backend(config.clone()).await?;
100 info!("P2pStateManager connected with {:?}", config);
101 *guard = Some(backend.clone());
102 Ok(backend)
103 }
104
105 pub async fn publish_metadata(
111 &self,
112 identity: &SourceIdentity,
113 worker_id: &str,
114 worker: WorkerMetadata,
115 ) -> MetadataResult<()> {
116 self.get_backend()
117 .await?
118 .publish_metadata(identity, worker_id, worker)
119 .await
120 }
121
122 pub async fn get_metadata(
124 &self,
125 source_id: &str,
126 worker_id: &str,
127 ) -> MetadataResult<Option<ModelMetadataRecord>> {
128 self.get_backend()
129 .await?
130 .get_metadata(source_id, worker_id)
131 .await
132 }
133
134 pub async fn list_workers(
136 &self,
137 source_id: Option<String>,
138 status_filter: Option<modelexpress_common::grpc::p2p::SourceStatus>,
139 ) -> MetadataResult<Vec<crate::metadata_backend::SourceInstanceInfo>> {
140 self.get_backend()
141 .await?
142 .list_workers(source_id, status_filter)
143 .await
144 }
145
146 pub async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
148 self.get_backend().await?.remove_metadata(source_id).await
149 }
150
151 pub async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
153 self.get_backend()
154 .await?
155 .remove_worker(source_id, worker_id)
156 .await
157 }
158
159 pub async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
161 self.get_backend().await?.list_sources().await
162 }
163
164 pub async fn update_worker_status(
170 &self,
171 source_id: &str,
172 worker_id: &str,
173 worker_rank: u32,
174 status: modelexpress_common::grpc::p2p::SourceStatus,
175 ) -> MetadataResult<()> {
176 let updated_at = chrono::Utc::now().timestamp_millis();
177 self.get_backend()
178 .await?
179 .update_status(source_id, worker_id, worker_rank, status, updated_at)
180 .await?;
181
182 debug!(
183 "Updated status for source '{}' worker '{}' rank {} -> {}",
184 source_id, worker_id, worker_rank, status as i32
185 );
186 Ok(())
187 }
188}
189
190#[cfg(test)]
191#[allow(clippy::expect_used)]
192mod tests {
193 use super::*;
194 use crate::metadata_backend::MockMetadataBackend;
195 use mockall::predicate::eq;
196 use modelexpress_common::grpc::p2p::{
197 MxSourceType, SourceIdentity, SourceStatus, TensorDescriptor,
198 };
199
200 fn test_identity() -> SourceIdentity {
201 SourceIdentity {
202 mx_version: "0.3.0".to_string(),
203 mx_source_type: MxSourceType::Weights as i32,
204 model_name: "my-model".to_string(),
205 backend_framework: 1,
206 tensor_parallel_size: 8,
207 pipeline_parallel_size: 1,
208 expert_parallel_size: 0,
209 dtype: "bfloat16".to_string(),
210 quantization: String::new(),
211 extra_parameters: Default::default(),
212 }
213 }
214
215 #[test]
216 fn test_tensor_record_conversion() {
217 let desc = TensorDescriptor {
218 name: "model.layers.0.weight".to_string(),
219 addr: 0x7f0000000000,
220 size: 1024 * 1024 * 1024,
221 device_id: 0,
222 dtype: "bfloat16".to_string(),
223 };
224
225 let record = TensorRecord::from(desc.clone());
226 assert_eq!(record.name, "model.layers.0.weight");
227 assert_eq!(record.size, 1024 * 1024 * 1024);
228
229 let back: TensorDescriptor = record.into();
230 assert_eq!(back.name, desc.name);
231 assert_eq!(back.addr, desc.addr);
232 }
233
234 #[test]
235 fn test_worker_record_conversion() {
236 use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
237
238 let meta = WorkerMetadata {
239 worker_rank: 3,
240 backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3, 4, 5])),
241 tensors: vec![TensorDescriptor {
242 name: "test.weight".to_string(),
243 addr: 0x1000,
244 size: 4096,
245 device_id: 3,
246 dtype: "float16".to_string(),
247 }],
248 status: SourceStatus::Initializing as i32,
249 updated_at: 1234567890000,
250 ..Default::default()
251 };
252
253 let record = WorkerRecord::from(meta.clone());
254 assert_eq!(record.worker_rank, 3);
255 assert!(matches!(
256 &record.backend_metadata,
257 BackendMetadataRecord::Nixl(d) if d == &vec![1, 2, 3, 4, 5]
258 ));
259 assert_eq!(record.tensors.len(), 1);
260 assert_eq!(record.status, SourceStatus::Initializing as i32);
261 assert_eq!(record.updated_at, 1234567890000);
262
263 let back: WorkerMetadata = record.into();
264 assert_eq!(back.worker_rank, meta.worker_rank);
265 assert_eq!(back.backend_metadata, meta.backend_metadata);
266 }
267
268 #[test]
269 fn test_worker_record_transfer_engine_roundtrip() {
270 use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
271
272 let meta = WorkerMetadata {
273 worker_rank: 1,
274 backend_metadata: Some(BackendMetadata::TransferEngineSessionId(
275 "192.168.1.10:12345".to_string(),
276 )),
277 tensors: vec![TensorDescriptor {
278 name: "test.weight".to_string(),
279 addr: 0x2000,
280 size: 8192,
281 device_id: 0,
282 dtype: "float16".to_string(),
283 }],
284 status: 0,
285 updated_at: 0,
286 ..Default::default()
287 };
288
289 let record = WorkerRecord::from(meta.clone());
290 assert_eq!(record.worker_rank, 1);
291 assert!(matches!(
292 &record.backend_metadata,
293 BackendMetadataRecord::TransferEngine(sid) if sid == "192.168.1.10:12345"
294 ));
295 assert_eq!(
296 record.backend_metadata.backend_type_str(),
297 "transfer_engine"
298 );
299
300 let back: WorkerMetadata = record.into();
301 assert_eq!(back.worker_rank, meta.worker_rank);
302 assert_eq!(back.backend_metadata, meta.backend_metadata);
303 }
304
305 #[test]
306 fn test_backend_metadata_from_flat_with_discriminator() {
307 let te = BackendMetadataRecord::from_flat(
309 Vec::new(),
310 Some("10.0.0.1:5000".into()),
311 Some("transfer_engine"),
312 );
313 assert!(matches!(te, BackendMetadataRecord::TransferEngine(ref s) if s == "10.0.0.1:5000"));
314
315 let nixl = BackendMetadataRecord::from_flat(vec![1, 2, 3], None, Some("nixl"));
316 assert!(matches!(nixl, BackendMetadataRecord::Nixl(ref d) if d == &vec![1, 2, 3]));
317
318 let none = BackendMetadataRecord::from_flat(Vec::new(), None, Some("none"));
319 assert!(matches!(none, BackendMetadataRecord::None));
320
321 let inferred_te =
323 BackendMetadataRecord::from_flat(Vec::new(), Some("10.0.0.1:5000".into()), None);
324 assert!(matches!(
325 inferred_te,
326 BackendMetadataRecord::TransferEngine(_)
327 ));
328
329 let inferred_nixl = BackendMetadataRecord::from_flat(vec![1, 2], None, None);
330 assert!(matches!(inferred_nixl, BackendMetadataRecord::Nixl(_)));
331
332 let inferred_none = BackendMetadataRecord::from_flat(Vec::new(), None, None);
333 assert!(matches!(inferred_none, BackendMetadataRecord::None));
334 }
335
336 #[test]
337 fn test_model_record_creation() {
338 let record = ModelMetadataRecord {
339 source_id: "abc123def456abcd".to_string(),
340 worker_id: "test-instance-id".to_string(),
341 model_name: "meta-llama/Llama-3.1-70B".to_string(),
342 workers: vec![
343 WorkerRecord {
344 worker_rank: 0,
345 backend_metadata: BackendMetadataRecord::Nixl(vec![10, 20, 30]),
346 tensors: vec![TensorRecord {
347 name: "layer.0.weight".to_string(),
348 addr: 0x7f00_0000_0000,
349 size: 1_000_000,
350 device_id: 0,
351 dtype: "bfloat16".to_string(),
352 }],
353 status: SourceStatus::Ready as i32,
354 updated_at: 1234567890000,
355 metadata_endpoint: String::new(),
356 agent_name: String::new(),
357 worker_grpc_endpoint: String::new(),
358 },
359 WorkerRecord {
360 worker_rank: 1,
361 backend_metadata: BackendMetadataRecord::Nixl(vec![40, 50, 60]),
362 tensors: vec![TensorRecord {
363 name: "layer.0.weight".to_string(),
364 addr: 0x7f00_0000_0000,
365 size: 1_000_000,
366 device_id: 1,
367 dtype: "bfloat16".to_string(),
368 }],
369 status: SourceStatus::Ready as i32,
370 updated_at: 1234567890000,
371 metadata_endpoint: String::new(),
372 agent_name: String::new(),
373 worker_grpc_endpoint: String::new(),
374 },
375 ],
376 published_at: 1234567890,
377 };
378
379 assert_eq!(record.model_name, "meta-llama/Llama-3.1-70B");
380 assert_eq!(record.workers.len(), 2);
381 assert_eq!(record.workers[0].worker_rank, 0);
382 assert_eq!(record.workers[1].worker_rank, 1);
383 }
384
385 #[test]
386 fn test_backend_config_parsing() {
387 let default_redis = "redis://localhost:6379";
388 let default_ns = "default";
389
390 let config =
392 BackendConfig::from_type_str("redis", "redis://myhost:6379", default_ns).expect("ok");
393 assert!(matches!(config, BackendConfig::Redis { url } if url == "redis://myhost:6379"));
394
395 for alias in &["kubernetes", "k8s", "crd"] {
397 let config = BackendConfig::from_type_str(alias, default_redis, "prod-ns").expect("ok");
398 assert!(
399 matches!(config, BackendConfig::Kubernetes { namespace } if namespace == "prod-ns")
400 );
401 }
402
403 assert!(BackendConfig::from_type_str("bogus", default_redis, default_ns).is_err());
405 assert!(BackendConfig::from_type_str("memory", default_redis, default_ns).is_err());
406 assert!(BackendConfig::from_type_str("", default_redis, default_ns).is_err());
407
408 let config = BackendConfig::from_type_str("REDIS", default_redis, default_ns).expect("ok");
410 assert!(matches!(config, BackendConfig::Redis { .. }));
411 }
412
413 #[tokio::test]
414 async fn test_publish_metadata_calls_backend() {
415 let mut mock = MockMetadataBackend::new();
416 mock.expect_publish_metadata()
417 .withf(|identity, worker_id, worker| {
418 identity.model_name == "my-model"
419 && identity.tensor_parallel_size == 8
420 && worker_id == "a1b2c3d4"
421 && worker.worker_rank == 3
422 })
423 .once()
424 .returning(|_, _, _| Ok(()));
425
426 let manager = P2pStateManager::with_backend(Arc::new(mock));
427 manager
428 .publish_metadata(
429 &test_identity(),
430 "a1b2c3d4",
431 WorkerMetadata {
432 worker_rank: 3,
433 backend_metadata: None,
434 tensors: vec![],
435 status: SourceStatus::Initializing as i32,
436 updated_at: 0,
437 ..Default::default()
438 },
439 )
440 .await
441 .expect("publish_metadata failed");
442 }
443
444 #[tokio::test]
445 async fn test_publish_metadata_propagates_backend_error() {
446 let mut mock = MockMetadataBackend::new();
447 mock.expect_publish_metadata()
448 .once()
449 .returning(|_, _, _| Err("storage unavailable".into()));
450
451 let manager = P2pStateManager::with_backend(Arc::new(mock));
452 assert!(
453 manager
454 .publish_metadata(&test_identity(), "a1b2c3d4", WorkerMetadata::default())
455 .await
456 .is_err()
457 );
458 }
459
460 #[tokio::test]
461 async fn test_connect_fails_without_config() {
462 let manager = P2pStateManager {
463 backend: Arc::new(RwLock::new(None)),
464 config: None,
465 };
466 assert!(manager.connect().await.is_err());
467 }
468
469 #[tokio::test]
470 async fn test_update_worker_status_calls_backend() {
471 let mut mock = MockMetadataBackend::new();
472 mock.expect_update_status()
473 .with(
474 eq("abc123def456abcd"),
475 eq("test-instance"),
476 eq(2u32),
477 eq(SourceStatus::Ready),
478 mockall::predicate::always(),
479 )
480 .once()
481 .returning(|_, _, _, _, _| Ok(()));
482
483 let manager = P2pStateManager::with_backend(Arc::new(mock));
484 manager
485 .update_worker_status("abc123def456abcd", "test-instance", 2, SourceStatus::Ready)
486 .await
487 .expect("update_worker_status failed");
488 }
489
490 #[tokio::test]
491 async fn test_update_worker_status_propagates_backend_error() {
492 let mut mock = MockMetadataBackend::new();
493 mock.expect_update_status()
494 .once()
495 .returning(|_, _, _, _, _| Err("redis unavailable".into()));
496
497 let manager = P2pStateManager::with_backend(Arc::new(mock));
498 assert!(
499 manager
500 .update_worker_status("abc123def456abcd", "test-instance", 0, SourceStatus::Ready)
501 .await
502 .is_err()
503 );
504 }
505
506 #[tokio::test]
507 async fn test_list_workers_calls_backend() {
508 let mut mock = MockMetadataBackend::new();
509 mock.expect_list_workers()
510 .withf(|source_id, status_filter| {
511 source_id.as_deref() == Some("abc123def456abcd")
512 && *status_filter == Some(SourceStatus::Ready)
513 })
514 .once()
515 .returning(|_, _| {
516 Ok(vec![crate::metadata_backend::SourceInstanceInfo {
517 source_id: "abc123def456abcd".to_string(),
518 worker_id: "w1".to_string(),
519 model_name: "my-model".to_string(),
520 worker_rank: 0,
521 status: SourceStatus::Ready as i32,
522 updated_at: 1234567890000,
523 }])
524 });
525
526 let manager = P2pStateManager::with_backend(Arc::new(mock));
527 let result = manager
528 .list_workers(
529 Some("abc123def456abcd".to_string()),
530 Some(SourceStatus::Ready),
531 )
532 .await
533 .expect("list_workers failed");
534 assert_eq!(result.len(), 1);
535 assert_eq!(result[0].worker_id, "w1");
536 }
537
538 #[tokio::test]
539 async fn test_list_workers_propagates_backend_error() {
540 let mut mock = MockMetadataBackend::new();
541 mock.expect_list_workers()
542 .once()
543 .returning(|_, _| Err("backend error".into()));
544
545 let manager = P2pStateManager::with_backend(Arc::new(mock));
546 assert!(manager.list_workers(None, None).await.is_err());
547 }
548
549 #[tokio::test]
550 async fn test_remove_metadata_calls_backend() {
551 let mut mock = MockMetadataBackend::new();
552 mock.expect_remove_metadata()
553 .with(eq("abc123def456abcd"))
554 .once()
555 .returning(|_| Ok(()));
556
557 let manager = P2pStateManager::with_backend(Arc::new(mock));
558 manager
559 .remove_metadata("abc123def456abcd")
560 .await
561 .expect("remove_metadata failed");
562 }
563
564 #[tokio::test]
565 async fn test_remove_metadata_propagates_backend_error() {
566 let mut mock = MockMetadataBackend::new();
567 mock.expect_remove_metadata()
568 .once()
569 .returning(|_| Err("delete failed".into()));
570
571 let manager = P2pStateManager::with_backend(Arc::new(mock));
572 assert!(manager.remove_metadata("abc123def456abcd").await.is_err());
573 }
574
575 #[tokio::test]
576 async fn test_list_sources_calls_backend() {
577 let mut mock = MockMetadataBackend::new();
578 mock.expect_list_sources()
579 .once()
580 .returning(|| Ok(vec![("src1".to_string(), "model-a".to_string())]));
581
582 let manager = P2pStateManager::with_backend(Arc::new(mock));
583 let result = manager.list_sources().await.expect("list_sources failed");
584 assert_eq!(result.len(), 1);
585 assert_eq!(result[0].0, "src1");
586 assert_eq!(result[0].1, "model-a");
587 }
588
589 #[tokio::test]
590 async fn test_update_worker_status_stores_correct_status() {
591 let mut mock = MockMetadataBackend::new();
592 mock.expect_update_status()
593 .withf(|source_id, worker_id, worker_rank, status, _updated_at| {
594 source_id == "abc123def456abcd"
595 && worker_id == "test-instance"
596 && *worker_rank == 7
597 && *status == SourceStatus::Ready
598 })
599 .once()
600 .returning(|_, _, _, _, _| Ok(()));
601
602 let manager = P2pStateManager::with_backend(Arc::new(mock));
603 manager
604 .update_worker_status("abc123def456abcd", "test-instance", 7, SourceStatus::Ready)
605 .await
606 .expect("update_worker_status failed");
607 }
608}