dynamo_runtime/pipeline/network/egress/
addressed_router.rs1use std::sync::Arc;
5use std::time::Instant;
6
7use super::unified_client::RequestPlaneClient;
8use super::*;
9use crate::dynamo_nvtx_range;
10use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
11use crate::error::{DynamoError, ErrorType};
12use crate::logging::inject_trace_headers_into_map;
13use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
14use crate::metrics::request_plane::{
15 REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
16 REQUEST_PLANE_SEND_SECONDS,
17};
18use crate::pipeline::network::ConnectionInfo;
19use crate::pipeline::network::NetworkStreamWrapper;
20use crate::pipeline::network::PendingConnections;
21use crate::pipeline::network::StreamOptions;
22use crate::pipeline::network::TwoPartCodec;
23use crate::pipeline::network::codec::TwoPartMessage;
24use crate::pipeline::network::tcp;
25use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
26use crate::protocols::maybe_error::MaybeError;
27
28use anyhow::{Error, Result};
29use futures::stream::Stream;
30use serde::Deserialize;
31use serde::Serialize;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
35use tracing::Instrument;
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39enum RequestType {
40 SingleIn,
41 ManyIn,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46enum ResponseType {
47 SingleOut,
48 ManyOut,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52struct RequestControlMessage {
53 id: String,
54 request_type: RequestType,
55 response_type: ResponseType,
56 connection_info: ConnectionInfo,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
61 frontend_send_ts_ns: Option<u64>,
62}
63
64struct InflightGuard {
68 armed: bool,
69}
70
71impl InflightGuard {
72 fn new() -> Self {
73 Self { armed: true }
74 }
75
76 fn disarm(mut self) {
79 self.armed = false;
80 }
81}
82
83impl Drop for InflightGuard {
84 fn drop(&mut self) {
85 if self.armed {
86 REQUEST_PLANE_INFLIGHT.dec();
87 }
88 }
89}
90
91struct InflightDecStream<S> {
93 inner: S,
94}
95
96impl<S, T> Stream for InflightDecStream<S>
97where
98 S: Stream<Item = T> + Unpin,
99{
100 type Item = T;
101
102 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
103 Pin::new(&mut self.inner).poll_next(cx)
104 }
105}
106
107impl<S> Drop for InflightDecStream<S> {
108 fn drop(&mut self) {
109 REQUEST_PLANE_INFLIGHT.dec();
110 }
111}
112
113pub struct AddressedRequest<T> {
114 request: T,
115 address: String,
116}
117
118impl<T> AddressedRequest<T> {
119 pub fn new(request: T, address: String) -> Self {
120 Self { request, address }
121 }
122
123 pub(crate) fn into_parts(self) -> (T, String) {
124 (self.request, self.address)
125 }
126}
127
128pub struct AddressedPushRouter {
129 req_client: Arc<dyn RequestPlaneClient>,
131
132 resp_transport: Arc<tcp::server::TcpStreamServer>,
134}
135
136impl AddressedPushRouter {
137 pub fn new(
142 req_client: Arc<dyn RequestPlaneClient>,
143 resp_transport: Arc<tcp::server::TcpStreamServer>,
144 ) -> Result<Arc<Self>> {
145 Ok(Arc::new(Self {
146 req_client,
147 resp_transport,
148 }))
149 }
150}
151
152#[async_trait::async_trait]
153impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
154where
155 T: Data + Serialize,
156 U: Data + for<'de> Deserialize<'de> + MaybeError,
157{
158 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
159 let queue_start = Instant::now();
160 REQUEST_PLANE_INFLIGHT.inc();
161 let inflight_guard = InflightGuard::new();
162
163 let request_id = request.context().id().to_string();
164 let (addressed_request, context) = request.transfer(());
165 let (request, address) = addressed_request.into_parts();
166 let engine_ctx = context.context();
167 let engine_ctx_ = engine_ctx.clone();
168
169 let options = StreamOptions::builder()
171 .context(engine_ctx.clone())
172 .enable_request_stream(false)
173 .enable_response_stream(true)
174 .build()
175 .unwrap();
176
177 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
180
181 let pending_response_stream = match pending_connections.into_parts() {
183 (None, Some(recv_stream)) => recv_stream,
184 _ => {
185 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
186 }
187 };
188
189 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
191
192 let control_message = RequestControlMessage {
197 id: engine_ctx.id().to_string(),
198 request_type: RequestType::SingleIn,
199 response_type: ResponseType::ManyOut,
200 connection_info,
201 frontend_send_ts_ns: None,
202 };
203
204 let ctrl = serde_json::to_vec(&control_message)?;
208 let data = serde_json::to_vec(&request)?;
209
210 tracing::trace!(
211 request_id,
212 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
213 ctrl.len(),
214 data.len()
215 );
216
217 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
218
219 let codec = TwoPartCodec::default();
223 let buffer = {
224 let _nvtx = dynamo_nvtx_range!("codec.encode");
225 codec.encode_message(msg)?
226 };
227
228 REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
229 let tx_start = Instant::now();
230
231 tracing::trace!(
235 request_id,
236 transport = self.req_client.transport_name(),
237 address = %address,
238 "Sending request via request plane client"
239 );
240
241 let mut headers = std::collections::HashMap::new();
243 inject_trace_headers_into_map(&mut headers);
244 headers.insert("request-id".to_string(), request_id.clone());
245
246 let send_ts_ns = std::time::SystemTime::now()
249 .duration_since(std::time::UNIX_EPOCH)
250 .unwrap_or_default()
251 .as_nanos() as u64;
252 headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
253
254 let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
256 let _response = self
257 .req_client
258 .send_request(address, buffer, headers)
259 .await?;
260 drop(_nvtx_send);
261 REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
262
263 let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
264 tracing::trace!(request_id, "awaiting transport handshake");
265 let response_stream = response_stream_provider
266 .await
267 .map_err(|_| PipelineError::DetachedStreamReceiver)?
268 .map_err(PipelineError::ConnectionFailed)?;
269 drop(_nvtx_wait);
270
271 let mut is_complete_final = false;
273 let mut first_response = true;
274 let stream = tokio_stream::StreamNotifyClose::new(
275 tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
276 )
277 .filter_map(move |res| {
278 if let Some(res_bytes) = res {
279 if first_response {
280 first_response = false;
281 let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
282 REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
283 STAGE_DURATION_SECONDS
284 .with_label_values(&["transport_roundtrip"])
285 .observe(queue_start.elapsed().as_secs_f64());
286 }
287 if is_complete_final {
288 let err = DynamoError::msg(
289 "Response received after generation ended - this should never happen",
290 );
291 return Some(U::from_err(err));
292 }
293 match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
294 Ok(item) => {
295 is_complete_final = item.complete_final;
296 if let Some(data) = item.data {
297 Some(data)
298 } else if is_complete_final {
299 None
300 } else {
301 let err = DynamoError::msg(
302 "Empty response received - this should never happen",
303 );
304 Some(U::from_err(err))
305 }
306 }
307 Err(err) => {
308 let json_str = String::from_utf8_lossy(&res_bytes);
310 tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
311
312 Some(U::from_err(DynamoError::msg(err.to_string())))
313 }
314 }
315 } else if is_complete_final {
316 None
318 } else if engine_ctx_.is_stopped() {
319 tracing::debug!("Request cancelled and then trying to read a response");
323 None
324 } else {
325 let err = DynamoError::builder()
327 .error_type(ErrorType::Disconnected)
328 .message("Stream ended before generation completed")
329 .build();
330 tracing::debug!("{err}");
331 Some(U::from_err(err))
332 }
333 });
334
335 inflight_guard.disarm();
336 let stream = InflightDecStream { inner: stream };
337 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
338 }
339}