Skip to main content

modelexpress_server/metadata_backend/
kubernetes.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Kubernetes CRD backend for P2P model metadata storage.
5//!
6//! Uses ModelMetadata CRD and ConfigMaps for tensor descriptors.
7
8use super::{MetadataBackend, MetadataResult, ModelMetadataRecord, TensorRecord, WorkerRecord};
9use crate::k8s_types::{ModelMetadata, ModelMetadataSpec, TensorDescriptorJson, WorkerStatus};
10use async_trait::async_trait;
11use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
12use k8s_openapi::api::core::v1::ConfigMap;
13use kube::{
14    Client,
15    api::{Api, ListParams, Patch, PatchParams, PostParams},
16};
17use modelexpress_common::grpc::p2p::{SourceIdentity, SourceStatus, WorkerMetadata};
18use serde_json::json;
19use std::collections::BTreeMap;
20use tracing::{debug, info, warn};
21
22/// Kubernetes backend for metadata storage
23pub struct KubernetesBackend {
24    client: Client,
25    namespace: String,
26}
27
28impl KubernetesBackend {
29    /// Create a new Kubernetes backend
30    pub async fn new(namespace: &str) -> MetadataResult<Self> {
31        let client = Client::try_default().await?;
32        Ok(Self {
33            client,
34            namespace: namespace.to_string(),
35        })
36    }
37
38    /// Get the API handle for ModelMetadata CRD
39    fn model_metadata_api(&self) -> Api<ModelMetadata> {
40        Api::namespaced(self.client.clone(), &self.namespace)
41    }
42
43    /// Get the API handle for ConfigMaps
44    fn configmap_api(&self) -> Api<ConfigMap> {
45        Api::namespaced(self.client.clone(), &self.namespace)
46    }
47
48    /// Create or update a ConfigMap with tensor descriptors for a worker.
49    /// If `owner_uid` and `owner_name` are provided, sets ownerReferences
50    /// so K8s garbage-collects ConfigMaps when the parent CR is deleted.
51    async fn upsert_tensor_configmap(
52        &self,
53        source_id: &str,
54        worker_id: &str,
55        worker_rank: u32,
56        tensors: &[TensorRecord],
57        owner_name: Option<&str>,
58        owner_uid: Option<&str>,
59    ) -> MetadataResult<String> {
60        let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
61        let cm_name = format!("{}-tensors-worker-{}", cr_name, worker_rank);
62
63        // Convert tensors to JSON
64        let tensor_json: Vec<TensorDescriptorJson> = tensors
65            .iter()
66            .map(|t| TensorDescriptorJson {
67                name: t.name.clone(),
68                addr: t.addr.to_string(),
69                size: t.size.to_string(),
70                device_id: t.device_id,
71                dtype: t.dtype.clone(),
72            })
73            .collect();
74
75        let tensors_data = serde_json::to_string_pretty(&tensor_json)?;
76
77        let mut data = BTreeMap::new();
78        data.insert("tensors.json".to_string(), tensors_data);
79
80        let mut labels = BTreeMap::new();
81        labels.insert(
82            "modelexpress.nvidia.com/mx-source-id".to_string(),
83            source_id.to_string(),
84        );
85        labels.insert(
86            "modelexpress.nvidia.com/worker".to_string(),
87            worker_rank.to_string(),
88        );
89
90        let owner_references = match (owner_name, owner_uid) {
91            (Some(name), Some(uid)) => Some(vec![
92                k8s_openapi::apimachinery::pkg::apis::meta::v1::OwnerReference {
93                    api_version: "modelexpress.nvidia.com/v1alpha1".to_string(),
94                    kind: "ModelMetadata".to_string(),
95                    name: name.to_string(),
96                    uid: uid.to_string(),
97                    controller: Some(true),
98                    block_owner_deletion: Some(true),
99                },
100            ]),
101            _ => None,
102        };
103
104        let cm = ConfigMap {
105            metadata: kube::api::ObjectMeta {
106                name: Some(cm_name.clone()),
107                namespace: Some(self.namespace.clone()),
108                labels: Some(labels),
109                owner_references,
110                ..Default::default()
111            },
112            data: Some(data),
113            ..Default::default()
114        };
115
116        let api = self.configmap_api();
117
118        // Try to create, if exists then patch
119        match api.create(&PostParams::default(), &cm).await {
120            Ok(_) => {
121                debug!("Created ConfigMap {} for worker {}", cm_name, worker_rank);
122            }
123            Err(kube::Error::Api(err)) if err.code == 409 => {
124                // Already exists — use merge patch to avoid SSA field manager conflicts
125                api.patch(&cm_name, &PatchParams::default(), &Patch::Merge(&cm))
126                    .await?;
127                debug!("Updated ConfigMap {} for worker {}", cm_name, worker_rank);
128            }
129            Err(e) => return Err(e.into()),
130        }
131
132        Ok(cm_name)
133    }
134
135    /// Read tensor descriptors from a ConfigMap
136    async fn read_tensor_configmap(&self, cm_name: &str) -> MetadataResult<Vec<TensorRecord>> {
137        let api = self.configmap_api();
138        let cm = api.get(cm_name).await?;
139
140        let tensors_json = cm
141            .data
142            .and_then(|d| d.get("tensors.json").cloned())
143            .ok_or("ConfigMap missing tensors.json")?;
144
145        let tensor_descs: Vec<TensorDescriptorJson> = serde_json::from_str(&tensors_json)?;
146
147        let tensors = tensor_descs
148            .into_iter()
149            .map(|t| {
150                let addr = t.addr.parse::<u64>().map_err(|e| {
151                    format!("Invalid tensor addr '{}' for '{}': {}", t.addr, t.name, e)
152                })?;
153                let size = t.size.parse::<u64>().map_err(|e| {
154                    format!("Invalid tensor size '{}' for '{}': {}", t.size, t.name, e)
155                })?;
156                Ok(TensorRecord {
157                    name: t.name,
158                    addr,
159                    size,
160                    device_id: t.device_id,
161                    dtype: t.dtype,
162                })
163            })
164            .collect::<MetadataResult<Vec<_>>>()?;
165
166        Ok(tensors)
167    }
168}
169
170#[async_trait]
171impl MetadataBackend for KubernetesBackend {
172    async fn connect(&self) -> MetadataResult<()> {
173        // Test connection by listing CRDs (will fail if no permissions)
174        let api = self.model_metadata_api();
175        let _ = api.list(&ListParams::default().limit(1)).await?;
176        info!(
177            "Connected to Kubernetes, using namespace '{}'",
178            self.namespace
179        );
180        Ok(())
181    }
182
183    async fn publish_metadata(
184        &self,
185        identity: &SourceIdentity,
186        worker_id: &str,
187        worker: WorkerMetadata,
188    ) -> MetadataResult<()> {
189        let source_id = crate::source_identity::compute_mx_source_id(identity);
190        let source_id = source_id.as_str();
191        let model_name = &identity.model_name;
192        let api = self.model_metadata_api();
193        let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
194        let now = chrono::Utc::now().to_rfc3339();
195
196        let worker_record = WorkerRecord::from(worker);
197
198        // First, ensure the CR exists
199        let existing = api.get_opt(&cr_name).await?;
200
201        if existing.is_none() {
202            let new_cr = ModelMetadata {
203                metadata: kube::api::ObjectMeta {
204                    name: Some(cr_name.clone()),
205                    namespace: Some(self.namespace.clone()),
206                    labels: Some({
207                        let mut labels = BTreeMap::new();
208                        labels.insert(
209                            "modelexpress.nvidia.com/mx-source-id".to_string(),
210                            source_id.to_string(),
211                        );
212                        labels.insert(
213                            "modelexpress.nvidia.com/mx-worker-id".to_string(),
214                            worker_id.to_string(),
215                        );
216                        labels
217                    }),
218                    ..Default::default()
219                },
220                spec: ModelMetadataSpec {
221                    model_name: model_name.to_string(),
222                },
223                status: None,
224            };
225
226            match api.create(&PostParams::default(), &new_cr).await {
227                Ok(_) => {
228                    info!("Created ModelMetadata CR '{}'", cr_name);
229                }
230                Err(kube::Error::Api(err)) if err.code == 409 => {
231                    debug!(
232                        "ModelMetadata CR '{}' already exists, proceeding to update",
233                        cr_name
234                    );
235                }
236                Err(e) => return Err(e.into()),
237            }
238        }
239
240        // Get CR UID for ownerReferences on ConfigMaps
241        let cr = api.get(&cr_name).await?;
242        let owner_uid = cr.metadata.uid.as_deref();
243        let owner_name = cr.metadata.name.as_deref();
244
245        let cm_name = self
246            .upsert_tensor_configmap(
247                source_id,
248                worker_id,
249                worker_record.worker_rank,
250                &worker_record.tensors,
251                owner_name,
252                owner_uid,
253            )
254            .await?;
255
256        let backend_type = worker_record
257            .backend_metadata
258            .backend_type_str()
259            .to_string();
260        let (nixl_metadata, transfer_engine_session_id) = match &worker_record.backend_metadata {
261            super::BackendMetadataRecord::Nixl(data) => (BASE64.encode(data), None),
262            super::BackendMetadataRecord::TransferEngine(sid) => (String::new(), Some(sid.clone())),
263            super::BackendMetadataRecord::None => (String::new(), None),
264        };
265
266        let worker_status = WorkerStatus {
267            worker_rank: worker_record.worker_rank as i32,
268            backend_type: Some(backend_type),
269            nixl_metadata,
270            transfer_engine_session_id,
271            tensor_count: worker_record.tensors.len() as i32,
272            tensor_config_map: Some(cm_name),
273            status: WorkerStatus::status_name_from_proto(worker_record.status),
274            updated_at: Some(now.clone()),
275            metadata_endpoint: worker_record.metadata_endpoint.clone(),
276            agent_name: worker_record.agent_name.clone(),
277            worker_grpc_endpoint: worker_record.worker_grpc_endpoint.clone(),
278        };
279
280        let max_retries: u32 = 5;
281        let mut status_updated = false;
282        for attempt in 0..max_retries {
283            let current = api.get(&cr_name).await?;
284            let resource_version = current.metadata.resource_version.unwrap_or_default();
285            let generation = current.metadata.generation.unwrap_or(0);
286
287            let mut crd_status = current.status.unwrap_or_default();
288            crd_status.update_ready_condition(worker_record.status);
289
290            let status_patch = json!({
291                "metadata": { "resourceVersion": resource_version },
292                "status": {
293                    "worker": worker_status,
294                    "publishedAt": now,
295                    "conditions": crd_status.conditions,
296                    "observedGeneration": generation
297                }
298            });
299
300            match api
301                .patch_status(
302                    &cr_name,
303                    &PatchParams::default(),
304                    &Patch::Merge(&status_patch),
305                )
306                .await
307            {
308                Ok(_) => {
309                    status_updated = true;
310                    break;
311                }
312                Err(kube::Error::Api(err)) if err.code == 409 => {
313                    debug!(
314                        "Conflict updating status for source '{}' instance '{}', retrying ({}/{})",
315                        source_id,
316                        worker_id,
317                        attempt.saturating_add(1),
318                        max_retries
319                    );
320                    tokio::time::sleep(std::time::Duration::from_millis(
321                        100_u64.saturating_mul(u64::from(attempt).saturating_add(1)),
322                    ))
323                    .await;
324                }
325                Err(e) => return Err(e.into()),
326            }
327        }
328
329        if !status_updated {
330            return Err(format!(
331                "Failed to update status for source '{}' instance '{}' after {} retries",
332                source_id, worker_id, max_retries
333            )
334            .into());
335        }
336
337        info!(
338            "Published metadata for '{}' (source_id={}, worker_id={}): rank {} ({} tensors)",
339            model_name,
340            source_id,
341            worker_id,
342            worker_record.worker_rank,
343            worker_record.tensors.len(),
344        );
345
346        Ok(())
347    }
348
349    async fn get_metadata(
350        &self,
351        source_id: &str,
352        worker_id: &str,
353    ) -> MetadataResult<Option<ModelMetadataRecord>> {
354        let api = self.model_metadata_api();
355        let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
356
357        let cr = match api.get_opt(&cr_name).await? {
358            Some(cr) => cr,
359            None => {
360                debug!(
361                    "No ModelMetadata CR found for source_id={} worker_id={}",
362                    source_id, worker_id
363                );
364                return Ok(None);
365            }
366        };
367
368        let status = match cr.status {
369            Some(s) => s,
370            None => {
371                debug!("ModelMetadata CR '{}' has no status", cr_name);
372                return Ok(None);
373            }
374        };
375
376        let mut workers = Vec::new();
377        if let Some(worker_status) = status.worker {
378            let nixl_bytes = if !worker_status.nixl_metadata.is_empty() {
379                BASE64.decode(&worker_status.nixl_metadata).map_err(|e| {
380                    format!(
381                        "Failed to decode NIXL metadata for worker {}: {}",
382                        worker_status.worker_rank, e
383                    )
384                })?
385            } else {
386                Vec::new()
387            };
388            let backend_metadata = super::BackendMetadataRecord::from_flat(
389                nixl_bytes,
390                worker_status.transfer_engine_session_id.clone(),
391                worker_status.backend_type.as_deref(),
392            );
393
394            let tensors = if let Some(cm_name) = &worker_status.tensor_config_map {
395                match self.read_tensor_configmap(cm_name).await {
396                    Ok(t) => t,
397                    Err(e) => {
398                        warn!("Failed to read tensor ConfigMap '{}': {}", cm_name, e);
399                        Vec::new()
400                    }
401                }
402            } else {
403                Vec::new()
404            };
405
406            let status = WorkerStatus::status_proto_from_name(&worker_status.status);
407            let updated_at = worker_status
408                .updated_at
409                .as_deref()
410                .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
411                .map(|dt| dt.timestamp_millis())
412                .unwrap_or(0);
413
414            workers.push(WorkerRecord {
415                worker_rank: worker_status.worker_rank as u32,
416                backend_metadata,
417                tensors,
418                status,
419                updated_at,
420                metadata_endpoint: worker_status.metadata_endpoint.clone(),
421                agent_name: worker_status.agent_name.clone(),
422                worker_grpc_endpoint: worker_status.worker_grpc_endpoint.clone(),
423            });
424        }
425
426        let published_at = status
427            .published_at
428            .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
429            .map(|dt| dt.timestamp())
430            .unwrap_or(0);
431
432        debug!(
433            "Retrieved metadata for source_id={} worker_id={}: {} workers",
434            source_id,
435            worker_id,
436            workers.len()
437        );
438
439        Ok(Some(ModelMetadataRecord {
440            source_id: source_id.to_string(),
441            worker_id: worker_id.to_string(),
442            model_name: cr.spec.model_name.clone(),
443            workers,
444            published_at,
445        }))
446    }
447
448    async fn list_workers(
449        &self,
450        source_id: Option<String>,
451        status_filter: Option<SourceStatus>,
452    ) -> MetadataResult<Vec<super::SourceInstanceInfo>> {
453        let api = self.model_metadata_api();
454
455        let label_selector = match source_id {
456            Some(sid) => format!("modelexpress.nvidia.com/mx-source-id={}", sid),
457            None => String::new(),
458        };
459
460        let list_params = if label_selector.is_empty() {
461            ListParams::default()
462        } else {
463            ListParams::default().labels(&label_selector)
464        };
465
466        let crs = api.list(&list_params).await?;
467        let mut result = Vec::new();
468        for cr in crs.items {
469            let sid = cr
470                .metadata
471                .labels
472                .as_ref()
473                .and_then(|l| l.get("modelexpress.nvidia.com/mx-source-id"))
474                .cloned()
475                .unwrap_or_default();
476            let iid = cr
477                .metadata
478                .labels
479                .as_ref()
480                .and_then(|l| l.get("modelexpress.nvidia.com/mx-worker-id"))
481                .cloned()
482                .unwrap_or_default();
483
484            let worker_rank = cr
485                .status
486                .as_ref()
487                .and_then(|s| s.worker.as_ref())
488                .map(|w| w.worker_rank as u32)
489                .unwrap_or(0);
490
491            if let Some(required_status) = status_filter {
492                let required_name =
493                    crate::k8s_types::WorkerStatus::status_name_from_proto(required_status as i32);
494                let matches = cr
495                    .status
496                    .as_ref()
497                    .map(|s| s.worker.as_ref().is_some_and(|w| w.status == required_name))
498                    .unwrap_or(false);
499                if !matches {
500                    continue;
501                }
502            }
503
504            let (status, updated_at) = cr
505                .status
506                .as_ref()
507                .and_then(|s| s.worker.as_ref())
508                .map(|w| {
509                    let proto_status =
510                        crate::k8s_types::WorkerStatus::status_proto_from_name(&w.status);
511                    let millis = w
512                        .updated_at
513                        .as_deref()
514                        .and_then(|ts| chrono::DateTime::parse_from_rfc3339(ts).ok())
515                        .map(|dt| dt.timestamp_millis())
516                        .unwrap_or(0);
517                    (proto_status, millis)
518                })
519                .unwrap_or((0, 0));
520
521            result.push(super::SourceInstanceInfo {
522                source_id: sid,
523                worker_id: iid,
524                model_name: cr.spec.model_name,
525                worker_rank,
526                status,
527                updated_at,
528            });
529        }
530
531        Ok(result)
532    }
533
534    async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
535        let api = self.model_metadata_api();
536
537        // Delete all CRs for this source_id via label selector
538        let crs = api
539            .list(&ListParams::default().labels(&format!(
540                "modelexpress.nvidia.com/mx-source-id={}",
541                source_id
542            )))
543            .await?;
544
545        for cr in crs.items {
546            if let Some(name) = cr.metadata.name {
547                match api.delete(&name, &kube::api::DeleteParams::default()).await {
548                    Ok(_) => info!("Deleted ModelMetadata CR '{}'", name),
549                    Err(kube::Error::Api(err)) if err.code == 404 => {
550                        debug!("ModelMetadata CR '{}' not found", name);
551                    }
552                    Err(e) => return Err(e.into()),
553                }
554            }
555        }
556
557        // ConfigMaps are garbage-collected via ownerReferences; also sweep by label
558        let cm_api = self.configmap_api();
559        let cms = cm_api
560            .list(&ListParams::default().labels(&format!(
561                "modelexpress.nvidia.com/mx-source-id={}",
562                source_id
563            )))
564            .await?;
565
566        for cm in cms {
567            if let Some(name) = cm.metadata.name {
568                match cm_api
569                    .delete(&name, &kube::api::DeleteParams::default())
570                    .await
571                {
572                    Ok(_) => debug!("Deleted ConfigMap '{}'", name),
573                    Err(e) => warn!("Failed to delete ConfigMap '{}': {}", name, e),
574                }
575            }
576        }
577
578        Ok(())
579    }
580
581    async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
582        let api = self.model_metadata_api();
583        let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
584
585        match api
586            .delete(&cr_name, &kube::api::DeleteParams::default())
587            .await
588        {
589            Ok(_) => info!("Deleted ModelMetadata CR '{}'", cr_name),
590            Err(kube::Error::Api(err)) if err.code == 404 => {
591                debug!("ModelMetadata CR '{}' already gone", cr_name);
592            }
593            Err(e) => return Err(e.into()),
594        }
595
596        Ok(())
597    }
598
599    async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
600        let api = self.model_metadata_api();
601        let crs = api.list(&ListParams::default()).await?;
602
603        // De-duplicate by source_id (multiple instances share the same source_id)
604        let mut seen = std::collections::BTreeMap::new();
605        for cr in crs.items {
606            let source_id = cr
607                .metadata
608                .labels
609                .as_ref()
610                .and_then(|l| l.get("modelexpress.nvidia.com/mx-source-id"))
611                .cloned();
612            if let Some(sid) = source_id {
613                seen.entry(sid).or_insert_with(|| cr.spec.model_name);
614            }
615        }
616
617        Ok(seen.into_iter().collect())
618    }
619
620    async fn update_status(
621        &self,
622        source_id: &str,
623        worker_id: &str,
624        worker_rank: u32,
625        status: SourceStatus,
626        updated_at: i64,
627    ) -> MetadataResult<()> {
628        let api = self.model_metadata_api();
629        let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
630        let status_name = WorkerStatus::status_name_from_proto(status as i32);
631        let updated_at_rfc3339 = chrono::DateTime::from_timestamp_millis(updated_at)
632            .map(|dt| dt.to_rfc3339())
633            .unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
634
635        let max_retries: u32 = 5;
636        for attempt in 0..max_retries {
637            let current = api.get(&cr_name).await?;
638            let mut crd_status = current.status.ok_or_else(|| {
639                format!(
640                    "update_status: no status in source '{}' worker '{}'",
641                    source_id, worker_id
642                )
643            })?;
644
645            let mut worker = crd_status.worker.take().ok_or_else(|| {
646                format!(
647                    "update_status: no worker in source '{}' worker '{}'",
648                    source_id, worker_id
649                )
650            })?;
651
652            worker.status = status_name.clone();
653            worker.updated_at = Some(updated_at_rfc3339.clone());
654
655            crd_status.update_ready_condition(status as i32);
656
657            let generation = current.metadata.generation.unwrap_or(0);
658            let resource_version = current.metadata.resource_version.unwrap_or_default();
659            let status_patch = serde_json::json!({
660                "metadata": { "resourceVersion": resource_version },
661                "status": {
662                    "worker": worker,
663                    "conditions": crd_status.conditions,
664                    "observedGeneration": generation
665                }
666            });
667
668            match api
669                .patch_status(
670                    &cr_name,
671                    &PatchParams::default(),
672                    &Patch::Merge(&status_patch),
673                )
674                .await
675            {
676                Ok(_) => {
677                    debug!(
678                        "Updated status for source '{}' worker '{}' rank {} -> {}",
679                        source_id, worker_id, worker_rank, status_name
680                    );
681                    return Ok(());
682                }
683                Err(kube::Error::Api(err)) if err.code == 409 => {
684                    debug!(
685                        "Conflict updating status for source '{}' worker '{}', retrying ({}/{})",
686                        source_id,
687                        worker_id,
688                        attempt.saturating_add(1),
689                        max_retries
690                    );
691                    tokio::time::sleep(std::time::Duration::from_millis(
692                        100_u64.saturating_mul(u64::from(attempt).saturating_add(1)),
693                    ))
694                    .await;
695                }
696                Err(e) => return Err(e.into()),
697            }
698        }
699
700        Err(format!(
701            "Failed to update status for source '{}' worker '{}' rank {} after {} retries",
702            source_id, worker_id, worker_rank, max_retries
703        )
704        .into())
705    }
706}