use async_trait::async_trait;
use modelexpress_common::grpc::p2p::{SourceIdentity, SourceStatus, WorkerMetadata};
use std::sync::Arc;
pub mod kubernetes;
pub mod redis;
pub type MetadataResult<T> = Result<T, Box<dyn std::error::Error + Send + Sync>>;
#[derive(Debug, Clone)]
pub struct ModelMetadataRecord {
pub source_id: String,
pub worker_id: String,
pub model_name: String,
pub workers: Vec<WorkerRecord>,
pub published_at: i64,
}
#[derive(Debug, Clone)]
pub struct SourceInstanceInfo {
pub source_id: String,
pub worker_id: String,
pub model_name: String,
pub worker_rank: u32,
pub status: i32,
pub updated_at: i64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum BackendMetadataRecord {
Nixl(Vec<u8>),
TransferEngine(String),
None,
}
impl BackendMetadataRecord {
pub fn from_flat(
nixl_metadata: Vec<u8>,
transfer_engine_session_id: Option<String>,
backend_type: Option<&str>,
) -> Self {
match backend_type {
Some("transfer_engine") => {
let sid = transfer_engine_session_id.unwrap_or_default();
Self::TransferEngine(sid)
}
Some("nixl") => Self::Nixl(nixl_metadata),
Some("none") => Self::None,
_ => {
if let Some(sid) = transfer_engine_session_id
&& !sid.is_empty()
{
return Self::TransferEngine(sid);
}
if !nixl_metadata.is_empty() {
return Self::Nixl(nixl_metadata);
}
Self::None
}
}
}
pub fn backend_type_str(&self) -> &'static str {
match self {
Self::Nixl(_) => "nixl",
Self::TransferEngine(_) => "transfer_engine",
Self::None => "none",
}
}
}
#[derive(Debug, Clone)]
pub struct WorkerRecord {
pub worker_rank: u32,
pub backend_metadata: BackendMetadataRecord,
pub tensors: Vec<TensorRecord>,
pub status: i32,
pub updated_at: i64,
pub metadata_endpoint: String,
pub agent_name: String,
pub worker_grpc_endpoint: String,
}
#[derive(Debug, Clone)]
pub struct TensorRecord {
pub name: String,
pub addr: u64,
pub size: u64,
pub device_id: u32,
pub dtype: String,
}
impl From<WorkerMetadata> for WorkerRecord {
fn from(meta: WorkerMetadata) -> Self {
use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
let backend_metadata = match meta.backend_metadata {
Some(BackendMetadata::NixlMetadata(data)) => BackendMetadataRecord::Nixl(data),
Some(BackendMetadata::TransferEngineSessionId(sid)) => {
BackendMetadataRecord::TransferEngine(sid)
}
None => BackendMetadataRecord::None,
};
Self {
worker_rank: meta.worker_rank,
backend_metadata,
tensors: meta.tensors.into_iter().map(TensorRecord::from).collect(),
status: meta.status,
updated_at: meta.updated_at,
metadata_endpoint: meta.metadata_endpoint,
agent_name: meta.agent_name,
worker_grpc_endpoint: meta.worker_grpc_endpoint,
}
}
}
impl From<modelexpress_common::grpc::p2p::TensorDescriptor> for TensorRecord {
fn from(desc: modelexpress_common::grpc::p2p::TensorDescriptor) -> Self {
Self {
name: desc.name,
addr: desc.addr,
size: desc.size,
device_id: desc.device_id,
dtype: desc.dtype,
}
}
}
impl From<WorkerRecord> for WorkerMetadata {
fn from(record: WorkerRecord) -> Self {
use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
let backend_metadata = match record.backend_metadata {
BackendMetadataRecord::Nixl(data) => Some(BackendMetadata::NixlMetadata(data)),
BackendMetadataRecord::TransferEngine(sid) => {
Some(BackendMetadata::TransferEngineSessionId(sid))
}
BackendMetadataRecord::None => None,
};
Self {
worker_rank: record.worker_rank,
backend_metadata,
tensors: record
.tensors
.into_iter()
.map(modelexpress_common::grpc::p2p::TensorDescriptor::from)
.collect(),
status: record.status,
updated_at: record.updated_at,
metadata_endpoint: record.metadata_endpoint,
agent_name: record.agent_name,
worker_grpc_endpoint: record.worker_grpc_endpoint,
}
}
}
impl From<TensorRecord> for modelexpress_common::grpc::p2p::TensorDescriptor {
fn from(record: TensorRecord) -> Self {
Self {
name: record.name,
addr: record.addr,
size: record.size,
device_id: record.device_id,
dtype: record.dtype,
}
}
}
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait MetadataBackend: Send + Sync {
async fn connect(&self) -> MetadataResult<()>;
async fn publish_metadata(
&self,
identity: &SourceIdentity,
worker_id: &str,
worker: WorkerMetadata,
) -> MetadataResult<()>;
async fn get_metadata(
&self,
source_id: &str,
worker_id: &str,
) -> MetadataResult<Option<ModelMetadataRecord>>;
async fn list_workers(
&self,
source_id: Option<String>,
status_filter: Option<SourceStatus>,
) -> MetadataResult<Vec<SourceInstanceInfo>>;
async fn remove_metadata(&self, source_id: &str) -> MetadataResult<()>;
async fn remove_worker(&self, source_id: &str, worker_id: &str) -> MetadataResult<()>;
async fn list_sources(&self) -> MetadataResult<Vec<(String, String)>>;
async fn update_status(
&self,
source_id: &str,
worker_id: &str,
worker_rank: u32,
status: SourceStatus,
updated_at: i64,
) -> MetadataResult<()>;
}
pub use crate::backend_config::BackendConfig;
pub async fn create_backend(config: BackendConfig) -> MetadataResult<Arc<dyn MetadataBackend>> {
match config {
BackendConfig::Redis { url } => {
let backend = redis::RedisBackend::new(&url);
backend.connect().await?;
Ok(Arc::new(backend) as Arc<dyn MetadataBackend>)
}
BackendConfig::Kubernetes { namespace } => {
let backend = kubernetes::KubernetesBackend::new(&namespace).await?;
backend.connect().await?;
Ok(Arc::new(backend) as Arc<dyn MetadataBackend>)
}
}
}