dynamo_runtime/pipeline/network/egress/
addressed_router.rs1use std::collections::BTreeMap;
5use std::sync::Arc;
6use std::time::Instant;
7
8use super::unified_client::RequestPlaneClient;
9use super::*;
10use crate::component::Instance;
11use crate::discovery::EndpointInstanceId;
12use crate::dynamo_nvtx_range;
13use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
14use crate::error::{DynamoError, ErrorType};
15use crate::logging::inject_trace_headers_into_map;
16use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
17use crate::metrics::request_plane::{
18 REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
19 REQUEST_PLANE_SEND_SECONDS,
20};
21use crate::pipeline::network::ConnectionInfo;
22use crate::pipeline::network::NetworkStreamWrapper;
23use crate::pipeline::network::PendingConnections;
24use crate::pipeline::network::StreamOptions;
25use crate::pipeline::network::TwoPartCodec;
26use crate::pipeline::network::codec::TwoPartMessage;
27use crate::pipeline::network::tcp;
28use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
29use crate::protocols::maybe_error::MaybeError;
30use crate::traits::DistributedRuntimeProvider;
31
32use anyhow::{Error, Result};
33use futures::stream::Stream;
34use serde::Deserialize;
35use serde::Serialize;
36use std::pin::Pin;
37use std::task::{Context, Poll};
38use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
39use tracing::Instrument;
40
41const CONTROL_MESSAGE_MAX_BYTES: usize = 128 * 1024;
42
43fn serialize_control_message(control_message: &RequestControlMessage) -> Result<Vec<u8>, Error> {
44 let ctrl = serde_json::to_vec(control_message)?;
45 if ctrl.len() > CONTROL_MESSAGE_MAX_BYTES {
46 return Err(PipelineError::Generic(format!(
47 "request control message too large: {} bytes exceeds limit {}",
48 ctrl.len(),
49 CONTROL_MESSAGE_MAX_BYTES
50 ))
51 .into());
52 }
53 Ok(ctrl)
54}
55
56struct InflightGuard {
60 armed: bool,
61}
62
63impl InflightGuard {
64 fn new() -> Self {
65 Self { armed: true }
66 }
67
68 fn disarm(mut self) {
71 self.armed = false;
72 }
73}
74
75impl Drop for InflightGuard {
76 fn drop(&mut self) {
77 if self.armed {
78 REQUEST_PLANE_INFLIGHT.dec();
79 }
80 }
81}
82
83struct InflightDecStream<S> {
85 inner: S,
86}
87
88impl<S, T> Stream for InflightDecStream<S>
89where
90 S: Stream<Item = T> + Unpin,
91{
92 type Item = T;
93
94 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 Pin::new(&mut self.inner).poll_next(cx)
96 }
97}
98
99impl<S> Drop for InflightDecStream<S> {
100 fn drop(&mut self) {
101 REQUEST_PLANE_INFLIGHT.dec();
102 }
103}
104
105pub struct AddressedRequest<T> {
106 request: T,
107 address: String,
108 instance: Option<Instance>,
111}
112
113impl<T> AddressedRequest<T> {
114 pub fn new(request: T, address: String) -> Self {
115 Self {
116 request,
117 address,
118 instance: None,
119 }
120 }
121
122 pub fn with_instance(request: T, address: String, instance: Instance) -> Self {
123 Self {
124 request,
125 address,
126 instance: Some(instance),
127 }
128 }
129
130 pub fn for_instance(request: T, instance: Instance) -> Self {
131 let address = instance.transport.address().to_string();
132 Self::with_instance(request, address, instance)
133 }
134
135 pub(crate) fn into_parts(self) -> (T, String, Option<Instance>) {
136 (self.request, self.address, self.instance)
137 }
138}
139
140pub struct AddressedPushRouter {
141 req_client: Arc<dyn RequestPlaneClient>,
143
144 resp_transport: Arc<tcp::server::TcpStreamServer>,
146}
147
148impl AddressedPushRouter {
149 pub fn new(
154 req_client: Arc<dyn RequestPlaneClient>,
155 resp_transport: Arc<tcp::server::TcpStreamServer>,
156 ) -> Result<Arc<Self>> {
157 Ok(Arc::new(Self {
158 req_client,
159 resp_transport,
160 }))
161 }
162
163 pub async fn from_runtime_provider(
164 provider: &impl DistributedRuntimeProvider,
165 ) -> Result<Arc<Self>> {
166 let manager = provider.drt().network_manager();
167 let req_client = manager.create_client()?;
168 let resp_transport = provider.drt().tcp_server().await?;
169
170 tracing::debug!(
171 transport = req_client.transport_name(),
172 "Creating AddressedPushRouter with request plane client"
173 );
174
175 Self::new(req_client, resp_transport)
176 }
177
178 pub async fn cancel_instance_streams(&self, instance_id: &EndpointInstanceId) -> usize {
180 self.resp_transport
181 .cancel_instance_streams(instance_id)
182 .await
183 }
184
185 pub async fn clear_instance_tombstone(&self, instance_id: &EndpointInstanceId) {
187 self.resp_transport
188 .clear_instance_tombstone(instance_id)
189 .await
190 }
191}
192
193#[async_trait::async_trait]
194impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
195where
196 T: Data + Serialize,
197 U: Data + for<'de> Deserialize<'de> + MaybeError,
198{
199 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
200 let queue_start = Instant::now();
201 REQUEST_PLANE_INFLIGHT.inc();
202 let inflight_guard = InflightGuard::new();
203
204 let request_id = request.context().id().to_string();
205 let (addressed_request, context) = request.transfer(());
206 let (request, address, instance_info) = addressed_request.into_parts();
207 let engine_ctx = context.context();
208 let engine_ctx_ = engine_ctx.clone();
209
210 let options = StreamOptions::builder()
212 .context(engine_ctx.clone())
213 .enable_request_stream(false)
214 .enable_response_stream(true)
215 .build()
216 .unwrap();
217
218 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
221
222 let pending_response_stream = match pending_connections.into_parts() {
224 (None, Some(recv_stream)) => recv_stream,
225 _ => {
226 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
227 }
228 };
229
230 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
232
233 let recv_subject: Option<String> =
235 serde_json::from_str::<tcp::TcpStreamConnectionInfo>(&connection_info.info)
236 .ok()
237 .map(|ci| ci.subject);
238
239 if let (Some(subject), Some(inst)) = (&recv_subject, &instance_info) {
242 let endpoint_instance_id = inst.endpoint_instance_id();
243 if !self
244 .resp_transport
245 .associate_instance(subject, &endpoint_instance_id)
246 .await
247 {
248 return Err(anyhow::anyhow!(
249 DynamoError::builder()
250 .error_type(ErrorType::Disconnected)
251 .message(
252 "Worker removed before request could be sent (tombstoned instance)"
253 )
254 .build()
255 ));
256 }
257 }
258
259 let control_message = RequestControlMessage {
264 id: engine_ctx.id().to_string(),
265 request_type: RequestType::SingleIn,
266 response_type: ResponseType::ManyOut,
267 connection_info,
268 metadata: context.metadata().clone(),
269 frontend_send_ts_ns: None,
270 };
271
272 let ctrl = match serialize_control_message(&control_message) {
276 Ok(v) => v,
277 Err(e) => {
278 if let Some(subject) = &recv_subject {
279 self.resp_transport.cancel_recv_stream(subject).await;
280 }
281 return Err(e);
282 }
283 };
284 let data = match serde_json::to_vec(&request) {
285 Ok(v) => v,
286 Err(e) => {
287 if let Some(subject) = &recv_subject {
288 self.resp_transport.cancel_recv_stream(subject).await;
289 }
290 return Err(e.into());
291 }
292 };
293
294 tracing::trace!(
295 request_id,
296 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
297 ctrl.len(),
298 data.len()
299 );
300
301 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
302
303 let codec = TwoPartCodec::default();
307 let buffer = match codec.encode_message(msg) {
308 Ok(v) => v,
309 Err(e) => {
310 if let Some(subject) = &recv_subject {
311 self.resp_transport.cancel_recv_stream(subject).await;
312 }
313 return Err(e.into());
314 }
315 };
316
317 REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
318 let tx_start = Instant::now();
319
320 tracing::trace!(
324 request_id,
325 transport = self.req_client.transport_name(),
326 address = %address,
327 "Sending request via request plane client"
328 );
329
330 let mut headers = std::collections::HashMap::new();
332 inject_trace_headers_into_map(&mut headers);
333 headers.insert("request-id".to_string(), request_id.clone());
334
335 let send_ts_ns = std::time::SystemTime::now()
338 .duration_since(std::time::UNIX_EPOCH)
339 .unwrap_or_default()
340 .as_nanos() as u64;
341 headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
342
343 let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
345 let send_result = self.req_client.send_request(address, buffer, headers).await;
346 drop(_nvtx_send);
347
348 if let Err(e) = send_result {
349 if let Some(subject) = &recv_subject {
350 self.resp_transport.cancel_recv_stream(subject).await;
351 }
352 return Err(e);
353 }
354 REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
355
356 let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
357 tracing::trace!(request_id, "awaiting transport handshake");
358
359 let response_stream = match response_stream_provider.await {
362 Ok(Ok(stream)) => stream,
363 Ok(Err(e)) => {
364 if let Some(subject) = &recv_subject {
371 self.resp_transport.cancel_recv_stream(subject).await;
372 }
373 return Err(anyhow::anyhow!(
374 DynamoError::builder()
375 .error_type(ErrorType::CannotConnect)
376 .message(format!(
377 "Worker generate() failed before response stream: {e}"
378 ))
379 .build()
380 ));
381 }
382 Err(_recv_err) => {
383 if let Some(subject) = &recv_subject {
386 self.resp_transport.cancel_recv_stream(subject).await;
387 }
388 return Err(anyhow::anyhow!(
389 DynamoError::builder()
390 .error_type(ErrorType::Disconnected)
391 .message("Worker disconnected before response stream was established")
392 .build()
393 ));
394 }
395 };
396 drop(_nvtx_wait);
397
398 let mut is_complete_final = false;
400 let mut first_response = true;
401 let stream = tokio_stream::StreamNotifyClose::new(
402 tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
403 )
404 .filter_map(move |res| {
405 if let Some(res_bytes) = res {
406 if first_response {
407 first_response = false;
408 let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
409 REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
410 STAGE_DURATION_SECONDS
411 .with_label_values(&["transport_roundtrip"])
412 .observe(queue_start.elapsed().as_secs_f64());
413 }
414 if is_complete_final {
415 let err = DynamoError::msg(
416 "Response received after generation ended - this should never happen",
417 );
418 return Some(U::from_err(err));
419 }
420 match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
421 Ok(item) => {
422 is_complete_final = item.complete_final;
423 if let Some(data) = item.data {
424 Some(data)
425 } else if is_complete_final {
426 None
427 } else {
428 let err = DynamoError::msg(
429 "Empty response received - this should never happen",
430 );
431 Some(U::from_err(err))
432 }
433 }
434 Err(err) => {
435 let json_str = String::from_utf8_lossy(&res_bytes);
437 tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
438
439 Some(U::from_err(DynamoError::msg(err.to_string())))
440 }
441 }
442 } else if is_complete_final {
443 None
445 } else if engine_ctx_.is_stopped() {
446 tracing::debug!("Request cancelled and then trying to read a response");
450 None
451 } else {
452 let err = DynamoError::builder()
454 .error_type(ErrorType::Disconnected)
455 .message("Stream ended before generation completed")
456 .build();
457 tracing::debug!("{err}");
458 Some(U::from_err(err))
459 }
460 });
461
462 inflight_guard.disarm();
463 let stream = InflightDecStream { inner: stream };
464 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::{
471 CONTROL_MESSAGE_MAX_BYTES, ConnectionInfo, RequestControlMessage, RequestType,
472 ResponseType, serialize_control_message,
473 };
474 use std::collections::BTreeMap;
475
476 fn base_control_message(metadata: BTreeMap<String, String>) -> RequestControlMessage {
477 RequestControlMessage {
478 id: "request-123".to_string(),
479 request_type: RequestType::SingleIn,
480 response_type: ResponseType::ManyOut,
481 connection_info: ConnectionInfo {
482 transport: "tcp".to_string(),
483 info: "{}".to_string(),
484 },
485 metadata,
486 frontend_send_ts_ns: None,
487 }
488 }
489
490 #[test]
491 fn serialize_control_message_succeeds_under_limit() {
492 let mut metadata = BTreeMap::new();
493 metadata.insert("x-tiny-blob".to_string(), "alpha".to_string());
494
495 let ctrl = serialize_control_message(&base_control_message(metadata))
496 .expect("control message should serialize under the limit");
497 assert!(ctrl.len() <= CONTROL_MESSAGE_MAX_BYTES);
498 }
499
500 #[test]
501 fn serialize_control_message_errors_over_limit() {
502 let mut metadata = BTreeMap::new();
503 metadata.insert(
504 "x-large-blob".to_string(),
505 "x".repeat(CONTROL_MESSAGE_MAX_BYTES),
506 );
507
508 let err = serialize_control_message(&base_control_message(metadata))
509 .expect_err("oversized control message should fail")
510 .to_string();
511 assert!(err.contains("request control message too large"));
512 assert!(err.contains(&CONTROL_MESSAGE_MAX_BYTES.to_string()));
513 }
514}