use std::collections::HashMap;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct DisaggregatedEndpoint {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_host: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<u16>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>,
pub max_num_seqs: Option<u64>,
pub max_num_batched_tokens: Option<u64>,
pub tool_call_parser: Option<String>,
pub reasoning_parser: Option<String>,
#[serde(default = "default_data_parallel_start_rank")]
pub data_parallel_start_rank: u32,
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
#[serde(default = "default_local_indexer")]
pub enable_local_indexer: bool,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tensor_model_config: Option<tensor::TensorModelConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
}
const fn default_data_parallel_start_rank() -> u32 {
0
}
const fn default_data_parallel_size() -> u32 {
1
}
const fn default_local_indexer() -> bool {
true
}
impl Default for ModelRuntimeConfig {
fn default() -> Self {
Self {
total_kv_blocks: None,
max_num_seqs: None,
max_num_batched_tokens: None,
tool_call_parser: None,
reasoning_parser: None,
data_parallel_start_rank: default_data_parallel_start_rank(),
data_parallel_size: default_data_parallel_size(),
enable_local_indexer: true,
runtime_data: HashMap::new(),
tensor_model_config: None,
disaggregated_endpoint: None,
}
}
}
impl ModelRuntimeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn set_engine_specific<T: Serialize>(&mut self, key: &str, value: T) -> anyhow::Result<()> {
self.runtime_data
.insert(key.to_string(), serde_json::to_value(value)?);
Ok(())
}
pub fn get_engine_specific<T: DeserializeOwned>(&self, key: &str) -> anyhow::Result<Option<T>> {
if let Some(value) = self.runtime_data.get(key) {
Ok(Some(serde_json::from_value(value.clone())?))
} else {
Ok(None)
}
}
}