Skip to main content

modelexpress_server/metadata_backend/
redis.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Redis backend for P2P model metadata storage.
5//!
6//! Storage layout:
7//!   `mx:source:{source_id}`               — Redis Hash; field `__attributes__` stores
8//!                                            JSON-serialized SourceAttributesJson (once
9//!                                            per source); each other field is an worker_id
10//!                                            with an empty-string value (presence marker).
11//!   `mx:source:{source_id}:{worker_id}` — Redis Hash; field = worker_rank (string),
12//!                                            value = JSON-serialized WorkerRecordJson.
13//!
14//! Global listing uses SCAN with pattern `mx:source:????????????????` (16-char source IDs)
15//! to enumerate source index keys without a separate secondary index.
16
17use super::{MetadataBackend, MetadataResult, ModelMetadataRecord, TensorRecord, WorkerRecord};
18use async_trait::async_trait;
19use modelexpress_common::grpc::p2p::WorkerMetadata;
20use modelexpress_common::grpc::p2p::{SourceIdentity, SourceStatus};
21use redis::AsyncCommands;
22use redis::aio::ConnectionManager;
23use serde::{Deserialize, Serialize};
24use std::sync::Arc;
25use tokio::sync::RwLock;
26use tracing::{debug, info};
27
28/// Redis key prefixes and reserved field names
29mod keys {
30    pub const SOURCE_PREFIX: &str = "mx:source:";
31    /// SCAN pattern matching source index keys: `mx:source:{16-char-id}`
32    pub const SOURCE_SCAN_PATTERN: &str = "mx:source:????????????????";
33    /// Reserved hash field in the source index key that stores SourceAttributesJson.
34    pub const ATTRIBUTES_FIELD: &str = "__attributes__";
35}
36
37/// All fields of a SourceIdentity stored once per source in the index hash.
38#[derive(Debug, Clone, Serialize, Deserialize, Default)]
39struct SourceAttributesJson {
40    pub model_name: String,
41    #[serde(default)]
42    pub mx_version: String,
43    #[serde(default)]
44    pub mx_source_type: i32,
45    #[serde(default)]
46    pub backend_framework: i32,
47    #[serde(default)]
48    pub tensor_parallel_size: u32,
49    #[serde(default)]
50    pub pipeline_parallel_size: u32,
51    #[serde(default)]
52    pub expert_parallel_size: u32,
53    #[serde(default)]
54    pub dtype: String,
55    #[serde(default)]
56    pub quantization: String,
57}
58
59impl From<&SourceIdentity> for SourceAttributesJson {
60    fn from(id: &SourceIdentity) -> Self {
61        Self {
62            model_name: id.model_name.clone(),
63            mx_version: id.mx_version.clone(),
64            mx_source_type: id.mx_source_type,
65            backend_framework: id.backend_framework,
66            tensor_parallel_size: id.tensor_parallel_size,
67            pipeline_parallel_size: id.pipeline_parallel_size,
68            expert_parallel_size: id.expert_parallel_size,
69            dtype: id.dtype.clone(),
70            quantization: id.quantization.clone(),
71        }
72    }
73}
74
75/// Scan Redis for all keys matching `pattern`, iterating through all SCAN cursors.
76async fn scan_keys(conn: &mut ConnectionManager, pattern: &str) -> MetadataResult<Vec<String>> {
77    let mut all_keys = Vec::new();
78    let mut cursor: u64 = 0;
79    loop {
80        let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
81            .arg(cursor)
82            .arg("MATCH")
83            .arg(pattern)
84            .arg("COUNT")
85            .arg(100)
86            .query_async(conn)
87            .await?;
88        all_keys.extend(batch);
89        cursor = next_cursor;
90        if cursor == 0 {
91            break;
92        }
93    }
94    Ok(all_keys)
95}
96
97/// Serializable version of TensorRecord for Redis storage
98/// NOTE: addr and size are serialized as strings to avoid Lua cjson precision issues
99#[derive(Debug, Clone, Serialize, Deserialize)]
100struct TensorRecordJson {
101    pub name: String,
102    #[serde(
103        serialize_with = "serialize_u64_as_string",
104        deserialize_with = "deserialize_u64_from_any"
105    )]
106    pub addr: u64,
107    #[serde(
108        serialize_with = "serialize_u64_as_string",
109        deserialize_with = "deserialize_u64_from_any"
110    )]
111    pub size: u64,
112    pub device_id: u32,
113    pub dtype: String,
114}
115
116fn serialize_u64_as_string<S>(value: &u64, serializer: S) -> Result<S::Ok, S::Error>
117where
118    S: serde::Serializer,
119{
120    serializer.serialize_str(&value.to_string())
121}
122
123fn deserialize_u64_from_any<'de, D>(deserializer: D) -> Result<u64, D::Error>
124where
125    D: serde::Deserializer<'de>,
126{
127    use serde::de::{self, Visitor};
128
129    struct U64Visitor;
130
131    impl<'de> Visitor<'de> for U64Visitor {
132        type Value = u64;
133
134        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
135            formatter.write_str("a u64 as string or number")
136        }
137
138        fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> {
139            Ok(value)
140        }
141
142        fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
143        where
144            E: de::Error,
145        {
146            u64::try_from(value).map_err(|_| E::custom("negative value"))
147        }
148
149        fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
150        where
151            E: de::Error,
152        {
153            // Handle floats from cjson (the problematic case)
154            Ok(value as u64)
155        }
156
157        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
158        where
159            E: de::Error,
160        {
161            value.parse::<u64>().map_err(de::Error::custom)
162        }
163    }
164
165    deserializer.deserialize_any(U64Visitor)
166}
167
168impl From<TensorRecord> for TensorRecordJson {
169    fn from(record: TensorRecord) -> Self {
170        Self {
171            name: record.name,
172            addr: record.addr,
173            size: record.size,
174            device_id: record.device_id,
175            dtype: record.dtype,
176        }
177    }
178}
179
180impl From<TensorRecordJson> for TensorRecord {
181    fn from(json: TensorRecordJson) -> Self {
182        Self {
183            name: json.name,
184            addr: json.addr,
185            size: json.size,
186            device_id: json.device_id,
187            dtype: json.dtype,
188        }
189    }
190}
191
192/// Serializable version of WorkerRecord stored as a hash field value
193#[derive(Debug, Clone, Serialize, Deserialize)]
194struct WorkerRecordJson {
195    pub worker_rank: u32,
196    /// Explicit backend type discriminator ("nixl", "transfer_engine", "none").
197    #[serde(default)]
198    pub backend_type: Option<String>,
199    #[serde(default)]
200    pub nixl_metadata: Vec<u8>,
201    #[serde(default)]
202    pub transfer_engine_session_id: Option<String>,
203    pub tensors: Vec<TensorRecordJson>,
204    #[serde(default)]
205    pub status: i32,
206    #[serde(default)]
207    pub updated_at: i64,
208    /// P2P: NIXL listen thread endpoint
209    #[serde(default)]
210    pub metadata_endpoint: String,
211    /// P2P: NIXL agent name
212    #[serde(default)]
213    pub agent_name: String,
214    /// P2P: Worker gRPC endpoint for tensor manifest
215    #[serde(default)]
216    pub worker_grpc_endpoint: String,
217}
218
219impl WorkerRecordJson {
220    fn from_worker_record(record: WorkerRecord) -> Self {
221        let backend_type = record.backend_metadata.backend_type_str().to_string();
222        let (nixl_metadata, transfer_engine_session_id) = match record.backend_metadata {
223            super::BackendMetadataRecord::Nixl(data) => (data, None),
224            super::BackendMetadataRecord::TransferEngine(sid) => (Vec::new(), Some(sid)),
225            super::BackendMetadataRecord::None => (Vec::new(), None),
226        };
227        Self {
228            worker_rank: record.worker_rank,
229            backend_type: Some(backend_type),
230            nixl_metadata,
231            transfer_engine_session_id,
232            tensors: record
233                .tensors
234                .into_iter()
235                .map(TensorRecordJson::from)
236                .collect(),
237            status: record.status,
238            updated_at: record.updated_at,
239            metadata_endpoint: record.metadata_endpoint,
240            agent_name: record.agent_name,
241            worker_grpc_endpoint: record.worker_grpc_endpoint,
242        }
243    }
244}
245
246impl From<WorkerRecordJson> for WorkerRecord {
247    fn from(json: WorkerRecordJson) -> Self {
248        Self {
249            worker_rank: json.worker_rank,
250            backend_metadata: super::BackendMetadataRecord::from_flat(
251                json.nixl_metadata,
252                json.transfer_engine_session_id,
253                json.backend_type.as_deref(),
254            ),
255            tensors: json.tensors.into_iter().map(TensorRecord::from).collect(),
256            status: json.status,
257            updated_at: json.updated_at,
258            metadata_endpoint: json.metadata_endpoint,
259            agent_name: json.agent_name,
260            worker_grpc_endpoint: json.worker_grpc_endpoint,
261        }
262    }
263}
264
265/// Redis backend for metadata storage
266pub struct RedisBackend {
267    redis: Arc<RwLock<Option<ConnectionManager>>>,
268    redis_url: String,
269}
270
271impl RedisBackend {
272    /// Create a new Redis backend
273    pub fn new(redis_url: &str) -> Self {
274        Self {
275            redis: Arc::new(RwLock::new(None)),
276            redis_url: redis_url.to_string(),
277        }
278    }
279
280    /// Get a Redis connection, reconnecting if necessary
281    async fn get_conn(&self) -> MetadataResult<ConnectionManager> {
282        // Fast path: read lock
283        {
284            let guard = self.redis.read().await;
285            if let Some(conn) = guard.as_ref() {
286                return Ok(conn.clone());
287            }
288        }
289
290        // Slow path: write lock with double-check
291        let mut guard = self.redis.write().await;
292        if let Some(conn) = guard.as_ref() {
293            return Ok(conn.clone());
294        }
295
296        let client = redis::Client::open(self.redis_url.as_str())?;
297        let conn = ConnectionManager::new(client).await?;
298        *guard = Some(conn.clone());
299        Ok(conn)
300    }
301}
302
303#[async_trait]
304impl MetadataBackend for RedisBackend {
305    async fn connect(&self) -> MetadataResult<()> {
306        let client = redis::Client::open(self.redis_url.as_str())?;
307        let conn = ConnectionManager::new(client).await?;
308
309        let mut guard = self.redis.write().await;
310        *guard = Some(conn);
311
312        // Redact credentials from URL before logging
313        let safe_url = if self.redis_url.contains('@') {
314            if let Some(at_pos) = self.redis_url.rfind('@') {
315                let prefix = &self.redis_url[..at_pos];
316                let suffix = &self.redis_url[at_pos..];
317                if let Some(colon_pos) = prefix.rfind(':') {
318                    format!("{}:***{}", &prefix[..colon_pos], suffix)
319                } else {
320                    self.redis_url.clone()
321                }
322            } else {
323                self.redis_url.clone()
324            }
325        } else {
326            self.redis_url.clone()
327        };
328        info!("Connected to Redis at {}", safe_url);
329        Ok(())
330    }
331
332    async fn publish_metadata(
333        &self,
334        identity: &SourceIdentity,
335        worker_id: &str,
336        worker: WorkerMetadata,
337    ) -> MetadataResult<()> {
338        let source_id = crate::source_identity::compute_mx_source_id(identity);
339        let mut conn = self.get_conn().await?;
340        let worker_key = format!("{}{}:{}", keys::SOURCE_PREFIX, source_id, worker_id);
341        let source_key = format!("{}{}", keys::SOURCE_PREFIX, source_id);
342
343        let worker_record = WorkerRecord::from(worker);
344        let attr_json = serde_json::to_string(&SourceAttributesJson::from(identity))?;
345        let json = WorkerRecordJson::from_worker_record(worker_record.clone());
346        let value = serde_json::to_string(&json)?;
347
348        let mut pipe = redis::pipe();
349        pipe.hset(&worker_key, worker_record.worker_rank.to_string(), &value);
350        pipe.hset(&source_key, keys::ATTRIBUTES_FIELD, &attr_json);
351        pipe.hset(
352            &source_key,
353            worker_id,
354            worker_record.worker_rank.to_string(),
355        );
356        pipe.exec_async(&mut conn).await?;
357
358        info!(
359            "Published metadata for '{}' (source_id={source_id}, worker_id={}): rank {} ({} tensors)",
360            identity.model_name,
361            worker_id,
362            worker_record.worker_rank,
363            worker_record.tensors.len(),
364        );
365        Ok(())
366    }
367
368    async fn get_metadata(
369        &self,
370        source_id: &str,
371        worker_id: &str,
372    ) -> MetadataResult<Option<ModelMetadataRecord>> {
373        let mut conn = self.get_conn().await?;
374        let key = format!("{}{}:{}", keys::SOURCE_PREFIX, source_id, worker_id);
375
376        let fields: std::collections::HashMap<String, String> = conn.hgetall(&key).await?;
377        if fields.is_empty() {
378            debug!(
379                "No metadata found for source_id={} worker_id={}",
380                source_id, worker_id
381            );
382            return Ok(None);
383        }
384
385        // Fetch model_name from the source index key's __attributes__ field.
386        let source_key = format!("{}{}", keys::SOURCE_PREFIX, source_id);
387        let attr_json: Option<String> = conn.hget(&source_key, keys::ATTRIBUTES_FIELD).await?;
388        let model_name = attr_json
389            .and_then(|v| serde_json::from_str::<SourceAttributesJson>(&v).ok())
390            .map(|a| a.model_name)
391            .unwrap_or_default();
392
393        let mut workers: Vec<WorkerRecord> = Vec::with_capacity(fields.len());
394        for value in fields.values() {
395            let json: WorkerRecordJson = serde_json::from_str(value)?;
396            workers.push(WorkerRecord::from(json));
397        }
398        workers.sort_by_key(|w| w.worker_rank);
399
400        debug!(
401            "Retrieved metadata for source_id={} worker_id={}: {} workers",
402            source_id,
403            worker_id,
404            workers.len()
405        );
406
407        Ok(Some(ModelMetadataRecord {
408            source_id: source_id.to_string(),
409            worker_id: worker_id.to_string(),
410            model_name,
411            workers,
412            published_at: 0,
413        }))
414    }
415
416    async fn list_workers(
417        &self,
418        source_id: Option<String>,
419        status_filter: Option<SourceStatus>,
420    ) -> MetadataResult<Vec<super::SourceInstanceInfo>> {
421        let mut conn = self.get_conn().await?;
422
423        // Collect source_ids to query
424        let source_ids: Vec<String> = if let Some(sid) = source_id {
425            vec![sid]
426        } else {
427            scan_keys(&mut conn, keys::SOURCE_SCAN_PATTERN)
428                .await?
429                .into_iter()
430                .map(|k| k[keys::SOURCE_PREFIX.len()..].to_string())
431                .collect()
432        };
433
434        let mut result = Vec::new();
435
436        for sid in &source_ids {
437            let source_key = format!("{}{}", keys::SOURCE_PREFIX, sid);
438            let instance_map: std::collections::HashMap<String, String> =
439                conn.hgetall(&source_key).await?;
440
441            let model_name = instance_map
442                .get(keys::ATTRIBUTES_FIELD)
443                .and_then(|v| serde_json::from_str::<SourceAttributesJson>(v).ok())
444                .map(|a| a.model_name)
445                .unwrap_or_default();
446
447            for (iid, rank_str) in instance_map
448                .iter()
449                .filter(|(k, _)| k.as_str() != keys::ATTRIBUTES_FIELD)
450            {
451                let worker_rank: u32 = rank_str.parse().unwrap_or(0);
452                let worker_key = format!("{}{}:{}", keys::SOURCE_PREFIX, sid, iid);
453                let fields: std::collections::HashMap<String, String> =
454                    conn.hgetall(&worker_key).await?;
455                if fields.is_empty() {
456                    continue;
457                }
458
459                if let Some(required_status) = status_filter {
460                    let matches = fields.values().any(|v| {
461                        serde_json::from_str::<WorkerRecordJson>(v)
462                            .map(|j| j.status == required_status as i32)
463                            .unwrap_or(false)
464                    });
465                    if !matches {
466                        continue;
467                    }
468                }
469
470                let (status, updated_at) = fields
471                    .get(&worker_rank.to_string())
472                    .and_then(|v| serde_json::from_str::<WorkerRecordJson>(v).ok())
473                    .map(|j| (j.status, j.updated_at))
474                    .unwrap_or((0, 0));
475
476                result.push(super::SourceInstanceInfo {
477                    source_id: sid.clone(),
478                    worker_id: iid.to_string(),
479                    model_name: model_name.clone(),
480                    worker_rank,
481                    status,
482                    updated_at,
483                });
484            }
485        }
486
487        Ok(result)
488    }
489
490    async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
491        let mut conn = self.get_conn().await?;
492        let source_key = format!("{}{}", keys::SOURCE_PREFIX, source_id);
493
494        let instance_map: std::collections::HashMap<String, String> =
495            conn.hgetall(&source_key).await?;
496
497        let mut pipe = redis::pipe();
498        for iid in instance_map
499            .keys()
500            .filter(|k| k.as_str() != keys::ATTRIBUTES_FIELD)
501        {
502            let worker_key = format!("{}{}:{}", keys::SOURCE_PREFIX, source_id, iid);
503            pipe.del(worker_key);
504        }
505        pipe.del(&source_key);
506
507        pipe.exec_async(&mut conn).await?;
508        info!("Removed metadata for source_id={}", source_id);
509        Ok(())
510    }
511
512    async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
513        let mut conn = self.get_conn().await?;
514        let source_key = format!("{}{}", keys::SOURCE_PREFIX, source_id);
515        let worker_key = format!("{}{}:{}", keys::SOURCE_PREFIX, source_id, worker_id);
516
517        let mut pipe = redis::pipe();
518        pipe.del(&worker_key);
519        pipe.hdel(&source_key, worker_id);
520        pipe.exec_async(&mut conn).await?;
521
522        info!(
523            "Removed worker '{}' from source_id={}",
524            worker_id, source_id
525        );
526        Ok(())
527    }
528
529    async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
530        let mut conn = self.get_conn().await?;
531        let source_keys = scan_keys(&mut conn, keys::SOURCE_SCAN_PATTERN).await?;
532
533        let mut sources = Vec::new();
534        for key in source_keys {
535            let source_id = key[keys::SOURCE_PREFIX.len()..].to_string();
536            let attr_json: Option<String> = conn.hget(&key, keys::ATTRIBUTES_FIELD).await?;
537            if let Some(json) = attr_json {
538                let model_name = serde_json::from_str::<SourceAttributesJson>(&json)
539                    .map(|a| a.model_name)
540                    .unwrap_or_default();
541                sources.push((source_id, model_name));
542            }
543        }
544        Ok(sources)
545    }
546
547    async fn update_status(
548        &self,
549        source_id: &str,
550        worker_id: &str,
551        worker_rank: u32,
552        status: SourceStatus,
553        updated_at: i64,
554    ) -> MetadataResult<()> {
555        let mut conn = self.get_conn().await?;
556        let key = format!("{}{}:{}", keys::SOURCE_PREFIX, source_id, worker_id);
557        let field = worker_rank.to_string();
558
559        let value: Option<String> = conn.hget(&key, &field).await?;
560        let json_str = value.ok_or_else(|| {
561            format!(
562                "update_status: rank {} not found in source '{}' worker '{}'",
563                worker_rank, source_id, worker_id
564            )
565        })?;
566
567        let mut record: WorkerRecordJson = serde_json::from_str(&json_str)?;
568        record.status = status as i32;
569        record.updated_at = updated_at;
570
571        let updated = serde_json::to_string(&record)?;
572        conn.hset::<_, _, _, ()>(&key, &field, &updated).await?;
573
574        debug!(
575            "Updated status for source '{}' worker '{}' rank {} -> {}",
576            source_id, worker_id, worker_rank, status as i32
577        );
578        Ok(())
579    }
580}
581
582#[cfg(test)]
583#[allow(clippy::expect_used)]
584mod tests {
585    use super::*;
586
587    // ── TensorRecordJson serialization ──────────────────────────────────────
588
589    #[test]
590    fn test_tensor_record_json_roundtrip() {
591        let record = TensorRecord {
592            name: "model.layers.0.weight".to_string(),
593            addr: 0x7f00_0000_0000,
594            size: 1_073_741_824,
595            device_id: 3,
596            dtype: "bfloat16".to_string(),
597        };
598        let json_record = TensorRecordJson::from(record.clone());
599        let json = serde_json::to_string(&json_record).expect("serialize");
600
601        // addr and size must be serialized as strings
602        assert!(json.contains(r#""addr":"#));
603        let parsed: TensorRecordJson = serde_json::from_str(&json).expect("deserialize");
604        let back = TensorRecord::from(parsed);
605
606        assert_eq!(back.name, record.name);
607        assert_eq!(back.addr, record.addr);
608        assert_eq!(back.size, record.size);
609        assert_eq!(back.device_id, record.device_id);
610        assert_eq!(back.dtype, record.dtype);
611    }
612
613    #[test]
614    fn test_deserialize_u64_from_string() {
615        let json = r#"{"name":"w","addr":"139948187451390","size":"134217728","device_id":0,"dtype":"f16"}"#;
616        let t: TensorRecordJson = serde_json::from_str(json).expect("parse string");
617        assert_eq!(t.addr, 139948187451390);
618        assert_eq!(t.size, 134217728);
619    }
620
621    #[test]
622    fn test_deserialize_u64_from_number() {
623        let json = r#"{"name":"w","addr":1234567890,"size":4096,"device_id":0,"dtype":"f16"}"#;
624        let t: TensorRecordJson = serde_json::from_str(json).expect("parse number");
625        assert_eq!(t.addr, 1234567890);
626    }
627
628    #[test]
629    fn test_deserialize_u64_from_float() {
630        // cjson can emit floats for large integers
631        let json = r#"{"name":"w","addr":1048576.0,"size":4096.0,"device_id":0,"dtype":"f16"}"#;
632        let t: TensorRecordJson = serde_json::from_str(json).expect("parse float");
633        assert_eq!(t.addr, 1048576);
634    }
635
636    // ── WorkerRecordJson serialization ──────────────────────────────────────
637
638    #[test]
639    fn test_worker_record_json_roundtrip_with_status() {
640        let record = WorkerRecord {
641            worker_rank: 2,
642            backend_metadata: super::super::BackendMetadataRecord::Nixl(vec![
643                0xde, 0xad, 0xbe, 0xef,
644            ]),
645            tensors: vec![TensorRecord {
646                name: "t".to_string(),
647                addr: 0x1000,
648                size: 512,
649                device_id: 2,
650                dtype: "float16".to_string(),
651            }],
652            status: 2, // SOURCE_STATUS_READY
653            updated_at: 1_700_000_000_000,
654            metadata_endpoint: String::new(),
655            agent_name: String::new(),
656            worker_grpc_endpoint: String::new(),
657        };
658
659        let json_record = WorkerRecordJson::from_worker_record(record.clone());
660        let json = serde_json::to_string(&json_record).expect("serialize");
661        let parsed: WorkerRecordJson = serde_json::from_str(&json).expect("deserialize");
662        let back = WorkerRecord::from(parsed);
663
664        assert_eq!(back.worker_rank, record.worker_rank);
665        assert_eq!(back.backend_metadata, record.backend_metadata);
666        assert_eq!(back.status, record.status);
667        assert_eq!(back.updated_at, record.updated_at);
668        assert_eq!(back.tensors.len(), 1);
669    }
670
671    #[test]
672    fn test_worker_record_json_backward_compat_missing_status() {
673        // Records written before status/updated_at fields existed must default to 0.
674        // model_name field (removed) is silently ignored by serde.
675        let json = r#"{"worker_rank":0,"model_name":"m","nixl_metadata":[],"tensors":[]}"#;
676        let parsed: WorkerRecordJson = serde_json::from_str(json).expect("parse legacy");
677        assert_eq!(parsed.status, 0);
678        assert_eq!(parsed.updated_at, 0);
679    }
680
681    // ── SourceAttributesJson ────────────────────────────────────────────────
682
683    fn test_identity() -> modelexpress_common::grpc::p2p::SourceIdentity {
684        modelexpress_common::grpc::p2p::SourceIdentity {
685            mx_version: "0.3.0".to_string(),
686            mx_source_type: 0,
687            model_name: "deepseek-ai/DeepSeek-V3".to_string(),
688            backend_framework: 1,
689            tensor_parallel_size: 8,
690            pipeline_parallel_size: 2,
691            expert_parallel_size: 4,
692            dtype: "bfloat16".to_string(),
693            quantization: "fp8".to_string(),
694            extra_parameters: Default::default(),
695        }
696    }
697
698    #[test]
699    fn test_source_attributes_from_identity() {
700        let id = test_identity();
701        let attr = SourceAttributesJson::from(&id);
702
703        assert_eq!(attr.model_name, "deepseek-ai/DeepSeek-V3");
704        assert_eq!(attr.mx_version, "0.3.0");
705        assert_eq!(attr.tensor_parallel_size, 8);
706        assert_eq!(attr.pipeline_parallel_size, 2);
707        assert_eq!(attr.expert_parallel_size, 4);
708        assert_eq!(attr.dtype, "bfloat16");
709        assert_eq!(attr.quantization, "fp8");
710        assert_eq!(attr.backend_framework, 1);
711    }
712
713    #[test]
714    fn test_source_attributes_json_roundtrip() {
715        let id = test_identity();
716        let attr = SourceAttributesJson::from(&id);
717        let json = serde_json::to_string(&attr).expect("serialize");
718        let back: SourceAttributesJson = serde_json::from_str(&json).expect("deserialize");
719
720        assert_eq!(back.model_name, attr.model_name);
721        assert_eq!(back.tensor_parallel_size, attr.tensor_parallel_size);
722        assert_eq!(back.pipeline_parallel_size, attr.pipeline_parallel_size);
723        assert_eq!(back.expert_parallel_size, attr.expert_parallel_size);
724        assert_eq!(back.dtype, attr.dtype);
725        assert_eq!(back.quantization, attr.quantization);
726    }
727
728    #[test]
729    fn test_source_attributes_defaults_for_missing_fields() {
730        // Old records that only stored model_name should deserialize with zero defaults.
731        let json = r#"{"model_name":"my-model"}"#;
732        let attr: SourceAttributesJson = serde_json::from_str(json).expect("deserialize");
733
734        assert_eq!(attr.model_name, "my-model");
735        assert_eq!(attr.tensor_parallel_size, 0);
736        assert_eq!(attr.pipeline_parallel_size, 0);
737        assert_eq!(attr.expert_parallel_size, 0);
738        assert_eq!(attr.dtype, "");
739        assert_eq!(attr.quantization, "");
740    }
741}