use kube::CustomResource;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(CustomResource, Clone, Debug, Deserialize, Serialize, JsonSchema)]
#[kube(
group = "modelexpress.nvidia.com",
version = "v1alpha1",
kind = "ModelMetadata",
plural = "modelmetadatas",
shortname = "mxmeta",
namespaced,
status = "ModelMetadataStatus"
)]
pub struct ModelMetadataSpec {
#[serde(rename = "modelName")]
pub model_name: String,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, JsonSchema)]
pub struct ModelMetadataStatus {
#[serde(default)]
pub worker: Option<WorkerStatus>,
#[serde(default)]
pub conditions: Vec<Condition>,
#[serde(rename = "observedGeneration", default)]
pub observed_generation: i64,
#[serde(rename = "publishedAt", default)]
pub published_at: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, JsonSchema)]
pub struct WorkerStatus {
#[serde(rename = "workerRank")]
pub worker_rank: i32,
#[serde(rename = "backendType", default)]
pub backend_type: Option<String>,
#[serde(rename = "nixlMetadata", default)]
pub nixl_metadata: String,
#[serde(rename = "transferEngineSessionId", default)]
pub transfer_engine_session_id: Option<String>,
#[serde(rename = "tensorCount", default)]
pub tensor_count: i32,
#[serde(rename = "tensorConfigMap", default)]
pub tensor_config_map: Option<String>,
#[serde(default)]
pub status: String,
#[serde(rename = "updatedAt", default)]
pub updated_at: Option<String>,
#[serde(rename = "metadataEndpoint", default)]
pub metadata_endpoint: String,
#[serde(rename = "agentName", default)]
pub agent_name: String,
#[serde(rename = "workerGrpcEndpoint", default)]
pub worker_grpc_endpoint: String,
}
impl WorkerStatus {
pub fn status_name_from_proto(status: i32) -> String {
match status {
0 => "Unknown",
1 => "Initializing",
2 => "Ready",
3 => "Stale",
_ => "Unknown",
}
.to_string()
}
pub fn status_proto_from_name(name: &str) -> i32 {
match name {
"Initializing" => 1,
"Ready" => 2,
"Stale" => 3,
_ => 0,
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, JsonSchema)]
pub struct Condition {
#[serde(rename = "type")]
pub type_: String,
pub status: String,
#[serde(default)]
pub reason: Option<String>,
#[serde(default)]
pub message: Option<String>,
#[serde(rename = "lastTransitionTime", default)]
pub last_transition_time: Option<String>,
}
impl ModelMetadataStatus {
pub fn set_condition(&mut self, type_: &str, status: &str, reason: &str, message: &str) {
let now = chrono::Utc::now().to_rfc3339();
if let Some(existing) = self.conditions.iter_mut().find(|c| c.type_ == type_) {
if existing.status != status {
existing.last_transition_time = Some(now);
}
existing.status = status.to_string();
existing.reason = Some(reason.to_string());
existing.message = Some(message.to_string());
} else {
self.conditions.push(Condition {
type_: type_.to_string(),
status: status.to_string(),
reason: Some(reason.to_string()),
message: Some(message.to_string()),
last_transition_time: Some(now),
});
}
}
pub fn update_ready_condition(&mut self, worker_proto_status: i32) {
let is_ready = worker_proto_status == 2; if is_ready {
self.set_condition("Ready", "True", "WorkerReady", "Worker is ready");
} else {
let status_name = WorkerStatus::status_name_from_proto(worker_proto_status);
self.set_condition(
"Ready",
"False",
&format!("Worker{}", status_name),
"Worker is not ready",
);
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TensorDescriptorJson {
pub name: String,
pub addr: String,
pub size: String,
pub device_id: u32,
pub dtype: String,
}
pub fn sanitize_model_name(model_name: &str) -> String {
model_name
.to_lowercase()
.replace(['/', '_'], "-")
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '.')
.collect::<String>()
.trim_matches('-')
.to_string()
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_status_roundtrip() {
for (proto, name) in [
(0, "Unknown"),
(1, "Initializing"),
(2, "Ready"),
(3, "Stale"),
] {
assert_eq!(WorkerStatus::status_name_from_proto(proto), name);
assert_eq!(WorkerStatus::status_proto_from_name(name), proto);
}
}
#[test]
fn test_status_unknown_roundtrip() {
let written = WorkerStatus::status_name_from_proto(0);
assert_eq!(written, "Unknown");
let read_back = WorkerStatus::status_proto_from_name(&written);
assert_eq!(
read_back, 0,
"Unknown status must roundtrip to proto value 0"
);
}
#[test]
fn test_status_name_from_proto_unknown() {
assert_eq!(WorkerStatus::status_name_from_proto(99), "Unknown");
assert_eq!(WorkerStatus::status_name_from_proto(4), "Unknown");
}
#[test]
fn test_status_proto_from_name_unknown() {
assert_eq!(WorkerStatus::status_proto_from_name("Unknown"), 0);
assert_eq!(WorkerStatus::status_proto_from_name(""), 0);
assert_eq!(WorkerStatus::status_proto_from_name("ready"), 0);
}
#[test]
fn test_sanitize_model_name() {
assert_eq!(
sanitize_model_name("deepseek-ai/DeepSeek-V3"),
"deepseek-ai-deepseek-v3"
);
assert_eq!(
sanitize_model_name("meta-llama/Llama-3.1-70B"),
"meta-llama-llama-3.1-70b"
);
assert_eq!(sanitize_model_name("simple-model"), "simple-model");
}
#[test]
fn test_sanitize_model_name_special_chars() {
assert_eq!(sanitize_model_name("Llama@3.1+8B"), "llama3.18b");
assert_eq!(sanitize_model_name("model with spaces"), "modelwithspaces");
assert_eq!(
sanitize_model_name("org_name/model_v2"),
"org-name-model-v2"
);
}
#[test]
fn test_sanitize_model_name_edge_cases() {
assert_eq!(sanitize_model_name(""), "");
assert_eq!(sanitize_model_name("///"), "");
assert_eq!(sanitize_model_name("---"), "");
assert_eq!(sanitize_model_name("-model-"), "model");
}
#[test]
fn test_tensor_descriptor_json_roundtrip() {
let original = TensorDescriptorJson {
name: "model.layers.0.weight".to_string(),
addr: "139948187451390".to_string(),
size: "134217728".to_string(),
device_id: 0,
dtype: "bfloat16".to_string(),
};
let json = serde_json::to_string(&original).expect("serialize");
let parsed: TensorDescriptorJson = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.name, original.name);
assert_eq!(parsed.addr, original.addr);
assert_eq!(parsed.size, original.size);
assert_eq!(parsed.device_id, original.device_id);
assert_eq!(parsed.dtype, original.dtype);
let addr: u64 = parsed.addr.parse().expect("addr should parse as u64");
assert_eq!(addr, 139948187451390);
let size: u64 = parsed.size.parse().expect("size should parse as u64");
assert_eq!(size, 134217728);
}
#[test]
fn test_tensor_descriptor_json_large_values() {
let desc = TensorDescriptorJson {
name: "test".to_string(),
addr: u64::MAX.to_string(),
size: u64::MAX.to_string(),
device_id: 7,
dtype: "float16".to_string(),
};
let json = serde_json::to_string(&desc).expect("serialize");
let parsed: TensorDescriptorJson = serde_json::from_str(&json).expect("deserialize");
let addr: u64 = parsed.addr.parse().expect("max u64 addr should parse");
assert_eq!(addr, u64::MAX);
}
#[test]
fn test_set_condition_inserts_new() {
let mut status = ModelMetadataStatus::default();
assert!(status.conditions.is_empty());
status.set_condition("Ready", "True", "WorkerPublished", "Published");
assert_eq!(status.conditions.len(), 1);
let cond = &status.conditions[0];
assert_eq!(cond.type_, "Ready");
assert_eq!(cond.status, "True");
assert_eq!(cond.reason.as_deref(), Some("WorkerPublished"));
assert_eq!(cond.message.as_deref(), Some("Published"));
assert!(cond.last_transition_time.is_some());
}
#[test]
fn test_set_condition_updates_existing() {
let mut status = ModelMetadataStatus::default();
status.set_condition("Ready", "True", "WorkerPublished", "Published");
let original_time = status.conditions[0].last_transition_time.clone();
status.set_condition("Ready", "False", "WorkerStale", "Worker is stale");
assert_eq!(status.conditions.len(), 1);
let cond = &status.conditions[0];
assert_eq!(cond.status, "False");
assert_eq!(cond.reason.as_deref(), Some("WorkerStale"));
assert_ne!(
cond.last_transition_time, original_time,
"lastTransitionTime must change on status transition"
);
}
#[test]
fn test_set_condition_same_status_preserves_transition_time() {
let mut status = ModelMetadataStatus::default();
status.set_condition("Ready", "True", "WorkerPublished", "Published");
let original_time = status.conditions[0].last_transition_time.clone();
status.set_condition("Ready", "True", "StillReady", "Still ready");
assert_eq!(status.conditions.len(), 1);
assert_eq!(status.conditions[0].reason.as_deref(), Some("StillReady"));
assert_eq!(
status.conditions[0].last_transition_time, original_time,
"lastTransitionTime must not change when status stays the same"
);
}
#[test]
fn test_update_ready_condition_ready() {
let mut status = ModelMetadataStatus::default();
status.update_ready_condition(2);
assert_eq!(status.conditions.len(), 1);
let cond = &status.conditions[0];
assert_eq!(cond.type_, "Ready");
assert_eq!(cond.status, "True");
assert_eq!(cond.reason.as_deref(), Some("WorkerReady"));
}
#[test]
fn test_update_ready_condition_not_ready_states() {
for (proto, expected_reason) in [
(0, "WorkerUnknown"),
(1, "WorkerInitializing"),
(3, "WorkerStale"),
] {
let mut status = ModelMetadataStatus::default();
status.update_ready_condition(proto);
assert_eq!(status.conditions.len(), 1);
let cond = &status.conditions[0];
assert_eq!(cond.type_, "Ready");
assert_eq!(cond.status, "False");
assert_eq!(
cond.reason.as_deref(),
Some(expected_reason),
"proto status {} should produce reason {}",
proto,
expected_reason
);
}
}
#[test]
fn test_update_ready_condition_transition() {
let mut status = ModelMetadataStatus::default();
status.update_ready_condition(1); assert_eq!(status.conditions[0].status, "False");
let time_false = status.conditions[0].last_transition_time.clone();
status.update_ready_condition(2); assert_eq!(status.conditions[0].status, "True");
assert_ne!(
status.conditions[0].last_transition_time, time_false,
"lastTransitionTime must change on False->True transition"
);
status.update_ready_condition(3); assert_eq!(status.conditions[0].status, "False");
assert_eq!(status.conditions[0].reason.as_deref(), Some("WorkerStale"));
}
}