1use crate::p2p::backend::{BackendConfig, MetadataBackend, MetadataResult, create_backend};
11use modelexpress_common::grpc::p2p::{SourceIdentity, WorkerMetadata};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info};
15
16pub use crate::p2p::backend::{
18 BackendMetadataRecord, ModelMetadataRecord, TensorRecord, WorkerRecord,
19};
20
21#[derive(Clone)]
26pub struct P2pStateManager {
27 backend: Arc<RwLock<Option<Arc<dyn MetadataBackend>>>>,
28 config: Option<BackendConfig>,
29}
30
31impl Default for P2pStateManager {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl P2pStateManager {
38 pub fn new() -> Self {
43 Self {
44 backend: Arc::new(RwLock::new(None)),
45 config: BackendConfig::from_env().ok(),
46 }
47 }
48
49 pub fn with_config(config: BackendConfig) -> Self {
51 Self {
52 backend: Arc::new(RwLock::new(None)),
53 config: Some(config),
54 }
55 }
56
57 #[cfg(test)]
59 pub fn with_backend(backend: Arc<dyn MetadataBackend>) -> Self {
60 Self {
61 backend: Arc::new(RwLock::new(Some(backend))),
62 config: None,
63 }
64 }
65
66 pub async fn connect(&self) -> MetadataResult<String> {
68 let config = self.config.clone().ok_or(
69 "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
70 )?;
71
72 let backend_name = config.to_string();
73 let backend = create_backend(config).await?;
74 let mut guard = self.backend.write().await;
75 *guard = Some(backend);
76
77 info!("P2pStateManager connected (backend: {})", backend_name);
78 Ok(backend_name)
79 }
80
81 async fn get_backend(&self) -> MetadataResult<Arc<dyn MetadataBackend>> {
83 {
84 let guard = self.backend.read().await;
85 if let Some(backend) = guard.as_ref() {
86 return Ok(backend.clone());
87 }
88 }
89
90 let mut guard = self.backend.write().await;
91 if let Some(backend) = guard.as_ref() {
92 return Ok(backend.clone());
93 }
94
95 let config = self.config.clone().ok_or(
96 "MX_METADATA_BACKEND is not set or invalid. Set it to 'redis' or 'kubernetes'.",
97 )?;
98
99 let backend = create_backend(config.clone()).await?;
100 info!("P2pStateManager connected with {:?}", config);
101 *guard = Some(backend.clone());
102 Ok(backend)
103 }
104
105 pub async fn publish_metadata(
111 &self,
112 identity: &SourceIdentity,
113 worker_id: &str,
114 worker: WorkerMetadata,
115 ) -> MetadataResult<()> {
116 self.get_backend()
117 .await?
118 .publish_metadata(identity, worker_id, worker)
119 .await
120 }
121
122 pub async fn get_metadata(
124 &self,
125 source_id: &str,
126 worker_id: &str,
127 ) -> MetadataResult<Option<ModelMetadataRecord>> {
128 self.get_backend()
129 .await?
130 .get_metadata(source_id, worker_id)
131 .await
132 }
133
134 pub async fn list_workers(
136 &self,
137 source_id: Option<String>,
138 status_filter: Option<modelexpress_common::grpc::p2p::SourceStatus>,
139 ) -> MetadataResult<Vec<crate::p2p::backend::SourceInstanceInfo>> {
140 self.get_backend()
141 .await?
142 .list_workers(source_id, status_filter)
143 .await
144 }
145
146 pub async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
148 self.get_backend().await?.remove_metadata(source_id).await
149 }
150
151 pub async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
153 self.get_backend()
154 .await?
155 .remove_worker(source_id, worker_id)
156 .await
157 }
158
159 pub async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
161 self.get_backend().await?.list_sources().await
162 }
163
164 pub async fn update_worker_status(
170 &self,
171 source_id: &str,
172 worker_id: &str,
173 worker_rank: u32,
174 status: modelexpress_common::grpc::p2p::SourceStatus,
175 ) -> MetadataResult<()> {
176 let updated_at = chrono::Utc::now().timestamp_millis();
177 self.get_backend()
178 .await?
179 .update_status(source_id, worker_id, worker_rank, status, updated_at)
180 .await?;
181
182 debug!(
183 "Updated status for source '{}' worker '{}' rank {} -> {}",
184 source_id, worker_id, worker_rank, status as i32
185 );
186 Ok(())
187 }
188}
189
190#[cfg(test)]
191#[allow(clippy::expect_used)]
192mod tests {
193 use super::*;
194 use crate::p2p::backend::MockMetadataBackend;
195 use mockall::predicate::eq;
196 use modelexpress_common::grpc::p2p::{
197 MxSourceType, SourceIdentity, SourceStatus, TensorDescriptor,
198 };
199
200 fn test_identity() -> SourceIdentity {
201 SourceIdentity {
202 mx_version: "0.4.0".to_string(),
203 mx_source_type: MxSourceType::Weights as i32,
204 model_name: "my-model".to_string(),
205 backend_framework: 1,
206 tensor_parallel_size: 8,
207 pipeline_parallel_size: 1,
208 expert_parallel_size: 0,
209 dtype: "bfloat16".to_string(),
210 quantization: String::new(),
211 extra_parameters: Default::default(),
212 revision: String::new(),
213 }
214 }
215
216 #[test]
217 fn test_tensor_record_conversion() {
218 let desc = TensorDescriptor {
219 name: "model.layers.0.weight".to_string(),
220 addr: 0x7f0000000000,
221 size: 1024 * 1024 * 1024,
222 device_id: 0,
223 dtype: "bfloat16".to_string(),
224 };
225
226 let record = TensorRecord::from(desc.clone());
227 assert_eq!(record.name, "model.layers.0.weight");
228 assert_eq!(record.size, 1024 * 1024 * 1024);
229
230 let back: TensorDescriptor = record.into();
231 assert_eq!(back.name, desc.name);
232 assert_eq!(back.addr, desc.addr);
233 }
234
235 #[test]
236 fn test_worker_record_conversion() {
237 use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
238
239 let meta = WorkerMetadata {
240 worker_rank: 3,
241 backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3, 4, 5])),
242 tensors: vec![TensorDescriptor {
243 name: "test.weight".to_string(),
244 addr: 0x1000,
245 size: 4096,
246 device_id: 3,
247 dtype: "float16".to_string(),
248 }],
249 status: SourceStatus::Initializing as i32,
250 updated_at: 1234567890000,
251 ..Default::default()
252 };
253
254 let record = WorkerRecord::from(meta.clone());
255 assert_eq!(record.worker_rank, 3);
256 assert!(matches!(
257 &record.backend_metadata,
258 BackendMetadataRecord::Nixl(d) if d == &vec![1, 2, 3, 4, 5]
259 ));
260 assert_eq!(record.tensors.len(), 1);
261 assert_eq!(record.status, SourceStatus::Initializing as i32);
262 assert_eq!(record.updated_at, 1234567890000);
263
264 let back: WorkerMetadata = record.into();
265 assert_eq!(back.worker_rank, meta.worker_rank);
266 assert_eq!(back.backend_metadata, meta.backend_metadata);
267 }
268
269 #[test]
270 fn test_worker_record_transfer_engine_roundtrip() {
271 use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
272
273 let meta = WorkerMetadata {
274 worker_rank: 1,
275 backend_metadata: Some(BackendMetadata::TransferEngineSessionId(
276 "192.168.1.10:12345".to_string(),
277 )),
278 tensors: vec![TensorDescriptor {
279 name: "test.weight".to_string(),
280 addr: 0x2000,
281 size: 8192,
282 device_id: 0,
283 dtype: "float16".to_string(),
284 }],
285 status: 0,
286 updated_at: 0,
287 ..Default::default()
288 };
289
290 let record = WorkerRecord::from(meta.clone());
291 assert_eq!(record.worker_rank, 1);
292 assert!(matches!(
293 &record.backend_metadata,
294 BackendMetadataRecord::TransferEngine(sid) if sid == "192.168.1.10:12345"
295 ));
296 assert_eq!(
297 record.backend_metadata.backend_type_str(),
298 "transfer_engine"
299 );
300
301 let back: WorkerMetadata = record.into();
302 assert_eq!(back.worker_rank, meta.worker_rank);
303 assert_eq!(back.backend_metadata, meta.backend_metadata);
304 }
305
306 #[test]
307 fn test_backend_metadata_from_flat_with_discriminator() {
308 let te = BackendMetadataRecord::from_flat(
310 Vec::new(),
311 Some("10.0.0.1:5000".into()),
312 Some("transfer_engine"),
313 );
314 assert!(matches!(te, BackendMetadataRecord::TransferEngine(ref s) if s == "10.0.0.1:5000"));
315
316 let nixl = BackendMetadataRecord::from_flat(vec![1, 2, 3], None, Some("nixl"));
317 assert!(matches!(nixl, BackendMetadataRecord::Nixl(ref d) if d == &vec![1, 2, 3]));
318
319 let none = BackendMetadataRecord::from_flat(Vec::new(), None, Some("none"));
320 assert!(matches!(none, BackendMetadataRecord::None));
321
322 let inferred_te =
324 BackendMetadataRecord::from_flat(Vec::new(), Some("10.0.0.1:5000".into()), None);
325 assert!(matches!(
326 inferred_te,
327 BackendMetadataRecord::TransferEngine(_)
328 ));
329
330 let inferred_nixl = BackendMetadataRecord::from_flat(vec![1, 2], None, None);
331 assert!(matches!(inferred_nixl, BackendMetadataRecord::Nixl(_)));
332
333 let inferred_none = BackendMetadataRecord::from_flat(Vec::new(), None, None);
334 assert!(matches!(inferred_none, BackendMetadataRecord::None));
335 }
336
337 #[test]
338 fn test_model_record_creation() {
339 let record = ModelMetadataRecord {
340 source_id: "abc123def456abcd".to_string(),
341 worker_id: "test-instance-id".to_string(),
342 model_name: "meta-llama/Llama-3.1-70B".to_string(),
343 workers: vec![
344 WorkerRecord {
345 worker_rank: 0,
346 backend_metadata: BackendMetadataRecord::Nixl(vec![10, 20, 30]),
347 tensors: vec![TensorRecord {
348 name: "layer.0.weight".to_string(),
349 addr: 0x7f00_0000_0000,
350 size: 1_000_000,
351 device_id: 0,
352 dtype: "bfloat16".to_string(),
353 }],
354 status: SourceStatus::Ready as i32,
355 updated_at: 1234567890000,
356 metadata_endpoint: String::new(),
357 agent_name: String::new(),
358 worker_grpc_endpoint: String::new(),
359 },
360 WorkerRecord {
361 worker_rank: 1,
362 backend_metadata: BackendMetadataRecord::Nixl(vec![40, 50, 60]),
363 tensors: vec![TensorRecord {
364 name: "layer.0.weight".to_string(),
365 addr: 0x7f00_0000_0000,
366 size: 1_000_000,
367 device_id: 1,
368 dtype: "bfloat16".to_string(),
369 }],
370 status: SourceStatus::Ready as i32,
371 updated_at: 1234567890000,
372 metadata_endpoint: String::new(),
373 agent_name: String::new(),
374 worker_grpc_endpoint: String::new(),
375 },
376 ],
377 published_at: 1234567890,
378 };
379
380 assert_eq!(record.model_name, "meta-llama/Llama-3.1-70B");
381 assert_eq!(record.workers.len(), 2);
382 assert_eq!(record.workers[0].worker_rank, 0);
383 assert_eq!(record.workers[1].worker_rank, 1);
384 }
385
386 #[tokio::test]
387 async fn test_publish_metadata_calls_backend() {
388 let mut mock = MockMetadataBackend::new();
389 mock.expect_publish_metadata()
390 .withf(|identity, worker_id, worker| {
391 identity.model_name == "my-model"
392 && identity.tensor_parallel_size == 8
393 && worker_id == "a1b2c3d4"
394 && worker.worker_rank == 3
395 })
396 .once()
397 .returning(|_, _, _| Ok(()));
398
399 let manager = P2pStateManager::with_backend(Arc::new(mock));
400 manager
401 .publish_metadata(
402 &test_identity(),
403 "a1b2c3d4",
404 WorkerMetadata {
405 worker_rank: 3,
406 backend_metadata: None,
407 tensors: vec![],
408 status: SourceStatus::Initializing as i32,
409 updated_at: 0,
410 ..Default::default()
411 },
412 )
413 .await
414 .expect("publish_metadata failed");
415 }
416
417 #[tokio::test]
418 async fn test_publish_metadata_propagates_backend_error() {
419 let mut mock = MockMetadataBackend::new();
420 mock.expect_publish_metadata()
421 .once()
422 .returning(|_, _, _| Err("storage unavailable".into()));
423
424 let manager = P2pStateManager::with_backend(Arc::new(mock));
425 assert!(
426 manager
427 .publish_metadata(&test_identity(), "a1b2c3d4", WorkerMetadata::default())
428 .await
429 .is_err()
430 );
431 }
432
433 #[tokio::test]
434 async fn test_connect_fails_without_config() {
435 let manager = P2pStateManager {
436 backend: Arc::new(RwLock::new(None)),
437 config: None,
438 };
439 assert!(manager.connect().await.is_err());
440 }
441
442 #[tokio::test]
443 async fn test_update_worker_status_calls_backend() {
444 let mut mock = MockMetadataBackend::new();
445 mock.expect_update_status()
446 .with(
447 eq("abc123def456abcd"),
448 eq("test-instance"),
449 eq(2u32),
450 eq(SourceStatus::Ready),
451 mockall::predicate::always(),
452 )
453 .once()
454 .returning(|_, _, _, _, _| Ok(()));
455
456 let manager = P2pStateManager::with_backend(Arc::new(mock));
457 manager
458 .update_worker_status("abc123def456abcd", "test-instance", 2, SourceStatus::Ready)
459 .await
460 .expect("update_worker_status failed");
461 }
462
463 #[tokio::test]
464 async fn test_update_worker_status_propagates_backend_error() {
465 let mut mock = MockMetadataBackend::new();
466 mock.expect_update_status()
467 .once()
468 .returning(|_, _, _, _, _| Err("redis unavailable".into()));
469
470 let manager = P2pStateManager::with_backend(Arc::new(mock));
471 assert!(
472 manager
473 .update_worker_status("abc123def456abcd", "test-instance", 0, SourceStatus::Ready)
474 .await
475 .is_err()
476 );
477 }
478
479 #[tokio::test]
480 async fn test_list_workers_calls_backend() {
481 let mut mock = MockMetadataBackend::new();
482 mock.expect_list_workers()
483 .withf(|source_id, status_filter| {
484 source_id.as_deref() == Some("abc123def456abcd")
485 && *status_filter == Some(SourceStatus::Ready)
486 })
487 .once()
488 .returning(|_, _| {
489 Ok(vec![crate::p2p::backend::SourceInstanceInfo {
490 source_id: "abc123def456abcd".to_string(),
491 worker_id: "w1".to_string(),
492 model_name: "my-model".to_string(),
493 worker_rank: 0,
494 status: SourceStatus::Ready as i32,
495 updated_at: 1234567890000,
496 }])
497 });
498
499 let manager = P2pStateManager::with_backend(Arc::new(mock));
500 let result = manager
501 .list_workers(
502 Some("abc123def456abcd".to_string()),
503 Some(SourceStatus::Ready),
504 )
505 .await
506 .expect("list_workers failed");
507 assert_eq!(result.len(), 1);
508 assert_eq!(result[0].worker_id, "w1");
509 }
510
511 #[tokio::test]
512 async fn test_list_workers_propagates_backend_error() {
513 let mut mock = MockMetadataBackend::new();
514 mock.expect_list_workers()
515 .once()
516 .returning(|_, _| Err("backend error".into()));
517
518 let manager = P2pStateManager::with_backend(Arc::new(mock));
519 assert!(manager.list_workers(None, None).await.is_err());
520 }
521
522 #[tokio::test]
523 async fn test_remove_metadata_calls_backend() {
524 let mut mock = MockMetadataBackend::new();
525 mock.expect_remove_metadata()
526 .with(eq("abc123def456abcd"))
527 .once()
528 .returning(|_| Ok(()));
529
530 let manager = P2pStateManager::with_backend(Arc::new(mock));
531 manager
532 .remove_metadata("abc123def456abcd")
533 .await
534 .expect("remove_metadata failed");
535 }
536
537 #[tokio::test]
538 async fn test_remove_metadata_propagates_backend_error() {
539 let mut mock = MockMetadataBackend::new();
540 mock.expect_remove_metadata()
541 .once()
542 .returning(|_| Err("delete failed".into()));
543
544 let manager = P2pStateManager::with_backend(Arc::new(mock));
545 assert!(manager.remove_metadata("abc123def456abcd").await.is_err());
546 }
547
548 #[tokio::test]
549 async fn test_list_sources_calls_backend() {
550 let mut mock = MockMetadataBackend::new();
551 mock.expect_list_sources()
552 .once()
553 .returning(|| Ok(vec![("src1".to_string(), "model-a".to_string())]));
554
555 let manager = P2pStateManager::with_backend(Arc::new(mock));
556 let result = manager.list_sources().await.expect("list_sources failed");
557 assert_eq!(result.len(), 1);
558 assert_eq!(result[0].0, "src1");
559 assert_eq!(result[0].1, "model-a");
560 }
561
562 #[tokio::test]
563 async fn test_update_worker_status_stores_correct_status() {
564 let mut mock = MockMetadataBackend::new();
565 mock.expect_update_status()
566 .withf(|source_id, worker_id, worker_rank, status, _updated_at| {
567 source_id == "abc123def456abcd"
568 && worker_id == "test-instance"
569 && *worker_rank == 7
570 && *status == SourceStatus::Ready
571 })
572 .once()
573 .returning(|_, _, _, _, _| Ok(()));
574
575 let manager = P2pStateManager::with_backend(Arc::new(mock));
576 manager
577 .update_worker_status("abc123def456abcd", "test-instance", 7, SourceStatus::Ready)
578 .await
579 .expect("update_worker_status failed");
580 }
581}