use std::sync::Arc;
use tokio::sync::watch;
use crate::{
discovery::KvWorkerMonitor,
kv_router::KvRouter,
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
},
},
};
pub struct WorkerSet {
namespace: String,
mdcsum: String,
card: ModelDeploymentCard,
pub(crate) chat_engine: Option<OpenAIChatCompletionsStreamingEngine>,
pub(crate) completions_engine: Option<OpenAICompletionsStreamingEngine>,
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>,
pub(crate) kv_router: Option<Arc<KvRouter>>,
pub(crate) worker_monitor: Option<KvWorkerMonitor>,
instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
}
impl WorkerSet {
pub fn new(namespace: String, mdcsum: String, card: ModelDeploymentCard) -> Self {
Self {
namespace,
mdcsum,
card,
chat_engine: None,
completions_engine: None,
embeddings_engine: None,
images_engine: None,
videos_engine: None,
tensor_engine: None,
kv_router: None,
worker_monitor: None,
instance_count_rx: None,
}
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn mdcsum(&self) -> &str {
&self.mdcsum
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn has_chat_engine(&self) -> bool {
self.chat_engine.is_some()
}
pub fn has_completions_engine(&self) -> bool {
self.completions_engine.is_some()
}
pub fn has_embeddings_engine(&self) -> bool {
self.embeddings_engine.is_some()
}
pub fn has_images_engine(&self) -> bool {
self.images_engine.is_some()
}
pub fn has_videos_engine(&self) -> bool {
self.videos_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some()
}
pub fn has_decode_engine(&self) -> bool {
self.has_chat_engine() || self.has_completions_engine()
}
pub fn is_prefill_set(&self) -> bool {
!self.has_decode_engine()
&& !self.has_embeddings_engine()
&& !self.has_images_engine()
&& !self.has_videos_engine()
&& !self.has_tensor_engine()
}
pub fn parsing_options(&self) -> crate::protocols::openai::ParsingOptions {
crate::protocols::openai::ParsingOptions::new(
self.card.runtime_config.tool_call_parser.clone(),
self.card.runtime_config.reasoning_parser.clone(),
)
}
pub fn worker_count(&self) -> usize {
match &self.instance_count_rx {
Some(rx) => rx.borrow().len(),
None => 1,
}
}
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_worker_set_basics() {
let ws = make_worker_set("ns1", "abc123");
assert_eq!(ws.namespace(), "ns1");
assert_eq!(ws.mdcsum(), "abc123");
}
#[test]
fn test_no_engines_by_default() {
let ws = make_worker_set("ns1", "abc123");
assert!(!ws.has_chat_engine());
assert!(!ws.has_completions_engine());
assert!(!ws.has_embeddings_engine());
assert!(!ws.has_images_engine());
assert!(!ws.has_tensor_engine());
assert!(!ws.has_decode_engine());
assert!(ws.is_prefill_set());
}
#[test]
fn test_worker_count_without_watcher() {
let ws = make_worker_set("ns1", "abc");
assert_eq!(ws.worker_count(), 1);
}
#[test]
fn test_worker_count_with_watcher() {
let mut ws = make_worker_set("ns1", "abc");
let (tx, rx) = watch::channel(vec![1, 2, 3]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 3);
tx.send(vec![1]).unwrap();
assert_eq!(ws.worker_count(), 1);
tx.send(vec![]).unwrap();
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_with_empty_watcher() {
let mut ws = make_worker_set("ns1", "abc");
let (_tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_updates_on_join() {
let mut ws = make_worker_set("ns1", "abc");
let (tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
tx.send(vec![100]).unwrap();
assert_eq!(ws.worker_count(), 1);
tx.send(vec![100, 200]).unwrap();
assert_eq!(ws.worker_count(), 2);
tx.send(vec![100, 200, 300]).unwrap();
assert_eq!(ws.worker_count(), 3);
}
}