pub mod common_proto {
#![allow(clippy::all, clippy::absolute_paths, unused_qualifications)]
tonic::include_proto!("smg.grpc.common");
}
pub mod abort_on_drop;
pub mod channel;
pub mod mlx_engine;
pub mod sglang_scheduler;
pub mod tokenizer_bundle;
pub mod tokenspeed_scheduler;
pub mod trtllm_service;
pub mod vllm_engine;
use std::sync::Arc;
pub use abort_on_drop::{AbortOnDropClient, AbortOnDropStream};
pub use channel::{connect_channel, normalize_grpc_endpoint};
pub use mlx_engine::{proto as mlx_proto, MlxEngineClient};
pub use sglang_scheduler::{proto as sglang_proto, SglangSchedulerClient};
pub use tokenspeed_scheduler::{tokenspeed_proto, TokenSpeedSchedulerClient};
use tonic::metadata::MetadataMap;
pub use trtllm_service::{proto as trtllm_proto, TrtllmServiceClient};
pub use vllm_engine::{proto as vllm_proto, VllmEngineClient};
macro_rules! impl_get_tokenizer {
() => {
pub async fn get_tokenizer(
&self,
) -> Result<
$crate::tokenizer_bundle::StreamBundle,
Box<dyn std::error::Error + Send + Sync>,
> {
use $crate::common_proto::GetTokenizerRequest;
let request = tonic::Request::new(GetTokenizerRequest {});
let mut client = self.client.clone();
$crate::tokenizer_bundle::collect_bundle_from_rpc(
client.get_tokenizer(request),
|chunk| (chunk.data, chunk.sha256),
std::time::Duration::from_secs(120),
)
.await
}
};
}
pub(crate) use impl_get_tokenizer;
pub const FLUSH_RPC_DEADLINE_MARGIN: std::time::Duration = std::time::Duration::from_secs(45);
pub const PROFILE_RPC_DEADLINE: std::time::Duration = std::time::Duration::from_secs(630);
macro_rules! impl_admin_ops {
() => {
pub async fn flush_cache(
&self,
timeout_s: f32,
) -> Result<$crate::common_proto::FlushCacheResponse, tonic::Status> {
tracing::debug!("Requesting cache flush (timeout_s={timeout_s})");
let mut request =
tonic::Request::new($crate::common_proto::FlushCacheRequest { timeout_s });
if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
tracing::warn!("Failed to inject trace context: {}", e);
}
let deadline = std::time::Duration::from_secs_f32(timeout_s.max(0.0))
+ $crate::FLUSH_RPC_DEADLINE_MARGIN;
let mut client = self.client.clone();
let response = tokio::time::timeout(deadline, client.flush_cache(request))
.await
.map_err(|_| {
tonic::Status::deadline_exceeded(format!(
"FlushCache did not complete within {deadline:?}"
))
})??;
Ok(response.into_inner())
}
pub async fn start_profile(
&self,
req: $crate::common_proto::StartProfileRequest,
) -> Result<$crate::common_proto::ProfileResponse, tonic::Status> {
tracing::debug!("Requesting profile start");
let mut request = tonic::Request::new(req);
if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
tracing::warn!("Failed to inject trace context: {}", e);
}
let mut client = self.client.clone();
let response =
tokio::time::timeout($crate::PROFILE_RPC_DEADLINE, client.start_profile(request))
.await
.map_err(|_| {
tonic::Status::deadline_exceeded(format!(
"StartProfile did not complete within {:?}",
$crate::PROFILE_RPC_DEADLINE
))
})??;
Ok(response.into_inner())
}
pub async fn stop_profile(
&self,
) -> Result<$crate::common_proto::ProfileResponse, tonic::Status> {
tracing::debug!("Requesting profile stop");
let mut request = tonic::Request::new($crate::common_proto::StopProfileRequest {});
if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
tracing::warn!("Failed to inject trace context: {}", e);
}
let mut client = self.client.clone();
let response =
tokio::time::timeout($crate::PROFILE_RPC_DEADLINE, client.stop_profile(request))
.await
.map_err(|_| {
tonic::Status::deadline_exceeded(format!(
"StopProfile did not complete within {:?}",
$crate::PROFILE_RPC_DEADLINE
))
})??;
Ok(response.into_inner())
}
};
}
pub(crate) use impl_admin_ops;
macro_rules! impl_subscribe_kv_events {
() => {
pub async fn subscribe_kv_events(
&self,
start_sequence_number: u64,
) -> Result<tonic::Streaming<$crate::common_proto::KvEventBatch>, tonic::Status> {
let request = tonic::Request::new($crate::common_proto::SubscribeKvEventsRequest {
start_sequence_number,
});
let mut client = self.client.clone();
let response = client.subscribe_kv_events(request).await?;
Ok(response.into_inner())
}
};
}
pub(crate) use impl_subscribe_kv_events;
pub trait TraceInjector: Send + Sync {
fn inject(
&self,
metadata: &mut MetadataMap,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
}
#[derive(Clone, Default)]
pub struct NoopTraceInjector;
impl TraceInjector for NoopTraceInjector {
fn inject(
&self,
_metadata: &mut MetadataMap,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
pub type BoxedTraceInjector = Arc<dyn TraceInjector>;
macro_rules! impl_engine_client_basics {
($proto_client:path, $display_name:literal) => {
pub async fn connect(
endpoint: &str,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
Self::connect_with_trace_injector(
endpoint,
std::sync::Arc::new($crate::NoopTraceInjector),
)
.await
}
pub async fn connect_with_trace_injector(
endpoint: &str,
trace_injector: $crate::BoxedTraceInjector,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
tracing::debug!(
"Connecting to {} gRPC server at {}",
$display_name,
endpoint
);
let channel = $crate::channel::connect_channel(endpoint).await?;
let client = <$proto_client>::new(channel);
Ok(Self {
client,
trace_injector,
})
}
#[must_use]
pub fn with_trace_injector(mut self, trace_injector: $crate::BoxedTraceInjector) -> Self {
self.trace_injector = trace_injector;
self
}
pub async fn health_check(&self) -> Result<proto::HealthCheckResponse, tonic::Status> {
tracing::debug!("Sending health check request");
let request = tonic::Request::new(proto::HealthCheckRequest {});
let mut client = self.client.clone();
let response = client.health_check(request).await?;
tracing::debug!("Health check response received");
Ok(response.into_inner())
}
pub async fn get_model_info(&self) -> Result<proto::GetModelInfoResponse, tonic::Status> {
tracing::debug!("Requesting model info");
let request = tonic::Request::new(proto::GetModelInfoRequest {});
let mut client = self.client.clone();
let response = client.get_model_info(request).await?;
tracing::debug!("Model info response received");
Ok(response.into_inner())
}
pub async fn get_server_info(&self) -> Result<proto::GetServerInfoResponse, tonic::Status> {
tracing::debug!("Requesting server info");
let request = tonic::Request::new(proto::GetServerInfoRequest {});
let mut client = self.client.clone();
let response = client.get_server_info(request).await?;
tracing::debug!("Server info response received");
Ok(response.into_inner())
}
};
}
pub(crate) use impl_engine_client_basics;