dynamo_runtime/pipeline/network/egress/
addressed_router.rs1use std::sync::Arc;
5use std::time::Instant;
6
7use super::unified_client::RequestPlaneClient;
8use super::*;
9use crate::component::Instance;
10use crate::discovery::EndpointInstanceId;
11use crate::dynamo_nvtx_range;
12use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
13use crate::error::{DynamoError, ErrorType};
14use crate::logging::inject_trace_headers_into_map;
15use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
16use crate::metrics::request_plane::{
17 REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
18 REQUEST_PLANE_SEND_SECONDS,
19};
20use crate::pipeline::network::ConnectionInfo;
21use crate::pipeline::network::NetworkStreamWrapper;
22use crate::pipeline::network::PendingConnections;
23use crate::pipeline::network::StreamOptions;
24use crate::pipeline::network::TwoPartCodec;
25use crate::pipeline::network::codec::TwoPartMessage;
26use crate::pipeline::network::tcp;
27use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
28use crate::protocols::maybe_error::MaybeError;
29
30use anyhow::{Error, Result};
31use futures::stream::Stream;
32use serde::Deserialize;
33use serde::Serialize;
34use std::pin::Pin;
35use std::task::{Context, Poll};
36use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
37use tracing::Instrument;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41enum RequestType {
42 SingleIn,
43 ManyIn,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48enum ResponseType {
49 SingleOut,
50 ManyOut,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54struct RequestControlMessage {
55 id: String,
56 request_type: RequestType,
57 response_type: ResponseType,
58 connection_info: ConnectionInfo,
59 #[serde(default, skip_serializing_if = "Option::is_none")]
63 frontend_send_ts_ns: Option<u64>,
64}
65
66struct InflightGuard {
70 armed: bool,
71}
72
73impl InflightGuard {
74 fn new() -> Self {
75 Self { armed: true }
76 }
77
78 fn disarm(mut self) {
81 self.armed = false;
82 }
83}
84
85impl Drop for InflightGuard {
86 fn drop(&mut self) {
87 if self.armed {
88 REQUEST_PLANE_INFLIGHT.dec();
89 }
90 }
91}
92
93struct InflightDecStream<S> {
95 inner: S,
96}
97
98impl<S, T> Stream for InflightDecStream<S>
99where
100 S: Stream<Item = T> + Unpin,
101{
102 type Item = T;
103
104 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105 Pin::new(&mut self.inner).poll_next(cx)
106 }
107}
108
109impl<S> Drop for InflightDecStream<S> {
110 fn drop(&mut self) {
111 REQUEST_PLANE_INFLIGHT.dec();
112 }
113}
114
115pub struct AddressedRequest<T> {
116 request: T,
117 address: String,
118 instance: Option<Instance>,
121}
122
123impl<T> AddressedRequest<T> {
124 pub fn new(request: T, address: String) -> Self {
125 Self {
126 request,
127 address,
128 instance: None,
129 }
130 }
131
132 pub fn with_instance(request: T, address: String, instance: Instance) -> Self {
133 Self {
134 request,
135 address,
136 instance: Some(instance),
137 }
138 }
139
140 pub(crate) fn into_parts(self) -> (T, String, Option<Instance>) {
141 (self.request, self.address, self.instance)
142 }
143}
144
145pub struct AddressedPushRouter {
146 req_client: Arc<dyn RequestPlaneClient>,
148
149 resp_transport: Arc<tcp::server::TcpStreamServer>,
151}
152
153impl AddressedPushRouter {
154 pub fn new(
159 req_client: Arc<dyn RequestPlaneClient>,
160 resp_transport: Arc<tcp::server::TcpStreamServer>,
161 ) -> Result<Arc<Self>> {
162 Ok(Arc::new(Self {
163 req_client,
164 resp_transport,
165 }))
166 }
167
168 pub async fn cancel_instance_streams(&self, instance_id: &EndpointInstanceId) -> usize {
170 self.resp_transport
171 .cancel_instance_streams(instance_id)
172 .await
173 }
174
175 pub async fn clear_instance_tombstone(&self, instance_id: &EndpointInstanceId) {
177 self.resp_transport
178 .clear_instance_tombstone(instance_id)
179 .await
180 }
181}
182
183#[async_trait::async_trait]
184impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
185where
186 T: Data + Serialize,
187 U: Data + for<'de> Deserialize<'de> + MaybeError,
188{
189 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
190 let queue_start = Instant::now();
191 REQUEST_PLANE_INFLIGHT.inc();
192 let inflight_guard = InflightGuard::new();
193
194 let request_id = request.context().id().to_string();
195 let (addressed_request, context) = request.transfer(());
196 let (request, address, instance_info) = addressed_request.into_parts();
197 let engine_ctx = context.context();
198 let engine_ctx_ = engine_ctx.clone();
199
200 let options = StreamOptions::builder()
202 .context(engine_ctx.clone())
203 .enable_request_stream(false)
204 .enable_response_stream(true)
205 .build()
206 .unwrap();
207
208 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
211
212 let pending_response_stream = match pending_connections.into_parts() {
214 (None, Some(recv_stream)) => recv_stream,
215 _ => {
216 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
217 }
218 };
219
220 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
222
223 let recv_subject: Option<String> =
225 serde_json::from_str::<tcp::TcpStreamConnectionInfo>(&connection_info.info)
226 .ok()
227 .map(|ci| ci.subject);
228
229 if let (Some(subject), Some(inst)) = (&recv_subject, &instance_info) {
232 let endpoint_instance_id = inst.endpoint_instance_id();
233 if !self
234 .resp_transport
235 .associate_instance(subject, &endpoint_instance_id)
236 .await
237 {
238 return Err(anyhow::anyhow!(
239 DynamoError::builder()
240 .error_type(ErrorType::Disconnected)
241 .message(
242 "Worker removed before request could be sent (tombstoned instance)"
243 )
244 .build()
245 ));
246 }
247 }
248
249 let control_message = RequestControlMessage {
254 id: engine_ctx.id().to_string(),
255 request_type: RequestType::SingleIn,
256 response_type: ResponseType::ManyOut,
257 connection_info,
258 frontend_send_ts_ns: None,
259 };
260
261 let ctrl = match serde_json::to_vec(&control_message) {
265 Ok(v) => v,
266 Err(e) => {
267 if let Some(subject) = &recv_subject {
268 self.resp_transport.cancel_recv_stream(subject).await;
269 }
270 return Err(e.into());
271 }
272 };
273 let data = match serde_json::to_vec(&request) {
274 Ok(v) => v,
275 Err(e) => {
276 if let Some(subject) = &recv_subject {
277 self.resp_transport.cancel_recv_stream(subject).await;
278 }
279 return Err(e.into());
280 }
281 };
282
283 tracing::trace!(
284 request_id,
285 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
286 ctrl.len(),
287 data.len()
288 );
289
290 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
291
292 let codec = TwoPartCodec::default();
296 let buffer = match codec.encode_message(msg) {
297 Ok(v) => v,
298 Err(e) => {
299 if let Some(subject) = &recv_subject {
300 self.resp_transport.cancel_recv_stream(subject).await;
301 }
302 return Err(e.into());
303 }
304 };
305
306 REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
307 let tx_start = Instant::now();
308
309 tracing::trace!(
313 request_id,
314 transport = self.req_client.transport_name(),
315 address = %address,
316 "Sending request via request plane client"
317 );
318
319 let mut headers = std::collections::HashMap::new();
321 inject_trace_headers_into_map(&mut headers);
322 headers.insert("request-id".to_string(), request_id.clone());
323
324 let send_ts_ns = std::time::SystemTime::now()
327 .duration_since(std::time::UNIX_EPOCH)
328 .unwrap_or_default()
329 .as_nanos() as u64;
330 headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
331
332 let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
334 let send_result = self.req_client.send_request(address, buffer, headers).await;
335 drop(_nvtx_send);
336
337 if let Err(e) = send_result {
338 if let Some(subject) = &recv_subject {
339 self.resp_transport.cancel_recv_stream(subject).await;
340 }
341 return Err(e);
342 }
343 REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
344
345 let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
346 tracing::trace!(request_id, "awaiting transport handshake");
347
348 let response_stream = match response_stream_provider.await {
351 Ok(Ok(stream)) => stream,
352 Ok(Err(e)) => {
353 if let Some(subject) = &recv_subject {
360 self.resp_transport.cancel_recv_stream(subject).await;
361 }
362 return Err(anyhow::anyhow!(
363 DynamoError::builder()
364 .error_type(ErrorType::CannotConnect)
365 .message(format!(
366 "Worker generate() failed before response stream: {e}"
367 ))
368 .build()
369 ));
370 }
371 Err(_recv_err) => {
372 if let Some(subject) = &recv_subject {
375 self.resp_transport.cancel_recv_stream(subject).await;
376 }
377 return Err(anyhow::anyhow!(
378 DynamoError::builder()
379 .error_type(ErrorType::Disconnected)
380 .message("Worker disconnected before response stream was established")
381 .build()
382 ));
383 }
384 };
385 drop(_nvtx_wait);
386
387 let mut is_complete_final = false;
389 let mut first_response = true;
390 let stream = tokio_stream::StreamNotifyClose::new(
391 tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
392 )
393 .filter_map(move |res| {
394 if let Some(res_bytes) = res {
395 if first_response {
396 first_response = false;
397 let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
398 REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
399 STAGE_DURATION_SECONDS
400 .with_label_values(&["transport_roundtrip"])
401 .observe(queue_start.elapsed().as_secs_f64());
402 }
403 if is_complete_final {
404 let err = DynamoError::msg(
405 "Response received after generation ended - this should never happen",
406 );
407 return Some(U::from_err(err));
408 }
409 match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
410 Ok(item) => {
411 is_complete_final = item.complete_final;
412 if let Some(data) = item.data {
413 Some(data)
414 } else if is_complete_final {
415 None
416 } else {
417 let err = DynamoError::msg(
418 "Empty response received - this should never happen",
419 );
420 Some(U::from_err(err))
421 }
422 }
423 Err(err) => {
424 let json_str = String::from_utf8_lossy(&res_bytes);
426 tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
427
428 Some(U::from_err(DynamoError::msg(err.to_string())))
429 }
430 }
431 } else if is_complete_final {
432 None
434 } else if engine_ctx_.is_stopped() {
435 tracing::debug!("Request cancelled and then trying to read a response");
439 None
440 } else {
441 let err = DynamoError::builder()
443 .error_type(ErrorType::Disconnected)
444 .message("Stream ended before generation completed")
445 .build();
446 tracing::debug!("{err}");
447 Some(U::from_err(err))
448 }
449 });
450
451 inflight_guard.disarm();
452 let stream = InflightDecStream { inner: stream };
453 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
454 }
455}