use std::sync::Arc;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use dynamo_runtime::{
component::{self, ComponentEndpointInfo},
pipeline::{
network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
ServiceBackend, SingleIn, Source,
},
protocols::{self, annotated::Annotated},
slug::Slug,
transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime,
};
use super::ModelManager;
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use crate::{
backend::Backend,
model_type::ModelType,
preprocessor::{BackendInput, OpenAIPreprocessor},
protocols::common::llm_backend::LLMEngineOutput,
};
use crate::{
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{self, ModelDeploymentCard},
};
use tracing;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
pub name: String,
pub endpoint: protocols::Endpoint,
pub model_type: ModelType,
}
impl ModelEntry {
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend)
}
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store
.load::<ModelDeploymentCard>(model_card::BUCKET_NAME, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
ModelNetworkName(
Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}")).to_string(),
)
}
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
}
pub async fn load_entry(&self, etcd_client: etcd::Client) -> anyhow::Result<ModelEntry> {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?;
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {self}");
}
let model_entry = model_entries.remove(0);
serde_json::from_slice(model_entry.value()).with_context(|| {
format!(
"Error deserializing JSON. Key={self}. JSON={}",
model_entry.value_str().unwrap_or("INVALID UTF-8")
)
})
}
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let entry = self.load_entry(etcd_client.clone()).await?;
entry.load_mdc(endpoint_id, etcd_client).await
}
}
impl From<&ComponentEndpointInfo> for ModelNetworkName {
fn from(cei: &ComponentEndpointInfo) -> Self {
Self::from_parts(&cei.namespace, &cei.component, &cei.endpoint, cei.lease_id)
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct ModelWatchState {
pub prefix: String,
pub manager: ModelManager,
pub drt: DistributedRuntime,
}
pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started");
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry");
continue;
}
};
if state.manager.has_model_any(&model_entry.name) {
tracing::trace!(
service_name = model_entry.name,
"New endpoint for existing model"
);
continue;
}
match handle_put(&model_entry, state.clone()).await {
Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model");
}
Err(e) => {
tracing::error!(%e, "error adding model {}", model_entry.name);
}
}
}
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok(model_name) => {
tracing::info!("removed model {}", model_name);
}
Err(e) => {
tracing::error!("error removing model: {}", e);
}
},
}
}
}
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> anyhow::Result<&str> {
let key = kv.key_str()?;
tracing::debug!(key, "removing model");
let model_name = key.trim_start_matches(&state.prefix);
let _ = state.manager.remove_chat_completions_model(model_name);
let _ = state.manager.remove_completions_model(model_name);
Ok(model_name)
}
async fn handle_put(model_entry: &ModelEntry, state: Arc<ModelWatchState>) -> anyhow::Result<()> {
let endpoint_id = model_entry.endpoint.clone();
let client = state
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?
.endpoint(&endpoint_id.name)
.client()
.await?;
let Some(etcd_client) = state.drt.etcd_client() else {
anyhow::bail!("Missing etcd_client");
};
let card = match model_entry.load_mdc(endpoint_id, etcd_client).await {
Ok(card) => {
tracing::debug!(card.display_name, "adding model");
Some(card)
}
Err(err) => {
tracing::info!(%err, "load_mdc did not complete");
None
}
};
match model_entry.model_type {
ModelType::Backend => {
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card");
};
let _cache_dir = Some(card.move_from_nats(state.drt.nats_client()).await?);
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client.clone(),
RouterMode::Random, )
.await?;
let chat_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(ServiceBackend::from_engine(Arc::new(router)))?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
state
.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
let frontend = SegmentSource::<
SingleIn<CompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<BackendInput, Annotated<LLMEngineOutput>>::from_client(
client,
RouterMode::Random, )
.await?;
let completions_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(ServiceBackend::from_engine(Arc::new(router)))?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
state
.manager
.add_completions_model(&model_entry.name, completions_engine)?;
}
ModelType::Chat => {
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default())
.await?;
let engine = Arc::new(push_router);
state
.manager
.add_chat_completions_model(&model_entry.name, engine)?;
}
ModelType::Completion => {
let push_router =
PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
client,
Default::default(),
)
.await?;
let engine = Arc::new(push_router);
state
.manager
.add_completions_model(&model_entry.name, engine)?;
}
}
Ok(())
}