use std::sync::Arc;
use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService},
namespace::NamespaceFilter,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::metrics::MetricsHierarchy;
pub async fn run(
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let local_model = engine_config.local_model();
let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
(Some(tls_cert_path), Some(tls_key_path)) => {
if !tls_cert_path.exists() {
anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
}
if !tls_key_path.exists() {
anyhow::bail!("TLS key not found: {}", tls_key_path.display());
}
service_v2::HttpService::builder()
.enable_tls(true)
.tls_cert_path(Some(tls_cert_path.to_path_buf()))
.tls_key_path(Some(tls_key_path.to_path_buf()))
.port(local_model.http_port())
}
(None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
(_, _) => {
anyhow::bail!(
"Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
);
}
};
if let Some(http_host) = local_model.http_host() {
http_service_builder = http_service_builder.host(http_host);
}
http_service_builder =
http_service_builder.cancel_token(Some(distributed_runtime.primary_token()));
http_service_builder =
http_service_builder.with_request_template(engine_config.local_model().request_template());
http_service_builder =
http_service_builder.drt_metrics(Some(distributed_runtime.get_metrics_registry().clone()));
http_service_builder =
http_service_builder.drt_discovery(Some(distributed_runtime.discovery()));
let http_service = match engine_config {
EngineConfig::Dynamic {
ref model,
ref chat_engine_factory,
} => {
http_service_builder =
http_service_builder.discovery(Some(distributed_runtime.discovery()));
let http_service = http_service_builder.build()?;
let router_config = model.router_config();
let migration_limit = model.migration_limit();
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(),
model.namespace_prefix(),
);
run_watcher(
distributed_runtime.clone(),
http_service.state().manager_clone(),
router_config.clone(),
migration_limit,
namespace_filter,
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
chat_engine_factory.clone(),
)
.await?;
http_service
}
EngineConfig::InProcessText { engine, model, .. } => {
let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
}
EngineConfig::InProcessTokens {
engine: inner_engine,
model,
..
} => {
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer = model.card().tokenizer()?;
let chat_pipeline = common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer.clone())
.await?;
manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer)
.await?;
manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
}
};
tracing::debug!(
"Supported routes: {:?}",
http_service
.route_docs()
.iter()
.map(|rd| rd.to_string())
.collect::<Vec<String>>()
);
http_service
.run(distributed_runtime.primary_token())
.await?;
distributed_runtime.shutdown(); Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
namespace_filter: NamespaceFilter,
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_config,
migration_limit,
chat_engine_factory,
metrics.clone(),
);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(runtime.primary_token()),
)
.await?;
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
watch_obj.set_notify_on_model_update(tx);
let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_update) = rx.recv().await {
update_http_endpoints(http_service.clone(), model_update.clone());
update_model_metrics(model_update, metrics.clone());
}
});
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(discovery_stream, namespace_filter).await;
});
Ok(())
}
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
tracing::debug!(
"Updating HTTP service endpoints for model type: {:?}",
model_type
);
match model_type {
ModelUpdate::Added(card) => {
for endpoint_type in card.model_type.as_endpoint_types() {
service.enable_model_endpoint(endpoint_type, true);
}
}
ModelUpdate::Removed(card) => {
for endpoint_type in card.model_type.as_endpoint_types() {
service.enable_model_endpoint(endpoint_type, false);
}
}
}
}
fn update_model_metrics(
model_type: ModelUpdate,
metrics: Arc<crate::http::service::metrics::Metrics>,
) {
match model_type {
ModelUpdate::Added(card) => {
tracing::debug!("Updating metrics for added model: {}", card.display_name);
if let Err(err) = metrics.update_metrics_from_mdc(&card) {
tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
}
}
ModelUpdate::Removed(card) => {
tracing::debug!(model_name = card.display_name, "Model removed");
}
}
}