use std::sync::Arc;
use crate::{
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, RouterConfig, input::common},
grpc::service::kserve,
http::service::metrics::Metrics,
namespace::NamespaceFilter,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::DistributedRuntime;
pub async fn run(
distributed_runtime: DistributedRuntime,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let mut grpc_service_builder = kserve::KserveService::builder()
.port(engine_config.local_model().http_port()) .http_cancel_token(Some(distributed_runtime.primary_token()))
.with_request_template(engine_config.local_model().request_template());
if let Some(http_metrics_port) = engine_config.local_model().http_metrics_port() {
grpc_service_builder = grpc_service_builder.http_metrics_port(http_metrics_port);
}
let grpc_service = match engine_config {
EngineConfig::Dynamic { ref model, .. } => {
let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config();
let migration_limit = model.migration_limit();
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(),
model.namespace_prefix(),
);
run_watcher(
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
router_config.clone(),
migration_limit,
namespace_filter,
)
.await?;
grpc_service
}
EngineConfig::InProcessText { engine, model, .. } => {
let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
grpc_service
}
EngineConfig::InProcessTokens {
engine: inner_engine,
model,
..
} => {
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer = model.card().tokenizer()?;
let chat_pipeline = common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer)
.await?;
manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
grpc_service
}
};
let http_service = grpc_service.http_service().clone();
let shutdown_token = distributed_runtime.primary_token();
tokio::try_join!(
grpc_service.run(shutdown_token.clone()),
http_service.run(shutdown_token)
)?;
distributed_runtime.shutdown(); Ok(())
}
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
namespace_filter: NamespaceFilter,
) -> anyhow::Result<()> {
let metrics = Arc::new(Metrics::new());
let watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_config,
migration_limit,
None,
metrics,
);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
.list_and_watch(
dynamo_runtime::discovery::DiscoveryQuery::AllModels,
Some(runtime.primary_token()),
)
.await?;
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(discovery_stream, namespace_filter).await;
});
Ok(())
}