use std::sync::Arc;
use dashmap::DashMap;
use rand::Rng;
use super::worker_monitor::LoadThresholdConfig;
use super::worker_set::WorkerSet;
use super::{KvWorkerMonitor, ModelManagerError};
use crate::protocols::openai::ParsingOptions;
use crate::types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine,
},
};
pub struct Model {
name: String,
worker_sets: DashMap<String, Arc<WorkerSet>>,
}
impl Model {
pub fn new(name: String) -> Self {
Self {
name,
worker_sets: DashMap::new(),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn add_worker_set(&self, namespace: String, worker_set: Arc<WorkerSet>) {
tracing::info!(
model = %self.name,
namespace = %namespace,
"Adding worker set to model"
);
self.worker_sets.insert(namespace, worker_set);
}
pub fn is_checksum_compatible(&self, ws_key: &str, candidate_checksum: &str) -> bool {
match self.worker_sets.get(ws_key) {
Some(existing_ws) => existing_ws.mdcsum() == candidate_checksum,
None => true,
}
}
pub fn remove_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
let removed = self.worker_sets.remove(namespace).map(|(_, ws)| ws);
if removed.is_some() {
tracing::info!(
model = %self.name,
namespace = %namespace,
remaining_sets = self.worker_sets.len(),
"Removed worker set from model"
);
}
removed
}
pub fn has_worker_set(&self, namespace: &str) -> bool {
self.worker_sets.contains_key(namespace)
}
pub fn get_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
self.worker_sets
.get(namespace)
.map(|entry| entry.value().clone())
}
pub fn is_empty(&self) -> bool {
self.worker_sets.is_empty()
}
pub fn worker_set_count(&self) -> usize {
self.worker_sets.len()
}
pub fn has_decode_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_decode_engine())
}
pub fn has_prefill(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().is_prefill_set())
}
pub fn has_chat_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_chat_engine())
}
pub fn has_completions_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_completions_engine())
}
pub fn has_embeddings_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_embeddings_engine())
}
pub fn has_tensor_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_tensor_engine())
}
pub fn has_images_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_images_engine())
}
pub fn has_videos_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_videos_engine())
}
pub fn is_displayable(&self) -> bool {
let has_serving_engine = |ws: &WorkerSet| {
ws.has_chat_engine()
|| ws.has_completions_engine()
|| ws.has_embeddings_engine()
|| ws.has_images_engine()
|| ws.has_tensor_engine()
|| ws.has_videos_engine()
};
let has_any_serving_engine = self.worker_sets.iter().any(|entry| {
let ws = entry.value();
has_serving_engine(ws.as_ref())
});
self.worker_sets.iter().any(|entry| {
let ws = entry.value();
if ws.worker_count() == 0 {
return false;
}
has_serving_engine(ws.as_ref()) || (!has_any_serving_engine && ws.is_prefill_set())
})
}
pub fn get_chat_engine(
&self,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_completions_engine(
&self,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.completions_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_embeddings_engine(
&self,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.embeddings_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_images_engine(&self) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.images_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_videos_engine(&self) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.videos_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_chat_engine_with_parsing(
&self,
) -> Result<(OpenAIChatCompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone().map(|e| (e, ws.parsing_options())))
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_completions_engine_with_parsing(
&self,
) -> Result<(OpenAICompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| {
ws.completions_engine
.clone()
.map(|e| (e, ws.parsing_options()))
})
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn load_threshold_config(
&self,
config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> {
let mut result = None;
for entry in self.worker_sets.iter() {
if let Some(ref monitor) = entry.value().worker_monitor {
if let Some(cfg) = config {
monitor.set_load_threshold_config(cfg);
}
if result.is_none() {
result = Some(monitor.load_threshold_config());
}
}
}
result
}
pub fn get_worker_monitor_for_namespace(&self, namespace: &str) -> Option<KvWorkerMonitor> {
self.worker_sets
.get(namespace)
.and_then(|entry| entry.value().worker_monitor.clone())
}
pub fn total_workers(&self) -> usize {
self.worker_sets
.iter()
.map(|entry| entry.value().worker_count())
.sum()
}
fn select_worker_set_with<T, F>(&self, extract: F) -> Option<T>
where
F: Fn(&WorkerSet) -> Option<T>,
{
if self.worker_sets.len() == 1 {
return self.worker_sets.iter().next().and_then(|entry| {
let ws = entry.value();
if ws.worker_count() == 0 {
return None;
}
extract(ws)
});
}
let eligible: Vec<(T, usize)> = self
.worker_sets
.iter()
.filter_map(|entry| {
let ws = entry.value();
let count = ws.worker_count();
if count == 0 {
return None;
}
extract(ws).map(|val| (val, count))
})
.collect();
if eligible.is_empty() {
return None;
}
if eligible.len() == 1 {
return eligible.into_iter().next().map(|(val, _)| val);
}
let total_weight: usize = eligible.iter().map(|(_, w)| w).sum();
let mut pick = rand::rng().random_range(0..total_weight);
for (val, weight) in eligible {
if pick < weight {
return Some(val);
}
pick -= weight;
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
use tokio::sync::watch;
fn make_worker_set(namespace: &str, mdcsum: &str) -> Arc<WorkerSet> {
Arc::new(WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
))
}
fn make_worker_set_with_count(
namespace: &str,
mdcsum: &str,
worker_ids: Vec<u64>,
) -> (Arc<WorkerSet>, watch::Sender<Vec<u64>>) {
let (tx, rx) = watch::channel(worker_ids);
let mut ws = WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
);
ws.set_instance_watcher(rx);
(Arc::new(ws), tx)
}
#[test]
fn test_model_new() {
let model = Model::new("llama".to_string());
assert_eq!(model.name(), "llama");
assert!(model.is_empty());
assert_eq!(model.worker_set_count(), 0);
}
#[test]
fn test_add_remove_worker_set() {
let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws);
assert!(!model.is_empty());
assert_eq!(model.worker_set_count(), 1);
assert!(model.has_worker_set("ns1"));
assert!(!model.has_worker_set("ns2"));
let removed = model.remove_worker_set("ns1");
assert!(removed.is_some());
assert!(model.is_empty());
let removed_again = model.remove_worker_set("ns1");
assert!(removed_again.is_none());
}
#[test]
fn test_get_worker_set() {
let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws);
let retrieved = model.get_worker_set("ns1");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().namespace(), "ns1");
assert!(model.get_worker_set("ns2").is_none());
}
#[test]
fn test_multiple_worker_sets_same_checksum() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"));
assert_eq!(model.worker_set_count(), 2);
assert!(model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
model.remove_worker_set("ns1");
assert_eq!(model.worker_set_count(), 1);
assert!(!model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
}
#[test]
fn test_multiple_worker_sets_different_checksums() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "def"));
assert_eq!(model.worker_set_count(), 2);
assert!(model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
}
#[test]
fn test_is_checksum_compatible_no_existing_worker_set() {
let model = Model::new("llama".to_string());
assert!(model.is_checksum_compatible("ns1", "abc"));
assert!(model.is_checksum_compatible("ns1", "xyz"));
}
#[test]
fn test_is_checksum_compatible_matching_checksum() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(model.is_checksum_compatible("ns1", "abc"));
}
#[test]
fn test_is_checksum_compatible_mismatched_checksum() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(!model.is_checksum_compatible("ns1", "def"));
}
#[test]
fn test_is_checksum_compatible_different_ws_key() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(model.is_checksum_compatible("ns2", "def"));
assert!(model.is_checksum_compatible("ns2", "abc"));
}
#[test]
fn test_no_engines_means_prefill() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(model.has_prefill());
assert!(!model.has_decode_engine());
assert!(!model.has_chat_engine());
assert!(!model.has_completions_engine());
assert!(!model.has_embeddings_engine());
assert!(!model.has_tensor_engine());
assert!(!model.has_images_engine());
}
#[test]
fn test_get_engine_returns_error_without_engines() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
assert!(model.get_embeddings_engine().is_err());
assert!(model.get_images_engine().is_err());
assert!(model.get_tensor_engine().is_err());
}
#[test]
fn test_select_worker_set_with_extracts_namespace() {
let model = Model::new("llama".to_string());
assert!(model.get_chat_engine().is_err());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(model.get_chat_engine().is_err());
model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"));
assert!(model.get_chat_engine().is_err()); }
#[test]
fn test_total_workers_no_watcher() {
let model = Model::new("llama".to_string());
assert_eq!(model.total_workers(), 0);
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert_eq!(model.total_workers(), 1);
model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"));
assert_eq!(model.total_workers(), 2);
}
#[test]
fn test_total_workers_with_watcher() {
let model = Model::new("llama".to_string());
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2, 3]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![10, 20]);
model.add_worker_set("ns1".to_string(), ws1);
model.add_worker_set("ns2".to_string(), ws2);
assert_eq!(model.total_workers(), 5); }
#[test]
fn test_total_workers_updates_dynamically() {
let model = Model::new("llama".to_string());
let (ws1, tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2]);
model.add_worker_set("ns1".to_string(), ws1);
assert_eq!(model.total_workers(), 2);
tx1.send(vec![1]).unwrap();
assert_eq!(model.total_workers(), 1);
tx1.send(vec![]).unwrap();
assert_eq!(model.total_workers(), 0);
}
#[test]
fn test_zero_worker_single_set_filtered() {
let model = Model::new("llama".to_string());
let (ws, _tx) = make_worker_set_with_count("ns1", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws);
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
}
#[test]
fn test_zero_worker_multi_set_filtered() {
let model = Model::new("llama".to_string());
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws1);
model.add_worker_set("ns2".to_string(), ws2);
assert!(model.get_chat_engine().is_err());
}
}