Skip to main content

modelexpress_server/
metadata_backend.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Metadata backend abstraction for P2P model metadata.
5//!
6//! Supports two persistent backends:
7//! - **Redis**: Persistent storage via Redis keys + atomic Lua merge
8//! - **Kubernetes**: CRDs and ConfigMaps for native K8s integration
9//!
10//! Select the backend via `MX_METADATA_BACKEND=redis` or `MX_METADATA_BACKEND=kubernetes`.
11
12use async_trait::async_trait;
13use modelexpress_common::grpc::p2p::{SourceIdentity, SourceStatus, WorkerMetadata};
14use std::sync::Arc;
15
16pub mod kubernetes;
17pub mod redis;
18
19/// Result type for metadata operations
20pub type MetadataResult<T> = Result<T, Box<dyn std::error::Error + Send + Sync>>;
21
22/// Model metadata record returned from backends
23#[derive(Debug, Clone)]
24pub struct ModelMetadataRecord {
25    /// 16-char hex key derived from SourceIdentity hash
26    pub source_id: String,
27    /// Unique identifier for this running worker (UUID)
28    pub worker_id: String,
29    /// Human-readable model name from SourceIdentity
30    pub model_name: String,
31    pub workers: Vec<WorkerRecord>,
32    pub published_at: i64,
33}
34
35/// Lightweight reference to a source worker (no tensor metadata).
36/// Used by `list_workers` to support the `ListSources` RPC.
37#[derive(Debug, Clone)]
38pub struct SourceInstanceInfo {
39    pub source_id: String,
40    pub worker_id: String,
41    pub model_name: String,
42    /// Global rank of this worker.
43    pub worker_rank: u32,
44    /// Worker lifecycle status (maps to `SourceStatus` proto enum).
45    pub status: i32,
46    /// Timestamp of last status update (unix millis).
47    pub updated_at: i64,
48}
49
50/// Backend-specific metadata for a worker
51#[derive(Debug, Clone, PartialEq)]
52pub enum BackendMetadataRecord {
53    /// Serialized NIXL agent metadata for RDMA connections
54    Nixl(Vec<u8>),
55    /// Mooncake TransferEngine session ID ("ip:port")
56    TransferEngine(String),
57    /// No backend metadata provided
58    None,
59}
60
61impl BackendMetadataRecord {
62    /// Reconstruct from flat fields (used by Redis JSON and K8s CRD deserialization).
63    ///
64    /// When `backend_type` is provided, it is used as the authoritative discriminator.
65    /// Falls back to field-inference for backwards compatibility with records written
66    /// before `backend_type` was persisted.
67    pub fn from_flat(
68        nixl_metadata: Vec<u8>,
69        transfer_engine_session_id: Option<String>,
70        backend_type: Option<&str>,
71    ) -> Self {
72        match backend_type {
73            Some("transfer_engine") => {
74                let sid = transfer_engine_session_id.unwrap_or_default();
75                Self::TransferEngine(sid)
76            }
77            Some("nixl") => Self::Nixl(nixl_metadata),
78            Some("none") => Self::None,
79            // Unknown or missing backend_type: infer from fields (backwards compat)
80            _ => {
81                if let Some(sid) = transfer_engine_session_id
82                    && !sid.is_empty()
83                {
84                    return Self::TransferEngine(sid);
85                }
86                if !nixl_metadata.is_empty() {
87                    return Self::Nixl(nixl_metadata);
88                }
89                Self::None
90            }
91        }
92    }
93
94    /// Returns the backend type string for persistence.
95    pub fn backend_type_str(&self) -> &'static str {
96        match self {
97            Self::Nixl(_) => "nixl",
98            Self::TransferEngine(_) => "transfer_engine",
99            Self::None => "none",
100        }
101    }
102}
103
104/// Worker metadata record
105#[derive(Debug, Clone)]
106pub struct WorkerRecord {
107    pub worker_rank: u32,
108    pub backend_metadata: BackendMetadataRecord,
109    pub tensors: Vec<TensorRecord>,
110    /// Worker lifecycle status (maps to `SourceStatus` proto enum)
111    pub status: i32,
112    /// Timestamp of last status update (unix millis)
113    pub updated_at: i64,
114    /// P2P: NIXL listen thread endpoint (host:port)
115    pub metadata_endpoint: String,
116    /// P2P: NIXL agent name for remote identification
117    pub agent_name: String,
118    /// P2P: Worker gRPC endpoint for tensor manifest (host:port)
119    pub worker_grpc_endpoint: String,
120}
121
122/// Tensor descriptor record
123#[derive(Debug, Clone)]
124pub struct TensorRecord {
125    pub name: String,
126    pub addr: u64,
127    pub size: u64,
128    pub device_id: u32,
129    pub dtype: String,
130}
131
132// Conversions from gRPC types
133impl From<WorkerMetadata> for WorkerRecord {
134    fn from(meta: WorkerMetadata) -> Self {
135        use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
136        let backend_metadata = match meta.backend_metadata {
137            Some(BackendMetadata::NixlMetadata(data)) => BackendMetadataRecord::Nixl(data),
138            Some(BackendMetadata::TransferEngineSessionId(sid)) => {
139                BackendMetadataRecord::TransferEngine(sid)
140            }
141            None => BackendMetadataRecord::None,
142        };
143        Self {
144            worker_rank: meta.worker_rank,
145            backend_metadata,
146            tensors: meta.tensors.into_iter().map(TensorRecord::from).collect(),
147            status: meta.status,
148            updated_at: meta.updated_at,
149            metadata_endpoint: meta.metadata_endpoint,
150            agent_name: meta.agent_name,
151            worker_grpc_endpoint: meta.worker_grpc_endpoint,
152        }
153    }
154}
155
156impl From<modelexpress_common::grpc::p2p::TensorDescriptor> for TensorRecord {
157    fn from(desc: modelexpress_common::grpc::p2p::TensorDescriptor) -> Self {
158        Self {
159            name: desc.name,
160            addr: desc.addr,
161            size: desc.size,
162            device_id: desc.device_id,
163            dtype: desc.dtype,
164        }
165    }
166}
167
168// Conversions back to gRPC types
169impl From<WorkerRecord> for WorkerMetadata {
170    fn from(record: WorkerRecord) -> Self {
171        use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
172        let backend_metadata = match record.backend_metadata {
173            BackendMetadataRecord::Nixl(data) => Some(BackendMetadata::NixlMetadata(data)),
174            BackendMetadataRecord::TransferEngine(sid) => {
175                Some(BackendMetadata::TransferEngineSessionId(sid))
176            }
177            BackendMetadataRecord::None => None,
178        };
179        Self {
180            worker_rank: record.worker_rank,
181            backend_metadata,
182            tensors: record
183                .tensors
184                .into_iter()
185                .map(modelexpress_common::grpc::p2p::TensorDescriptor::from)
186                .collect(),
187            status: record.status,
188            updated_at: record.updated_at,
189            metadata_endpoint: record.metadata_endpoint,
190            agent_name: record.agent_name,
191            worker_grpc_endpoint: record.worker_grpc_endpoint,
192        }
193    }
194}
195
196impl From<TensorRecord> for modelexpress_common::grpc::p2p::TensorDescriptor {
197    fn from(record: TensorRecord) -> Self {
198        Self {
199            name: record.name,
200            addr: record.addr,
201            size: record.size,
202            device_id: record.device_id,
203            dtype: record.dtype,
204        }
205    }
206}
207
208/// Trait for metadata backend implementations
209#[cfg_attr(test, mockall::automock)]
210#[async_trait]
211pub trait MetadataBackend: Send + Sync {
212    /// Connect to the backend (initialize connections, etc.)
213    async fn connect(&self) -> MetadataResult<()>;
214
215    /// Publish metadata for a source worker.
216    /// `worker_id` uniquely identifies this running pod/process among all replicas
217    /// with the same identity. The backend derives `mx_source_id` from `identity`.
218    async fn publish_metadata(
219        &self,
220        identity: &SourceIdentity,
221        worker_id: &str,
222        worker: WorkerMetadata,
223    ) -> MetadataResult<()>;
224
225    /// Get full tensor metadata for one specific worker.
226    /// Returns `None` if the worker is not found.
227    async fn get_metadata(
228        &self,
229        source_id: &str,
230        worker_id: &str,
231    ) -> MetadataResult<Option<ModelMetadataRecord>>;
232
233    /// List available workers, optionally filtered by source_id and status.
234    /// `source_id`: if `Some`, return only workers for that source; if `None`, all sources.
235    /// `status_filter`: if `Some(s)`, return only workers where all workers have status `s`.
236    async fn list_workers(
237        &self,
238        source_id: Option<String>,
239        status_filter: Option<SourceStatus>,
240    ) -> MetadataResult<Vec<SourceInstanceInfo>>;
241
242    /// Remove all workers of a source by mx_source_id
243    async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()>;
244
245    /// Remove a single worker by source_id and worker_id.
246    /// Used by the reaper to garbage-collect individual stale entries.
247    async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()>;
248
249    /// List all registered source IDs and their model names
250    async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>>;
251
252    /// Patch the status of a worker for a specific worker.
253    async fn update_status(
254        &self,
255        source_id: &str,
256        worker_id: &str,
257        worker_rank: u32,
258        status: SourceStatus,
259        updated_at: i64,
260    ) -> MetadataResult<()>;
261}
262
263/// Configuration for metadata backends
264#[derive(Debug, Clone)]
265pub enum BackendConfig {
266    /// Redis backend — persistent, horizontally scalable
267    Redis { url: String },
268    /// Kubernetes CRD backend — native K8s integration
269    Kubernetes { namespace: String },
270}
271
272impl std::fmt::Display for BackendConfig {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        match self {
275            Self::Redis { .. } => write!(f, "redis"),
276            Self::Kubernetes { .. } => write!(f, "kubernetes"),
277        }
278    }
279}
280
281impl BackendConfig {
282    /// Create backend config from environment variables.
283    ///
284    /// `MX_METADATA_BACKEND` is required. Valid values:
285    /// - `redis`: Redis
286    /// - `kubernetes` | `k8s` | `crd`: Kubernetes CRD
287    pub fn from_env() -> Result<Self, String> {
288        let backend_type = std::env::var("MX_METADATA_BACKEND").unwrap_or_default();
289        let redis_url = Self::redis_url_from_env();
290        let k8s_namespace = Self::k8s_namespace_from_env();
291        Self::from_type_str(&backend_type, &redis_url, &k8s_namespace)
292    }
293
294    /// Parse a backend type string into a config. Testable without env vars.
295    pub fn from_type_str(
296        backend_type: &str,
297        redis_url: &str,
298        k8s_namespace: &str,
299    ) -> Result<Self, String> {
300        match backend_type.to_lowercase().as_str() {
301            "redis" => Ok(Self::Redis {
302                url: redis_url.to_string(),
303            }),
304            "kubernetes" | "k8s" | "crd" => Ok(Self::Kubernetes {
305                namespace: k8s_namespace.to_string(),
306            }),
307            other => Err(format!(
308                "MX_METADATA_BACKEND='{}' is not valid. Use 'redis' or 'kubernetes'.",
309                other
310            )),
311        }
312    }
313
314    pub fn redis_url_from_env() -> String {
315        if let Ok(url) = std::env::var("REDIS_URL") {
316            return url;
317        }
318        let host = std::env::var("MX_REDIS_HOST")
319            .or_else(|_| std::env::var("REDIS_HOST"))
320            .unwrap_or_else(|_| "localhost".to_string());
321        let port = std::env::var("MX_REDIS_PORT")
322            .or_else(|_| std::env::var("REDIS_PORT"))
323            .unwrap_or_else(|_| "6379".to_string());
324        format!("redis://{}:{}", host, port)
325    }
326
327    fn k8s_namespace_from_env() -> String {
328        std::env::var("MX_METADATA_NAMESPACE")
329            .or_else(|_| std::env::var("POD_NAMESPACE"))
330            .unwrap_or_else(|_| "default".to_string())
331    }
332}
333
334/// Create a backend from configuration.
335pub async fn create_backend(config: BackendConfig) -> MetadataResult<Arc<dyn MetadataBackend>> {
336    match config {
337        BackendConfig::Redis { url } => {
338            let backend = redis::RedisBackend::new(&url);
339            backend.connect().await?;
340            Ok(Arc::new(backend) as Arc<dyn MetadataBackend>)
341        }
342        BackendConfig::Kubernetes { namespace } => {
343            let backend = kubernetes::KubernetesBackend::new(&namespace).await?;
344            backend.connect().await?;
345            Ok(Arc::new(backend) as Arc<dyn MetadataBackend>)
346        }
347    }
348}