dynamo_runtime/pipeline/network/egress/
addressed_router.rs1use std::sync::Arc;
5
6use super::unified_client::RequestPlaneClient;
7use super::*;
8use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
9use crate::logging::inject_trace_headers_into_map;
10use crate::pipeline::network::ConnectionInfo;
11use crate::pipeline::network::NetworkStreamWrapper;
12use crate::pipeline::network::PendingConnections;
13use crate::pipeline::network::STREAM_ERR_MSG;
14use crate::pipeline::network::StreamOptions;
15use crate::pipeline::network::TwoPartCodec;
16use crate::pipeline::network::codec::TwoPartMessage;
17use crate::pipeline::network::tcp;
18use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
19use crate::protocols::maybe_error::MaybeError;
20
21use anyhow::{Error, Result};
22use serde::Deserialize;
23use serde::Serialize;
24use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
25use tracing::Instrument;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29enum RequestType {
30 SingleIn,
31 ManyIn,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36enum ResponseType {
37 SingleOut,
38 ManyOut,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42struct RequestControlMessage {
43 id: String,
44 request_type: RequestType,
45 response_type: ResponseType,
46 connection_info: ConnectionInfo,
47}
48
49pub struct AddressedRequest<T> {
50 request: T,
51 address: String,
52}
53
54impl<T> AddressedRequest<T> {
55 pub fn new(request: T, address: String) -> Self {
56 Self { request, address }
57 }
58
59 pub(crate) fn into_parts(self) -> (T, String) {
60 (self.request, self.address)
61 }
62}
63
64pub struct AddressedPushRouter {
65 req_client: Arc<dyn RequestPlaneClient>,
67
68 resp_transport: Arc<tcp::server::TcpStreamServer>,
70}
71
72impl AddressedPushRouter {
73 pub fn new(
78 req_client: Arc<dyn RequestPlaneClient>,
79 resp_transport: Arc<tcp::server::TcpStreamServer>,
80 ) -> Result<Arc<Self>> {
81 Ok(Arc::new(Self {
82 req_client,
83 resp_transport,
84 }))
85 }
86}
87
88#[async_trait::async_trait]
89impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
90where
91 T: Data + Serialize,
92 U: Data + for<'de> Deserialize<'de> + MaybeError,
93{
94 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
95 let request_id = request.context().id().to_string();
96 let (addressed_request, context) = request.transfer(());
97 let (request, address) = addressed_request.into_parts();
98 let engine_ctx = context.context();
99 let engine_ctx_ = engine_ctx.clone();
100
101 let options = StreamOptions::builder()
103 .context(engine_ctx.clone())
104 .enable_request_stream(false)
105 .enable_response_stream(true)
106 .build()
107 .unwrap();
108
109 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
112
113 let pending_response_stream = match pending_connections.into_parts() {
115 (None, Some(recv_stream)) => recv_stream,
116 _ => {
117 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
118 }
119 };
120
121 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
123
124 let control_message = RequestControlMessage {
129 id: engine_ctx.id().to_string(),
130 request_type: RequestType::SingleIn,
131 response_type: ResponseType::ManyOut,
132 connection_info,
133 };
134
135 let ctrl = serde_json::to_vec(&control_message)?;
139 let data = serde_json::to_vec(&request)?;
140
141 tracing::trace!(
142 request_id,
143 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
144 ctrl.len(),
145 data.len()
146 );
147
148 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
149
150 let codec = TwoPartCodec::default();
154 let buffer = codec.encode_message(msg)?;
155
156 tracing::trace!(
160 request_id,
161 transport = self.req_client.transport_name(),
162 address = %address,
163 "Sending request via request plane client"
164 );
165
166 let mut headers = std::collections::HashMap::new();
168 inject_trace_headers_into_map(&mut headers);
169
170 let _response = self
172 .req_client
173 .send_request(address, buffer, headers)
174 .await?;
175
176 tracing::trace!(request_id, "awaiting transport handshake");
177 let response_stream = response_stream_provider
178 .await
179 .map_err(|_| PipelineError::DetachedStreamReceiver)?
180 .map_err(PipelineError::ConnectionFailed)?;
181
182 let mut is_complete_final = false;
184 let stream = tokio_stream::StreamNotifyClose::new(
185 tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
186 )
187 .filter_map(move |res| {
188 if let Some(res_bytes) = res {
189 if is_complete_final {
190 return Some(U::from_err(
191 Error::msg(
192 "Response received after generation ended - this should never happen",
193 )
194 .into(),
195 ));
196 }
197 match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
198 Ok(item) => {
199 is_complete_final = item.complete_final;
200 if let Some(data) = item.data {
201 Some(data)
202 } else if is_complete_final {
203 None
204 } else {
205 Some(U::from_err(
206 Error::msg("Empty response received - this should never happen")
207 .into(),
208 ))
209 }
210 }
211 Err(err) => {
212 let json_str = String::from_utf8_lossy(&res_bytes);
214 tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
215
216 Some(U::from_err(Error::new(err).into()))
217 }
218 }
219 } else if is_complete_final {
220 None
222 } else if engine_ctx_.is_stopped() {
223 tracing::debug!("Request cancelled and then trying to read a response");
227 None
228 } else {
229 tracing::debug!("{STREAM_ERR_MSG}");
231 Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
232 }
233 });
234
235 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
236 }
237}