Skip to main content

smg_grpc_client/
lib.rs

1//! gRPC clients for vLLM, TensorRT-LLM, MLX, TokenSpeed, and SGLang backends.
2//!
3//! This crate provides gRPC client implementations for communicating with
4//! the vLLM engine, TensorRT-LLM engine, MLX engine, TokenSpeed scheduler,
5//! and SGLang scheduler backends.
6
7pub mod common_proto {
8    #![allow(clippy::all, clippy::absolute_paths, unused_qualifications)]
9    tonic::include_proto!("smg.grpc.common");
10}
11pub mod abort_on_drop;
12pub mod channel;
13pub mod mlx_engine;
14pub mod sglang_scheduler;
15pub mod tokenizer_bundle;
16pub mod tokenspeed_scheduler;
17pub mod trtllm_service;
18pub mod vllm_engine;
19
20// Re-export clients
21use std::sync::Arc;
22
23pub use abort_on_drop::{AbortOnDropClient, AbortOnDropStream};
24pub use channel::{connect_channel, normalize_grpc_endpoint};
25pub use mlx_engine::{proto as mlx_proto, MlxEngineClient};
26pub use sglang_scheduler::{proto as sglang_proto, SglangSchedulerClient};
27pub use tokenspeed_scheduler::{tokenspeed_proto, TokenSpeedSchedulerClient};
28use tonic::metadata::MetadataMap;
29pub use trtllm_service::{proto as trtllm_proto, TrtllmServiceClient};
30pub use vllm_engine::{proto as vllm_proto, VllmEngineClient};
31
32/// Shared `get_tokenizer()` implementation for all engine clients.
33///
34/// Each engine's generated proto client has a `get_tokenizer` RPC method
35/// with identical signature (using common proto types). This macro provides
36/// the wrapper that calls `collect_bundle_from_rpc` with the standard
37/// timeout and chunk extraction.
38macro_rules! impl_get_tokenizer {
39    () => {
40        pub async fn get_tokenizer(
41            &self,
42        ) -> Result<
43            $crate::tokenizer_bundle::StreamBundle,
44            Box<dyn std::error::Error + Send + Sync>,
45        > {
46            use $crate::common_proto::GetTokenizerRequest;
47            let request = tonic::Request::new(GetTokenizerRequest {});
48            let mut client = self.client.clone();
49            $crate::tokenizer_bundle::collect_bundle_from_rpc(
50                client.get_tokenizer(request),
51                |chunk| (chunk.data, chunk.sha256),
52                std::time::Duration::from_secs(120),
53            )
54            .await
55        }
56    };
57}
58pub(crate) use impl_get_tokenizer;
59
60/// Extra local-deadline margin for `flush_cache` on top of the timeout
61/// forwarded to the backend. The servicer bounds its own scheduler
62/// round-trip at `max(30, timeout_s + 10)` seconds, so the margin must
63/// cover that budget plus transport overhead.
64pub const FLUSH_RPC_DEADLINE_MARGIN: std::time::Duration = std::time::Duration::from_secs(45);
65
66/// Local deadline for profile start/stop RPCs. Stopping a profile can take
67/// a long time while the backend serializes large traces.
68pub const PROFILE_RPC_DEADLINE: std::time::Duration = std::time::Duration::from_secs(630);
69
70/// Shared admin-op implementations (`flush_cache`, `start_profile`,
71/// `stop_profile`) for engine clients whose protos expose the common
72/// admin RPCs (request/response messages live in `common.proto`).
73///
74/// Every call enforces a local deadline so an unresponsive backend cannot
75/// hang the gateway, and injects trace context for distributed tracing.
76macro_rules! impl_admin_ops {
77    () => {
78        /// Flush the KV cache on the backend scheduler.
79        ///
80        /// `timeout_s` is forwarded to the backend: 0 = flush immediately
81        /// (fails if requests are in flight), >0 = wait up to that many
82        /// seconds for the scheduler to go idle first.
83        pub async fn flush_cache(
84            &self,
85            timeout_s: f32,
86        ) -> Result<$crate::common_proto::FlushCacheResponse, tonic::Status> {
87            tracing::debug!("Requesting cache flush (timeout_s={timeout_s})");
88            let mut request =
89                tonic::Request::new($crate::common_proto::FlushCacheRequest { timeout_s });
90            if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
91                tracing::warn!("Failed to inject trace context: {}", e);
92            }
93            let deadline = std::time::Duration::from_secs_f32(timeout_s.max(0.0))
94                + $crate::FLUSH_RPC_DEADLINE_MARGIN;
95            let mut client = self.client.clone();
96            let response = tokio::time::timeout(deadline, client.flush_cache(request))
97                .await
98                .map_err(|_| {
99                    tonic::Status::deadline_exceeded(format!(
100                        "FlushCache did not complete within {deadline:?}"
101                    ))
102                })??;
103            Ok(response.into_inner())
104        }
105
106        /// Start the profiler on the backend scheduler.
107        pub async fn start_profile(
108            &self,
109            req: $crate::common_proto::StartProfileRequest,
110        ) -> Result<$crate::common_proto::ProfileResponse, tonic::Status> {
111            tracing::debug!("Requesting profile start");
112            let mut request = tonic::Request::new(req);
113            if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
114                tracing::warn!("Failed to inject trace context: {}", e);
115            }
116            let mut client = self.client.clone();
117            let response =
118                tokio::time::timeout($crate::PROFILE_RPC_DEADLINE, client.start_profile(request))
119                    .await
120                    .map_err(|_| {
121                        tonic::Status::deadline_exceeded(format!(
122                            "StartProfile did not complete within {:?}",
123                            $crate::PROFILE_RPC_DEADLINE
124                        ))
125                    })??;
126            Ok(response.into_inner())
127        }
128
129        /// Stop the profiler on the backend scheduler and export traces.
130        pub async fn stop_profile(
131            &self,
132        ) -> Result<$crate::common_proto::ProfileResponse, tonic::Status> {
133            tracing::debug!("Requesting profile stop");
134            let mut request = tonic::Request::new($crate::common_proto::StopProfileRequest {});
135            if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
136                tracing::warn!("Failed to inject trace context: {}", e);
137            }
138            let mut client = self.client.clone();
139            let response =
140                tokio::time::timeout($crate::PROFILE_RPC_DEADLINE, client.stop_profile(request))
141                    .await
142                    .map_err(|_| {
143                        tonic::Status::deadline_exceeded(format!(
144                            "StopProfile did not complete within {:?}",
145                            $crate::PROFILE_RPC_DEADLINE
146                        ))
147                    })??;
148            Ok(response.into_inner())
149        }
150    };
151}
152pub(crate) use impl_admin_ops;
153
154/// Shared `subscribe_kv_events()` implementation for all engine clients.
155///
156/// Each engine's generated proto client has a `subscribe_kv_events` RPC method
157/// with identical signature (using common proto types). This macro provides
158/// the wrapper that returns a `tonic::Streaming<KvEventBatch>`.
159macro_rules! impl_subscribe_kv_events {
160    () => {
161        /// Subscribe to KV cache events from the backend.
162        /// Returns a long-lived server-streaming response.
163        pub async fn subscribe_kv_events(
164            &self,
165            start_sequence_number: u64,
166        ) -> Result<tonic::Streaming<$crate::common_proto::KvEventBatch>, tonic::Status> {
167            let request = tonic::Request::new($crate::common_proto::SubscribeKvEventsRequest {
168                start_sequence_number,
169            });
170            let mut client = self.client.clone();
171            let response = client.subscribe_kv_events(request).await?;
172            Ok(response.into_inner())
173        }
174    };
175}
176pub(crate) use impl_subscribe_kv_events;
177
178/// Trait for injecting trace context into gRPC metadata.
179///
180/// Implement this trait to enable distributed tracing across gRPC calls.
181/// The default implementation is a no-op.
182pub trait TraceInjector: Send + Sync {
183    /// Inject trace context into the given metadata map.
184    ///
185    /// Returns `Ok(())` on success, or an error if injection fails.
186    fn inject(
187        &self,
188        metadata: &mut MetadataMap,
189    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
190}
191
192/// A no-op trace injector that does nothing.
193#[derive(Clone, Default)]
194pub struct NoopTraceInjector;
195
196impl TraceInjector for NoopTraceInjector {
197    fn inject(
198        &self,
199        _metadata: &mut MetadataMap,
200    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
201        Ok(())
202    }
203}
204
205/// Type alias for a boxed trace injector.
206pub type BoxedTraceInjector = Arc<dyn TraceInjector>;
207
208/// Generates the boilerplate that every engine client shares: the two
209/// `connect` constructors, `with_trace_injector`, and the three "standard"
210/// RPCs (`health_check`, `get_model_info`, `get_server_info`) whose
211/// request/response types are uniform across the generated proto crates.
212///
213/// `$proto_client` is the fully-qualified path of the generated tonic
214/// client type (which `Self` wraps). `$display_name` is the human-readable
215/// name used in the connect log line.
216///
217/// Each engine's `impl` block invokes this once and then adds engine-
218/// specific RPCs (`generate`, `embed`, etc.) below.
219macro_rules! impl_engine_client_basics {
220    ($proto_client:path, $display_name:literal) => {
221        /// Create a new client and connect to the backend.
222        pub async fn connect(
223            endpoint: &str,
224        ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
225            Self::connect_with_trace_injector(
226                endpoint,
227                std::sync::Arc::new($crate::NoopTraceInjector),
228            )
229            .await
230        }
231
232        /// Create a new client with a custom trace injector.
233        pub async fn connect_with_trace_injector(
234            endpoint: &str,
235            trace_injector: $crate::BoxedTraceInjector,
236        ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
237            tracing::debug!(
238                "Connecting to {} gRPC server at {}",
239                $display_name,
240                endpoint
241            );
242            let channel = $crate::channel::connect_channel(endpoint).await?;
243            let client = <$proto_client>::new(channel);
244            Ok(Self {
245                client,
246                trace_injector,
247            })
248        }
249
250        /// Set or replace the trace injector.
251        #[must_use]
252        pub fn with_trace_injector(mut self, trace_injector: $crate::BoxedTraceInjector) -> Self {
253            self.trace_injector = trace_injector;
254            self
255        }
256
257        /// Perform a health check.
258        pub async fn health_check(&self) -> Result<proto::HealthCheckResponse, tonic::Status> {
259            tracing::debug!("Sending health check request");
260            let request = tonic::Request::new(proto::HealthCheckRequest {});
261            let mut client = self.client.clone();
262            let response = client.health_check(request).await?;
263            tracing::debug!("Health check response received");
264            Ok(response.into_inner())
265        }
266
267        /// Get model information.
268        pub async fn get_model_info(&self) -> Result<proto::GetModelInfoResponse, tonic::Status> {
269            tracing::debug!("Requesting model info");
270            let request = tonic::Request::new(proto::GetModelInfoRequest {});
271            let mut client = self.client.clone();
272            let response = client.get_model_info(request).await?;
273            tracing::debug!("Model info response received");
274            Ok(response.into_inner())
275        }
276
277        /// Get server information.
278        pub async fn get_server_info(&self) -> Result<proto::GetServerInfoResponse, tonic::Status> {
279            tracing::debug!("Requesting server info");
280            let request = tonic::Request::new(proto::GetServerInfoRequest {});
281            let mut client = self.client.clone();
282            let response = client.get_server_info(request).await?;
283            tracing::debug!("Server info response received");
284            Ok(response.into_inner())
285        }
286    };
287}
288pub(crate) use impl_engine_client_basics;