use std::{collections::HashSet, sync::Arc};
use dashmap::{DashMap, mapref::entry::Entry};
use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig;
use super::{KvWorkerMonitor, Model, RuntimeConfigWatch, WorkerSet, runtime_config_watch};
use dynamo_runtime::{
component::{Endpoint, build_transport_type},
discovery::DiscoverySpec,
prelude::DistributedRuntimeProvider,
protocols::EndpointId,
};
use crate::{
kv_router::{
KvRouter, KvRouterConfig, protocols::WorkerId, router_endpoint_id,
scheduler::DefaultWorkerSelector,
},
local_model::runtime_config::DisaggregatedEndpoint,
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
},
},
};
enum PrefillActivationState {
DecodeWaiting(oneshot::Sender<Endpoint>),
PrefillReady(oneshot::Receiver<Endpoint>),
}
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Model already exists: {0}")]
ModelAlreadyExists(String),
}
pub struct ModelManager {
models: DashMap<String, Arc<Model>>,
cards: DashMap<String, ModelDeploymentCard>,
prefill_router_activators: DashMap<String, PrefillActivationState>,
runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>,
}
impl Default for ModelManager {
fn default() -> Self {
Self::new()
}
}
impl ModelManager {
pub fn new() -> Self {
Self {
models: DashMap::new(),
cards: DashMap::new(),
prefill_router_activators: DashMap::new(),
runtime_configs: DashMap::new(),
}
}
pub fn get_or_create_model(&self, model_name: &str) -> Arc<Model> {
self.models
.entry(model_name.to_string())
.or_insert_with(|| Arc::new(Model::new(model_name.to_string())))
.clone()
}
pub fn get_model(&self, model_name: &str) -> Option<Arc<Model>> {
self.models
.get(model_name)
.map(|entry| entry.value().clone())
}
pub fn remove_model_if_empty(&self, model_name: &str) {
if self
.models
.remove_if(model_name, |_, model| model.is_empty())
.is_some()
{
tracing::info!(model_name, "Removed empty model from manager");
}
}
pub fn add_worker_set(&self, model_name: &str, namespace: &str, worker_set: WorkerSet) {
let model = self.get_or_create_model(model_name);
model.add_worker_set(namespace.to_string(), Arc::new(worker_set));
}
pub fn remove_worker_set(&self, model_name: &str, namespace: &str) -> Option<Arc<WorkerSet>> {
let model = self.models.get(model_name)?;
let removed = model.remove_worker_set(namespace);
drop(model);
self.remove_model_if_empty(model_name);
removed
}
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.iter().map(|r| r.value().clone()).collect()
}
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.insert(key.to_string(), card);
Ok(())
}
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.remove(key).map(|(_, v)| v)
}
pub fn has_decode_model(&self, model: &str) -> bool {
self.models
.get(model)
.is_some_and(|m| m.has_decode_engine())
}
pub fn has_prefill_model(&self, model: &str) -> bool {
self.models.get(model).is_some_and(|m| m.has_prefill())
}
pub fn has_model_any(&self, model: &str) -> bool {
self.has_decode_model(model) || self.has_prefill_model(model)
}
pub fn model_display_names(&self) -> HashSet<String> {
self.models
.iter()
.filter(|entry| entry.value().is_displayable())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_chat_completions_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_chat_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_completions_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_completions_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_embeddings_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_embeddings_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_tensor_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_tensor_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_images_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_images_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_videos_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_videos_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_prefill_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_prefill())
.map(|entry| entry.key().clone())
.collect()
}
pub fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_embeddings_engine()
}
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_completions_engine()
}
pub fn get_chat_completions_engine(
&self,
model: &str,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_chat_engine()
}
pub fn get_tensor_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_tensor_engine()
}
pub fn get_images_engine(
&self,
model: &str,
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_images_engine()
}
pub fn get_videos_engine(
&self,
model: &str,
) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_videos_engine()
}
pub fn get_chat_completions_engine_with_parsing(
&self,
model: &str,
) -> Result<
(
OpenAIChatCompletionsStreamingEngine,
crate::protocols::openai::ParsingOptions,
),
ModelManagerError,
> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_chat_engine_with_parsing()
}
pub fn get_completions_engine_with_parsing(
&self,
model: &str,
) -> Result<
(
OpenAICompletionsStreamingEngine,
crate::protocols::openai::ParsingOptions,
),
ModelManagerError,
> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_completions_engine_with_parsing()
}
pub fn add_chat_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_chat_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_chat_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.chat_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_completions_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_completions_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.completions_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_embeddings_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_embeddings_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_embeddings_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.embeddings_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_tensor_model(
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_tensor_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_tensor_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.tensor_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_images_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIImagesStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_images_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_images_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.images_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_videos_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIVideosStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_videos_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_videos_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.videos_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_prefill_model(
&self,
model: &str,
card_checksum: &str,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_prefill() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_prefill_{}", model);
let ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn remove_model(&self, model: &str) -> Option<Arc<Model>> {
self.models.remove(model).map(|(_, m)| m)
}
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_chat_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_completions_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_tensor_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_embeddings_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_images_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_videos_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub async fn kv_chooser_for(
&self,
endpoint: &Endpoint,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
) -> anyhow::Result<Arc<KvRouter>> {
let client = endpoint.client().await?;
let discovery = endpoint.component().drt().discovery();
let instance_id = discovery.instance_id();
let router_endpoint_id =
router_endpoint_id(endpoint.id().namespace, endpoint.id().component);
let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?;
let discovery_spec = DiscoverySpec::Endpoint {
namespace: router_endpoint_id.namespace.clone(),
component: router_endpoint_id.component.clone(),
endpoint: router_endpoint_id.name.clone(),
transport,
};
discovery.register(discovery_spec).await?;
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(
endpoint.clone(),
client,
workers_with_configs,
kv_cache_block_size,
Some(selector),
kv_router_config,
worker_type,
)
.await?;
Ok(Arc::new(chooser))
}
pub(crate) fn model_namespace_key(model_name: &str, namespace: &str) -> String {
format!("{}:{}", model_name, namespace)
}
pub fn register_prefill_router(
&self,
model_name: &str,
namespace: &str,
) -> Option<oneshot::Receiver<Endpoint>> {
let key = Self::model_namespace_key(model_name, namespace);
match self.prefill_router_activators.remove(&key) {
Some((_, PrefillActivationState::PrefillReady(rx))) => {
tracing::debug!(
model_name = %model_name,
namespace = %namespace,
"Prefill endpoint already available for namespace, returning receiver"
);
Some(rx)
}
Some((key, PrefillActivationState::DecodeWaiting(tx))) => {
tracing::error!(
model_name = %model_name,
namespace = %namespace,
"Decode WorkerSet already registered for this prefill router"
);
self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx));
None
}
None => {
let (tx, rx) = oneshot::channel();
self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx));
tracing::debug!(
model_name = %model_name,
namespace = %namespace,
"No prefill endpoint for namespace yet, storing sender for future activation"
);
Some(rx)
}
}
}
pub fn activate_prefill_router(
&self,
model_name: &str,
namespace: &str,
endpoint: Endpoint,
) -> anyhow::Result<()> {
let key = Self::model_namespace_key(model_name, namespace);
match self.prefill_router_activators.remove(&key) {
Some((_, PrefillActivationState::DecodeWaiting(sender))) => {
sender.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint to prefill router activator for {}:{}",
model_name,
namespace
)
})?;
tracing::info!(
model_name = %model_name,
namespace = %namespace,
"Activated prefill router for decode WorkerSet"
);
Ok(())
}
Some((_, PrefillActivationState::PrefillReady(_))) => {
anyhow::bail!(
"Prefill router for {}:{} already activated",
model_name,
namespace
);
}
None => {
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint for prefill model {}:{}",
model_name,
namespace
)
})?;
self.prefill_router_activators
.insert(key, PrefillActivationState::PrefillReady(rx));
tracing::info!(
model_name = %model_name,
namespace = %namespace,
"Stored prefill endpoint for future decode WorkerSet registration"
);
Ok(())
}
}
}
pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) {
let key = Self::model_namespace_key(model_name, namespace);
if self.prefill_router_activators.remove(&key).is_some() {
tracing::debug!(
model_name = %model_name,
namespace = %namespace,
"Cleaned up prefill router activator for removed WorkerSet"
);
}
}
pub fn load_threshold_config(
&self,
model: &str,
config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> {
let model_entry = self.models.get(model)?;
model_entry.load_threshold_config(config)
}
pub fn get_worker_monitor_for_namespace(
&self,
model: &str,
namespace: &str,
) -> Option<KvWorkerMonitor> {
let model_entry = self.models.get(model)?;
model_entry.get_worker_monitor_for_namespace(namespace)
}
pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> {
let mut result = Vec::new();
for entry in self.models.iter() {
if let Some(config) = entry.value().load_threshold_config(None) {
result.push((entry.key().clone(), config));
}
}
result
}
pub async fn get_or_create_runtime_config_watcher(
&self,
endpoint: &Endpoint,
) -> anyhow::Result<RuntimeConfigWatch> {
let endpoint_id = endpoint.id();
if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
return Ok(existing.clone());
}
let rx = runtime_config_watch(endpoint).await?;
let result = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => {
e.insert(rx.clone());
rx
}
};
Ok(result)
}
pub fn get_disaggregated_endpoint(
&self,
endpoint_id: &EndpointId,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let rx = self.runtime_configs.get(endpoint_id)?;
let configs = rx.borrow();
configs.get(&worker_id)?.disaggregated_endpoint.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_add_and_get_worker_set() {
let mm = ModelManager::new();
let ws = make_worker_set("ns1", "abc");
mm.add_worker_set("llama", "ns1", ws);
let model = mm.get_model("llama");
assert!(model.is_some());
let model = model.unwrap();
assert!(model.has_worker_set("ns1"));
assert_eq!(model.worker_set_count(), 1);
}
#[test]
fn test_add_worker_set_creates_model() {
let mm = ModelManager::new();
assert!(mm.get_model("llama").is_none());
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
assert!(mm.get_model("llama").is_some());
}
#[test]
fn test_remove_worker_set_removes_empty_model() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
assert!(mm.get_model("llama").is_some());
let removed = mm.remove_worker_set("llama", "ns1");
assert!(removed.is_some());
assert_eq!(removed.unwrap().namespace(), "ns1");
assert!(mm.get_model("llama").is_none());
}
#[test]
fn test_remove_worker_set_keeps_model_with_remaining() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"));
mm.remove_worker_set("llama", "ns1");
let model = mm.get_model("llama").unwrap();
assert!(!model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
assert_eq!(model.worker_set_count(), 1);
}
#[test]
fn test_remove_worker_set_nonexistent_model() {
let mm = ModelManager::new();
assert!(mm.remove_worker_set("llama", "ns1").is_none());
}
#[test]
fn test_remove_worker_set_nonexistent_namespace() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
assert!(mm.remove_worker_set("llama", "ns2").is_none());
assert!(mm.get_model("llama").is_some());
}
#[test]
fn test_remove_model_if_empty_noop_when_not_empty() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
mm.remove_model_if_empty("llama");
assert!(mm.get_model("llama").is_some()); }
#[test]
fn test_remove_model_if_empty_noop_when_missing() {
let mm = ModelManager::new();
mm.remove_model_if_empty("nonexistent"); }
#[test]
fn test_remove_model() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"));
let removed = mm.remove_model("llama");
assert!(removed.is_some());
assert!(mm.get_model("llama").is_none());
}
#[test]
fn test_get_or_create_model_idempotent() {
let mm = ModelManager::new();
let m1 = mm.get_or_create_model("llama");
let m2 = mm.get_or_create_model("llama");
assert!(Arc::ptr_eq(&m1, &m2));
}
#[test]
fn test_has_decode_model() {
let mm = ModelManager::new();
assert!(!mm.has_decode_model("llama"));
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
assert!(!mm.has_decode_model("llama"));
}
#[test]
fn test_has_prefill_model() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
assert!(mm.has_prefill_model("llama"));
}
#[test]
fn test_has_model_any() {
let mm = ModelManager::new();
assert!(!mm.has_model_any("llama"));
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
assert!(mm.has_model_any("llama")); }
#[test]
fn test_model_display_names_includes_prefill() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
let names = mm.model_display_names();
assert!(names.contains("llama"));
}
#[test]
fn test_model_display_names_empty() {
let mm = ModelManager::new();
assert!(mm.model_display_names().is_empty());
}
#[test]
fn test_list_prefill_models() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def"));
let prefill = mm.list_prefill_models();
assert_eq!(prefill.len(), 2);
assert!(prefill.contains(&"llama".to_string()));
assert!(prefill.contains(&"gpt".to_string()));
}
#[test]
fn test_save_and_remove_model_card() {
let mm = ModelManager::new();
let card = ModelDeploymentCard::default();
mm.save_model_card("instance/key/1", card.clone()).unwrap();
let cards = mm.get_model_cards();
assert_eq!(cards.len(), 1);
let removed = mm.remove_model_card("instance/key/1");
assert!(removed.is_some());
assert!(mm.get_model_cards().is_empty());
}
#[test]
fn test_remove_model_card_nonexistent() {
let mm = ModelManager::new();
assert!(mm.remove_model_card("nonexistent").is_none());
}
#[test]
fn test_prefill_router_register_new() {
let mm = ModelManager::new();
let rx = mm.register_prefill_router("llama", "ns1");
assert!(rx.is_some());
}
#[test]
fn test_prefill_router_double_register_returns_none() {
let mm = ModelManager::new();
let rx1 = mm.register_prefill_router("llama", "ns1");
assert!(rx1.is_some());
let rx2 = mm.register_prefill_router("llama", "ns1");
assert!(rx2.is_none());
}
#[test]
fn test_prefill_router_different_namespaces_independent() {
let mm = ModelManager::new();
let rx1 = mm.register_prefill_router("llama", "ns1");
let rx2 = mm.register_prefill_router("llama", "ns2");
assert!(rx1.is_some());
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_different_models_independent() {
let mm = ModelManager::new();
let rx1 = mm.register_prefill_router("llama", "ns1");
let rx2 = mm.register_prefill_router("gpt", "ns1");
assert!(rx1.is_some());
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_remove_allows_reregister() {
let mm = ModelManager::new();
let rx = mm.register_prefill_router("llama", "ns1");
assert!(rx.is_some());
mm.remove_prefill_activator("llama", "ns1");
let rx2 = mm.register_prefill_router("llama", "ns1");
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_remove_nonexistent_noop() {
let mm = ModelManager::new();
mm.remove_prefill_activator("llama", "ns1");
}
#[test]
fn test_model_namespace_key_format() {
assert_eq!(
ModelManager::model_namespace_key("llama", "ns1"),
"llama:ns1"
);
assert_eq!(
ModelManager::model_namespace_key("gpt-4", "default-abc"),
"gpt-4:default-abc"
);
}
}