1use super::{MetadataBackend, MetadataResult, ModelMetadataRecord, TensorRecord, WorkerRecord};
9use crate::k8s_types::{ModelMetadata, ModelMetadataSpec, TensorDescriptorJson, WorkerStatus};
10use async_trait::async_trait;
11use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
12use k8s_openapi::api::core::v1::ConfigMap;
13use kube::{
14 Client,
15 api::{Api, ListParams, Patch, PatchParams, PostParams},
16};
17use modelexpress_common::grpc::p2p::{SourceIdentity, SourceStatus, WorkerMetadata};
18use serde_json::json;
19use std::collections::BTreeMap;
20use tracing::{debug, info, warn};
21
22pub struct KubernetesBackend {
24 client: Client,
25 namespace: String,
26}
27
28impl KubernetesBackend {
29 pub async fn new(namespace: &str) -> MetadataResult<Self> {
31 let client = Client::try_default().await?;
32 Ok(Self {
33 client,
34 namespace: namespace.to_string(),
35 })
36 }
37
38 fn model_metadata_api(&self) -> Api<ModelMetadata> {
40 Api::namespaced(self.client.clone(), &self.namespace)
41 }
42
43 fn configmap_api(&self) -> Api<ConfigMap> {
45 Api::namespaced(self.client.clone(), &self.namespace)
46 }
47
48 async fn upsert_tensor_configmap(
52 &self,
53 source_id: &str,
54 worker_id: &str,
55 worker_rank: u32,
56 tensors: &[TensorRecord],
57 owner_name: Option<&str>,
58 owner_uid: Option<&str>,
59 ) -> MetadataResult<String> {
60 let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
61 let cm_name = format!("{}-tensors-worker-{}", cr_name, worker_rank);
62
63 let tensor_json: Vec<TensorDescriptorJson> = tensors
65 .iter()
66 .map(|t| TensorDescriptorJson {
67 name: t.name.clone(),
68 addr: t.addr.to_string(),
69 size: t.size.to_string(),
70 device_id: t.device_id,
71 dtype: t.dtype.clone(),
72 })
73 .collect();
74
75 let tensors_data = serde_json::to_string_pretty(&tensor_json)?;
76
77 let mut data = BTreeMap::new();
78 data.insert("tensors.json".to_string(), tensors_data);
79
80 let mut labels = BTreeMap::new();
81 labels.insert(
82 "modelexpress.nvidia.com/mx-source-id".to_string(),
83 source_id.to_string(),
84 );
85 labels.insert(
86 "modelexpress.nvidia.com/worker".to_string(),
87 worker_rank.to_string(),
88 );
89
90 let owner_references = match (owner_name, owner_uid) {
91 (Some(name), Some(uid)) => Some(vec![
92 k8s_openapi::apimachinery::pkg::apis::meta::v1::OwnerReference {
93 api_version: "modelexpress.nvidia.com/v1alpha1".to_string(),
94 kind: "ModelMetadata".to_string(),
95 name: name.to_string(),
96 uid: uid.to_string(),
97 controller: Some(true),
98 block_owner_deletion: Some(true),
99 },
100 ]),
101 _ => None,
102 };
103
104 let cm = ConfigMap {
105 metadata: kube::api::ObjectMeta {
106 name: Some(cm_name.clone()),
107 namespace: Some(self.namespace.clone()),
108 labels: Some(labels),
109 owner_references,
110 ..Default::default()
111 },
112 data: Some(data),
113 ..Default::default()
114 };
115
116 let api = self.configmap_api();
117
118 match api.create(&PostParams::default(), &cm).await {
120 Ok(_) => {
121 debug!("Created ConfigMap {} for worker {}", cm_name, worker_rank);
122 }
123 Err(kube::Error::Api(err)) if err.code == 409 => {
124 api.patch(&cm_name, &PatchParams::default(), &Patch::Merge(&cm))
126 .await?;
127 debug!("Updated ConfigMap {} for worker {}", cm_name, worker_rank);
128 }
129 Err(e) => return Err(e.into()),
130 }
131
132 Ok(cm_name)
133 }
134
135 async fn read_tensor_configmap(&self, cm_name: &str) -> MetadataResult<Vec<TensorRecord>> {
137 let api = self.configmap_api();
138 let cm = api.get(cm_name).await?;
139
140 let tensors_json = cm
141 .data
142 .and_then(|d| d.get("tensors.json").cloned())
143 .ok_or("ConfigMap missing tensors.json")?;
144
145 let tensor_descs: Vec<TensorDescriptorJson> = serde_json::from_str(&tensors_json)?;
146
147 let tensors = tensor_descs
148 .into_iter()
149 .map(|t| {
150 let addr = t.addr.parse::<u64>().map_err(|e| {
151 format!("Invalid tensor addr '{}' for '{}': {}", t.addr, t.name, e)
152 })?;
153 let size = t.size.parse::<u64>().map_err(|e| {
154 format!("Invalid tensor size '{}' for '{}': {}", t.size, t.name, e)
155 })?;
156 Ok(TensorRecord {
157 name: t.name,
158 addr,
159 size,
160 device_id: t.device_id,
161 dtype: t.dtype,
162 })
163 })
164 .collect::<MetadataResult<Vec<_>>>()?;
165
166 Ok(tensors)
167 }
168}
169
170#[async_trait]
171impl MetadataBackend for KubernetesBackend {
172 async fn connect(&self) -> MetadataResult<()> {
173 let api = self.model_metadata_api();
175 let _ = api.list(&ListParams::default().limit(1)).await?;
176 info!(
177 "Connected to Kubernetes, using namespace '{}'",
178 self.namespace
179 );
180 Ok(())
181 }
182
183 async fn publish_metadata(
184 &self,
185 identity: &SourceIdentity,
186 worker_id: &str,
187 worker: WorkerMetadata,
188 ) -> MetadataResult<()> {
189 let source_id = crate::source_identity::compute_mx_source_id(identity);
190 let source_id = source_id.as_str();
191 let model_name = &identity.model_name;
192 let api = self.model_metadata_api();
193 let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
194 let now = chrono::Utc::now().to_rfc3339();
195
196 let worker_record = WorkerRecord::from(worker);
197
198 let existing = api.get_opt(&cr_name).await?;
200
201 if existing.is_none() {
202 let new_cr = ModelMetadata {
203 metadata: kube::api::ObjectMeta {
204 name: Some(cr_name.clone()),
205 namespace: Some(self.namespace.clone()),
206 labels: Some({
207 let mut labels = BTreeMap::new();
208 labels.insert(
209 "modelexpress.nvidia.com/mx-source-id".to_string(),
210 source_id.to_string(),
211 );
212 labels.insert(
213 "modelexpress.nvidia.com/mx-worker-id".to_string(),
214 worker_id.to_string(),
215 );
216 labels
217 }),
218 ..Default::default()
219 },
220 spec: ModelMetadataSpec {
221 model_name: model_name.to_string(),
222 },
223 status: None,
224 };
225
226 match api.create(&PostParams::default(), &new_cr).await {
227 Ok(_) => {
228 info!("Created ModelMetadata CR '{}'", cr_name);
229 }
230 Err(kube::Error::Api(err)) if err.code == 409 => {
231 debug!(
232 "ModelMetadata CR '{}' already exists, proceeding to update",
233 cr_name
234 );
235 }
236 Err(e) => return Err(e.into()),
237 }
238 }
239
240 let cr = api.get(&cr_name).await?;
242 let owner_uid = cr.metadata.uid.as_deref();
243 let owner_name = cr.metadata.name.as_deref();
244
245 let cm_name = self
246 .upsert_tensor_configmap(
247 source_id,
248 worker_id,
249 worker_record.worker_rank,
250 &worker_record.tensors,
251 owner_name,
252 owner_uid,
253 )
254 .await?;
255
256 let backend_type = worker_record
257 .backend_metadata
258 .backend_type_str()
259 .to_string();
260 let (nixl_metadata, transfer_engine_session_id) = match &worker_record.backend_metadata {
261 super::BackendMetadataRecord::Nixl(data) => (BASE64.encode(data), None),
262 super::BackendMetadataRecord::TransferEngine(sid) => (String::new(), Some(sid.clone())),
263 super::BackendMetadataRecord::None => (String::new(), None),
264 };
265
266 let worker_status = WorkerStatus {
267 worker_rank: worker_record.worker_rank as i32,
268 backend_type: Some(backend_type),
269 nixl_metadata,
270 transfer_engine_session_id,
271 tensor_count: worker_record.tensors.len() as i32,
272 tensor_config_map: Some(cm_name),
273 status: WorkerStatus::status_name_from_proto(worker_record.status),
274 updated_at: Some(now.clone()),
275 metadata_endpoint: worker_record.metadata_endpoint.clone(),
276 agent_name: worker_record.agent_name.clone(),
277 worker_grpc_endpoint: worker_record.worker_grpc_endpoint.clone(),
278 };
279
280 let max_retries: u32 = 5;
281 let mut status_updated = false;
282 for attempt in 0..max_retries {
283 let current = api.get(&cr_name).await?;
284 let resource_version = current.metadata.resource_version.unwrap_or_default();
285 let generation = current.metadata.generation.unwrap_or(0);
286
287 let mut crd_status = current.status.unwrap_or_default();
288 crd_status.update_ready_condition(worker_record.status);
289
290 let status_patch = json!({
291 "metadata": { "resourceVersion": resource_version },
292 "status": {
293 "worker": worker_status,
294 "publishedAt": now,
295 "conditions": crd_status.conditions,
296 "observedGeneration": generation
297 }
298 });
299
300 match api
301 .patch_status(
302 &cr_name,
303 &PatchParams::default(),
304 &Patch::Merge(&status_patch),
305 )
306 .await
307 {
308 Ok(_) => {
309 status_updated = true;
310 break;
311 }
312 Err(kube::Error::Api(err)) if err.code == 409 => {
313 debug!(
314 "Conflict updating status for source '{}' instance '{}', retrying ({}/{})",
315 source_id,
316 worker_id,
317 attempt.saturating_add(1),
318 max_retries
319 );
320 tokio::time::sleep(std::time::Duration::from_millis(
321 100_u64.saturating_mul(u64::from(attempt).saturating_add(1)),
322 ))
323 .await;
324 }
325 Err(e) => return Err(e.into()),
326 }
327 }
328
329 if !status_updated {
330 return Err(format!(
331 "Failed to update status for source '{}' instance '{}' after {} retries",
332 source_id, worker_id, max_retries
333 )
334 .into());
335 }
336
337 info!(
338 "Published metadata for '{}' (source_id={}, worker_id={}): rank {} ({} tensors)",
339 model_name,
340 source_id,
341 worker_id,
342 worker_record.worker_rank,
343 worker_record.tensors.len(),
344 );
345
346 Ok(())
347 }
348
349 async fn get_metadata(
350 &self,
351 source_id: &str,
352 worker_id: &str,
353 ) -> MetadataResult<Option<ModelMetadataRecord>> {
354 let api = self.model_metadata_api();
355 let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
356
357 let cr = match api.get_opt(&cr_name).await? {
358 Some(cr) => cr,
359 None => {
360 debug!(
361 "No ModelMetadata CR found for source_id={} worker_id={}",
362 source_id, worker_id
363 );
364 return Ok(None);
365 }
366 };
367
368 let status = match cr.status {
369 Some(s) => s,
370 None => {
371 debug!("ModelMetadata CR '{}' has no status", cr_name);
372 return Ok(None);
373 }
374 };
375
376 let mut workers = Vec::new();
377 if let Some(worker_status) = status.worker {
378 let nixl_bytes = if !worker_status.nixl_metadata.is_empty() {
379 BASE64.decode(&worker_status.nixl_metadata).map_err(|e| {
380 format!(
381 "Failed to decode NIXL metadata for worker {}: {}",
382 worker_status.worker_rank, e
383 )
384 })?
385 } else {
386 Vec::new()
387 };
388 let backend_metadata = super::BackendMetadataRecord::from_flat(
389 nixl_bytes,
390 worker_status.transfer_engine_session_id.clone(),
391 worker_status.backend_type.as_deref(),
392 );
393
394 let tensors = if let Some(cm_name) = &worker_status.tensor_config_map {
395 match self.read_tensor_configmap(cm_name).await {
396 Ok(t) => t,
397 Err(e) => {
398 warn!("Failed to read tensor ConfigMap '{}': {}", cm_name, e);
399 Vec::new()
400 }
401 }
402 } else {
403 Vec::new()
404 };
405
406 let status = WorkerStatus::status_proto_from_name(&worker_status.status);
407 let updated_at = worker_status
408 .updated_at
409 .as_deref()
410 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
411 .map(|dt| dt.timestamp_millis())
412 .unwrap_or(0);
413
414 workers.push(WorkerRecord {
415 worker_rank: worker_status.worker_rank as u32,
416 backend_metadata,
417 tensors,
418 status,
419 updated_at,
420 metadata_endpoint: worker_status.metadata_endpoint.clone(),
421 agent_name: worker_status.agent_name.clone(),
422 worker_grpc_endpoint: worker_status.worker_grpc_endpoint.clone(),
423 });
424 }
425
426 let published_at = status
427 .published_at
428 .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
429 .map(|dt| dt.timestamp())
430 .unwrap_or(0);
431
432 debug!(
433 "Retrieved metadata for source_id={} worker_id={}: {} workers",
434 source_id,
435 worker_id,
436 workers.len()
437 );
438
439 Ok(Some(ModelMetadataRecord {
440 source_id: source_id.to_string(),
441 worker_id: worker_id.to_string(),
442 model_name: cr.spec.model_name.clone(),
443 workers,
444 published_at,
445 }))
446 }
447
448 async fn list_workers(
449 &self,
450 source_id: Option<String>,
451 status_filter: Option<SourceStatus>,
452 ) -> MetadataResult<Vec<super::SourceInstanceInfo>> {
453 let api = self.model_metadata_api();
454
455 let label_selector = match source_id {
456 Some(sid) => format!("modelexpress.nvidia.com/mx-source-id={}", sid),
457 None => String::new(),
458 };
459
460 let list_params = if label_selector.is_empty() {
461 ListParams::default()
462 } else {
463 ListParams::default().labels(&label_selector)
464 };
465
466 let crs = api.list(&list_params).await?;
467 let mut result = Vec::new();
468 for cr in crs.items {
469 let sid = cr
470 .metadata
471 .labels
472 .as_ref()
473 .and_then(|l| l.get("modelexpress.nvidia.com/mx-source-id"))
474 .cloned()
475 .unwrap_or_default();
476 let iid = cr
477 .metadata
478 .labels
479 .as_ref()
480 .and_then(|l| l.get("modelexpress.nvidia.com/mx-worker-id"))
481 .cloned()
482 .unwrap_or_default();
483
484 let worker_rank = cr
485 .status
486 .as_ref()
487 .and_then(|s| s.worker.as_ref())
488 .map(|w| w.worker_rank as u32)
489 .unwrap_or(0);
490
491 if let Some(required_status) = status_filter {
492 let required_name =
493 crate::k8s_types::WorkerStatus::status_name_from_proto(required_status as i32);
494 let matches = cr
495 .status
496 .as_ref()
497 .map(|s| s.worker.as_ref().is_some_and(|w| w.status == required_name))
498 .unwrap_or(false);
499 if !matches {
500 continue;
501 }
502 }
503
504 let (status, updated_at) = cr
505 .status
506 .as_ref()
507 .and_then(|s| s.worker.as_ref())
508 .map(|w| {
509 let proto_status =
510 crate::k8s_types::WorkerStatus::status_proto_from_name(&w.status);
511 let millis = w
512 .updated_at
513 .as_deref()
514 .and_then(|ts| chrono::DateTime::parse_from_rfc3339(ts).ok())
515 .map(|dt| dt.timestamp_millis())
516 .unwrap_or(0);
517 (proto_status, millis)
518 })
519 .unwrap_or((0, 0));
520
521 result.push(super::SourceInstanceInfo {
522 source_id: sid,
523 worker_id: iid,
524 model_name: cr.spec.model_name,
525 worker_rank,
526 status,
527 updated_at,
528 });
529 }
530
531 Ok(result)
532 }
533
534 async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()> {
535 let api = self.model_metadata_api();
536
537 let crs = api
539 .list(&ListParams::default().labels(&format!(
540 "modelexpress.nvidia.com/mx-source-id={}",
541 source_id
542 )))
543 .await?;
544
545 for cr in crs.items {
546 if let Some(name) = cr.metadata.name {
547 match api.delete(&name, &kube::api::DeleteParams::default()).await {
548 Ok(_) => info!("Deleted ModelMetadata CR '{}'", name),
549 Err(kube::Error::Api(err)) if err.code == 404 => {
550 debug!("ModelMetadata CR '{}' not found", name);
551 }
552 Err(e) => return Err(e.into()),
553 }
554 }
555 }
556
557 let cm_api = self.configmap_api();
559 let cms = cm_api
560 .list(&ListParams::default().labels(&format!(
561 "modelexpress.nvidia.com/mx-source-id={}",
562 source_id
563 )))
564 .await?;
565
566 for cm in cms {
567 if let Some(name) = cm.metadata.name {
568 match cm_api
569 .delete(&name, &kube::api::DeleteParams::default())
570 .await
571 {
572 Ok(_) => debug!("Deleted ConfigMap '{}'", name),
573 Err(e) => warn!("Failed to delete ConfigMap '{}': {}", name, e),
574 }
575 }
576 }
577
578 Ok(())
579 }
580
581 async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()> {
582 let api = self.model_metadata_api();
583 let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
584
585 match api
586 .delete(&cr_name, &kube::api::DeleteParams::default())
587 .await
588 {
589 Ok(_) => info!("Deleted ModelMetadata CR '{}'", cr_name),
590 Err(kube::Error::Api(err)) if err.code == 404 => {
591 debug!("ModelMetadata CR '{}' already gone", cr_name);
592 }
593 Err(e) => return Err(e.into()),
594 }
595
596 Ok(())
597 }
598
599 async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>> {
600 let api = self.model_metadata_api();
601 let crs = api.list(&ListParams::default()).await?;
602
603 let mut seen = std::collections::BTreeMap::new();
605 for cr in crs.items {
606 let source_id = cr
607 .metadata
608 .labels
609 .as_ref()
610 .and_then(|l| l.get("modelexpress.nvidia.com/mx-source-id"))
611 .cloned();
612 if let Some(sid) = source_id {
613 seen.entry(sid).or_insert_with(|| cr.spec.model_name);
614 }
615 }
616
617 Ok(seen.into_iter().collect())
618 }
619
620 async fn update_status(
621 &self,
622 source_id: &str,
623 worker_id: &str,
624 worker_rank: u32,
625 status: SourceStatus,
626 updated_at: i64,
627 ) -> MetadataResult<()> {
628 let api = self.model_metadata_api();
629 let cr_name = format!("mx-source-{}-{}", source_id, worker_id);
630 let status_name = WorkerStatus::status_name_from_proto(status as i32);
631 let updated_at_rfc3339 = chrono::DateTime::from_timestamp_millis(updated_at)
632 .map(|dt| dt.to_rfc3339())
633 .unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
634
635 let max_retries: u32 = 5;
636 for attempt in 0..max_retries {
637 let current = api.get(&cr_name).await?;
638 let mut crd_status = current.status.ok_or_else(|| {
639 format!(
640 "update_status: no status in source '{}' worker '{}'",
641 source_id, worker_id
642 )
643 })?;
644
645 let mut worker = crd_status.worker.take().ok_or_else(|| {
646 format!(
647 "update_status: no worker in source '{}' worker '{}'",
648 source_id, worker_id
649 )
650 })?;
651
652 worker.status = status_name.clone();
653 worker.updated_at = Some(updated_at_rfc3339.clone());
654
655 crd_status.update_ready_condition(status as i32);
656
657 let generation = current.metadata.generation.unwrap_or(0);
658 let resource_version = current.metadata.resource_version.unwrap_or_default();
659 let status_patch = serde_json::json!({
660 "metadata": { "resourceVersion": resource_version },
661 "status": {
662 "worker": worker,
663 "conditions": crd_status.conditions,
664 "observedGeneration": generation
665 }
666 });
667
668 match api
669 .patch_status(
670 &cr_name,
671 &PatchParams::default(),
672 &Patch::Merge(&status_patch),
673 )
674 .await
675 {
676 Ok(_) => {
677 debug!(
678 "Updated status for source '{}' worker '{}' rank {} -> {}",
679 source_id, worker_id, worker_rank, status_name
680 );
681 return Ok(());
682 }
683 Err(kube::Error::Api(err)) if err.code == 409 => {
684 debug!(
685 "Conflict updating status for source '{}' worker '{}', retrying ({}/{})",
686 source_id,
687 worker_id,
688 attempt.saturating_add(1),
689 max_retries
690 );
691 tokio::time::sleep(std::time::Duration::from_millis(
692 100_u64.saturating_mul(u64::from(attempt).saturating_add(1)),
693 ))
694 .await;
695 }
696 Err(e) => return Err(e.into()),
697 }
698 }
699
700 Err(format!(
701 "Failed to update status for source '{}' worker '{}' rank {} after {} retries",
702 source_id, worker_id, worker_rank, max_retries
703 )
704 .into())
705 }
706}