dynamo_runtime/pipeline/network/egress/
addressed_router.rs1use async_nats::client::Client;
5use async_nats::{HeaderMap, HeaderValue};
6use tracing as log;
7
8use super::*;
9use crate::logging::DistributedTraceContext;
10use crate::logging::get_distributed_tracing_context;
11use crate::logging::inject_otel_context_into_nats_headers;
12use crate::{Result, protocols::maybe_error::MaybeError};
13use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
14use tracing::Instrument;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18enum RequestType {
19 SingleIn,
20 ManyIn,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25enum ResponseType {
26 SingleOut,
27 ManyOut,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31struct RequestControlMessage {
32 id: String,
33 request_type: RequestType,
34 response_type: ResponseType,
35 connection_info: ConnectionInfo,
36}
37
38pub struct AddressedRequest<T> {
39 request: T,
40 address: String,
41}
42
43impl<T> AddressedRequest<T> {
44 pub fn new(request: T, address: String) -> Self {
45 Self { request, address }
46 }
47
48 fn into_parts(self) -> (T, String) {
49 (self.request, self.address)
50 }
51}
52
53pub struct AddressedPushRouter {
54 req_transport: Client,
56
57 resp_transport: Arc<tcp::server::TcpStreamServer>,
59}
60
61impl AddressedPushRouter {
62 pub fn new(
63 req_transport: Client,
64 resp_transport: Arc<tcp::server::TcpStreamServer>,
65 ) -> Result<Arc<Self>> {
66 Ok(Arc::new(Self {
67 req_transport,
68 resp_transport,
69 }))
70 }
71}
72
73#[async_trait]
74impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
75where
76 T: Data + Serialize,
77 U: Data + for<'de> Deserialize<'de> + MaybeError,
78{
79 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
80 let request_id = request.context().id().to_string();
81 let (addressed_request, context) = request.transfer(());
82 let (request, address) = addressed_request.into_parts();
83 let engine_ctx = context.context();
84 let engine_ctx_ = engine_ctx.clone();
85
86 let options = StreamOptions::builder()
88 .context(engine_ctx.clone())
89 .enable_request_stream(false)
90 .enable_response_stream(true)
91 .build()
92 .unwrap();
93
94 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
97
98 let pending_response_stream = match pending_connections.into_parts() {
100 (None, Some(recv_stream)) => recv_stream,
101 _ => {
102 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
103 }
104 };
105
106 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
108
109 let control_message = RequestControlMessage {
114 id: engine_ctx.id().to_string(),
115 request_type: RequestType::SingleIn,
116 response_type: ResponseType::ManyOut,
117 connection_info,
118 };
119
120 let ctrl = serde_json::to_vec(&control_message)?;
124 let data = serde_json::to_vec(&request)?;
125
126 log::trace!(
127 request_id,
128 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
129 ctrl.len(),
130 data.len()
131 );
132
133 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
134
135 let codec = TwoPartCodec::default();
139 let buffer = codec.encode_message(msg)?;
140
141 log::trace!(request_id, "enqueueing two-part message to nats");
144
145 let mut headers = HeaderMap::new();
152 inject_otel_context_into_nats_headers(&mut headers, None);
153
154 if let Some(trace_context) = get_distributed_tracing_context() {
156 if let Some(x_request_id) = trace_context.x_request_id {
157 headers.insert("x-request-id", x_request_id);
158 }
159 if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
160 headers.insert("x-dynamo-request-id", x_dynamo_request_id);
161 }
162 }
163
164 let _response = self
167 .req_transport
168 .request_with_headers(address.to_string(), headers, buffer)
169 .await?;
170
171 log::trace!(request_id, "awaiting transport handshake");
172 let response_stream = response_stream_provider
173 .await
174 .map_err(|_| PipelineError::DetachedStreamReceiver)?
175 .map_err(PipelineError::ConnectionFailed)?;
176
177 let mut is_complete_final = false;
179 let stream = tokio_stream::StreamNotifyClose::new(
180 tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
181 )
182 .filter_map(move |res| {
183 if let Some(res_bytes) = res {
184 if is_complete_final {
185 return Some(U::from_err(
186 Error::msg(
187 "Response received after generation ended - this should never happen",
188 )
189 .into(),
190 ));
191 }
192 match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
193 Ok(item) => {
194 is_complete_final = item.complete_final;
195 if let Some(data) = item.data {
196 Some(data)
197 } else if is_complete_final {
198 None
199 } else {
200 Some(U::from_err(
201 Error::msg("Empty response received - this should never happen")
202 .into(),
203 ))
204 }
205 }
206 Err(err) => {
207 let json_str = String::from_utf8_lossy(&res_bytes);
209 log::warn!(%err, %json_str, "Failed deserializing JSON to response");
210
211 Some(U::from_err(Error::new(err).into()))
212 }
213 }
214 } else if is_complete_final {
215 None
217 } else if engine_ctx_.is_stopped() {
218 log::debug!("Request cancelled and then trying to read a response");
222 None
223 } else {
224 log::debug!("{STREAM_ERR_MSG}");
226 Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
227 }
228 });
229
230 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
231 }
232}