use modelexpress_common::grpc::p2p::SourceIdentity;
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
pub fn compute_mx_source_id(identity: &SourceIdentity) -> String {
let canonical = canonical_json(identity);
let digest = Sha256::digest(canonical.as_bytes());
format!("{:x}", digest)[..16].to_string()
}
pub fn validate_identity(identity: &SourceIdentity) -> Result<(), String> {
if identity.model_name.is_empty() {
return Err("identity.model_name is required".to_string());
}
Ok(())
}
fn canonical_json(identity: &SourceIdentity) -> String {
let mut items: Vec<(&String, &String)> = identity.extra_parameters.iter().collect();
items.sort_by(|a, b| a.0.cmp(b.0));
let mut sorted_extra: BTreeMap<String, String> = BTreeMap::new();
for (k, v) in items {
let lk = k.to_lowercase();
sorted_extra.entry(lk).or_insert_with(|| v.to_lowercase());
}
serde_json::json!({
"mx_version": identity.mx_version.to_lowercase(),
"mx_source_type": identity.mx_source_type,
"model_name": identity.model_name.to_lowercase(),
"backend_framework": identity.backend_framework,
"tensor_parallel_size": identity.tensor_parallel_size,
"pipeline_parallel_size": identity.pipeline_parallel_size,
"expert_parallel_size": identity.expert_parallel_size,
"dtype": identity.dtype.to_lowercase(),
"quantization": identity.quantization.to_lowercase(),
"extra_parameters": sorted_extra,
"revision": identity.revision.to_lowercase(),
})
.to_string()
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
fn base_identity() -> SourceIdentity {
SourceIdentity {
mx_version: "0.4.0".to_string(),
mx_source_type: 0, model_name: "deepseek-ai/DeepSeek-V3".to_string(),
backend_framework: 1, tensor_parallel_size: 8,
pipeline_parallel_size: 1,
expert_parallel_size: 0,
dtype: "bfloat16".to_string(),
quantization: String::new(),
extra_parameters: Default::default(),
revision: String::new(),
}
}
#[test]
fn test_id_is_16_chars() {
let id = compute_mx_source_id(&base_identity());
assert_eq!(id.len(), 16);
assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_deterministic() {
let id1 = compute_mx_source_id(&base_identity());
let id2 = compute_mx_source_id(&base_identity());
assert_eq!(id1, id2);
}
#[test]
fn test_case_insensitive() {
let mut upper = base_identity();
upper.model_name = "DEEPSEEK-AI/DEEPSEEK-V3".to_string();
upper.dtype = "BFLOAT16".to_string();
assert_eq!(
compute_mx_source_id(&base_identity()),
compute_mx_source_id(&upper)
);
}
#[test]
fn test_different_tp_gives_different_id() {
let mut tp4 = base_identity();
tp4.tensor_parallel_size = 4;
assert_ne!(
compute_mx_source_id(&base_identity()),
compute_mx_source_id(&tp4)
);
}
#[test]
fn test_different_dtype_gives_different_id() {
let mut fp8 = base_identity();
fp8.dtype = "float8_e4m3fn".to_string();
assert_ne!(
compute_mx_source_id(&base_identity()),
compute_mx_source_id(&fp8)
);
}
#[test]
fn test_different_source_type_gives_different_id() {
let mut lora = base_identity();
lora.mx_source_type = 1; assert_ne!(
compute_mx_source_id(&base_identity()),
compute_mx_source_id(&lora)
);
}
#[test]
fn test_extra_parameters_sorted() {
let mut a = base_identity();
a.extra_parameters
.insert("z_key".to_string(), "val".to_string());
a.extra_parameters
.insert("a_key".to_string(), "val".to_string());
let mut b = base_identity();
b.extra_parameters
.insert("a_key".to_string(), "val".to_string());
b.extra_parameters
.insert("z_key".to_string(), "val".to_string());
assert_eq!(compute_mx_source_id(&a), compute_mx_source_id(&b));
}
#[test]
fn test_different_revision_gives_different_id() {
let mut pinned = base_identity();
pinned.revision = "abc123def4567890".to_string();
assert_ne!(
compute_mx_source_id(&base_identity()),
compute_mx_source_id(&pinned)
);
}
#[test]
fn test_revision_case_insensitive() {
let mut upper = base_identity();
upper.revision = "ABC123DEF4567890".to_string();
let mut lower = base_identity();
lower.revision = "abc123def4567890".to_string();
assert_eq!(compute_mx_source_id(&upper), compute_mx_source_id(&lower));
}
#[test]
fn test_python_cross_check_base_identity() {
assert_eq!(compute_mx_source_id(&base_identity()), "e2438ef16adcf628");
}
#[test]
fn test_python_cross_check_with_revision() {
let mut pinned = base_identity();
pinned.revision = "abc123def4567890".to_string();
assert_eq!(compute_mx_source_id(&pinned), "7b7803769825576e");
}
#[test]
fn test_python_cross_check_case_colliding_extra() {
let mut id = base_identity();
id.extra_parameters
.insert("Foo".to_string(), "a".to_string());
id.extra_parameters
.insert("foo".to_string(), "b".to_string());
assert_eq!(compute_mx_source_id(&id), "deecf6684507f09c");
}
#[test]
fn test_validate_requires_model_name() {
let mut id = base_identity();
id.model_name = String::new();
assert!(validate_identity(&id).is_err());
}
#[test]
fn test_validate_passes() {
assert!(validate_identity(&base_identity()).is_ok());
}
}