dynamo_runtime/pipeline/network/egress/
addressed_router.rs1use async_nats::client::Client;
5use async_nats::{HeaderMap, HeaderValue};
6use tracing as log;
7
8use super::*;
9use crate::logging::DistributedTraceContext;
10use crate::logging::get_distributed_tracing_context;
11use crate::{Result, protocols::maybe_error::MaybeError};
12use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
13use tracing::Instrument;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17enum RequestType {
18 SingleIn,
19 ManyIn,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24enum ResponseType {
25 SingleOut,
26 ManyOut,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct RequestControlMessage {
31 id: String,
32 request_type: RequestType,
33 response_type: ResponseType,
34 connection_info: ConnectionInfo,
35}
36
37pub struct AddressedRequest<T> {
38 request: T,
39 address: String,
40}
41
42impl<T> AddressedRequest<T> {
43 pub fn new(request: T, address: String) -> Self {
44 Self { request, address }
45 }
46
47 fn into_parts(self) -> (T, String) {
48 (self.request, self.address)
49 }
50}
51
52pub struct AddressedPushRouter {
53 req_transport: Client,
55
56 resp_transport: Arc<tcp::server::TcpStreamServer>,
58}
59
60impl AddressedPushRouter {
61 pub fn new(
62 req_transport: Client,
63 resp_transport: Arc<tcp::server::TcpStreamServer>,
64 ) -> Result<Arc<Self>> {
65 Ok(Arc::new(Self {
66 req_transport,
67 resp_transport,
68 }))
69 }
70}
71
72#[async_trait]
73impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
74where
75 T: Data + Serialize,
76 U: Data + for<'de> Deserialize<'de> + MaybeError,
77{
78 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
79 let request_id = request.context().id().to_string();
80 let (addressed_request, context) = request.transfer(());
81 let (request, address) = addressed_request.into_parts();
82 let engine_ctx = context.context();
83 let engine_ctx_ = engine_ctx.clone();
84
85 let options = StreamOptions::builder()
87 .context(engine_ctx.clone())
88 .enable_request_stream(false)
89 .enable_response_stream(true)
90 .build()
91 .unwrap();
92
93 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
96
97 let pending_response_stream = match pending_connections.into_parts() {
99 (None, Some(recv_stream)) => recv_stream,
100 _ => {
101 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
102 }
103 };
104
105 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
107
108 let control_message = RequestControlMessage {
113 id: engine_ctx.id().to_string(),
114 request_type: RequestType::SingleIn,
115 response_type: ResponseType::ManyOut,
116 connection_info,
117 };
118
119 let ctrl = serde_json::to_vec(&control_message)?;
123 let data = serde_json::to_vec(&request)?;
124
125 log::trace!(
126 request_id,
127 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
128 ctrl.len(),
129 data.len()
130 );
131
132 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
133
134 let codec = TwoPartCodec::default();
138 let buffer = codec.encode_message(msg)?;
139
140 log::trace!(request_id, "enqueueing two-part message to nats");
143
144 let mut headers = HeaderMap::new();
149 if let Some(trace_context) = get_distributed_tracing_context() {
150 headers.insert("traceparent", trace_context.create_traceparent());
151 if let Some(tracestate) = trace_context.tracestate {
152 headers.insert("tracestate", tracestate);
153 }
154 if let Some(x_request_id) = trace_context.x_request_id {
155 headers.insert("x-request-id", x_request_id);
156 }
157 if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
158 headers.insert("x-dynamo-request-id", x_dynamo_request_id);
159 }
160 }
161
162 let _response = self
165 .req_transport
166 .request_with_headers(address.to_string(), headers, buffer)
167 .await?;
168
169 log::trace!(request_id, "awaiting transport handshake");
170 let response_stream = response_stream_provider
171 .await
172 .map_err(|_| PipelineError::DetachedStreamReceiver)?
173 .map_err(PipelineError::ConnectionFailed)?;
174
175 let mut is_complete_final = false;
177 let stream = tokio_stream::StreamNotifyClose::new(
178 tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
179 )
180 .filter_map(move |res| {
181 if let Some(res_bytes) = res {
182 if is_complete_final {
183 return Some(U::from_err(
184 Error::msg(
185 "Response received after generation ended - this should never happen",
186 )
187 .into(),
188 ));
189 }
190 match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
191 Ok(item) => {
192 is_complete_final = item.complete_final;
193 if let Some(data) = item.data {
194 Some(data)
195 } else if is_complete_final {
196 None
197 } else {
198 Some(U::from_err(
199 Error::msg("Empty response received - this should never happen")
200 .into(),
201 ))
202 }
203 }
204 Err(err) => {
205 let json_str = String::from_utf8_lossy(&res_bytes);
207 log::warn!(%err, %json_str, "Failed deserializing JSON to response");
208
209 Some(U::from_err(Error::new(err).into()))
210 }
211 }
212 } else if is_complete_final {
213 None
215 } else if engine_ctx_.is_stopped() {
216 log::debug!("Request cancelled and then trying to read a response");
220 None
221 } else {
222 log::debug!("{STREAM_ERR_MSG}");
224 Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
225 }
226 });
227
228 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
229 }
230}