dynamo_runtime/pipeline/network/egress/
push.rs1use async_nats::client::Client;
17use tracing as log;
18
19use super::*;
20use crate::Result;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24enum RequestType {
25 SingleIn,
26 ManyIn,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31enum ResponseType {
32 SingleOut,
33 ManyOut,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct RequestControlMessage {
38 id: String,
39 request_type: RequestType,
40 response_type: ResponseType,
41 connection_info: ConnectionInfo,
42}
43
44pub type PushRouter<In, Out> =
45 Arc<dyn AsyncEngine<SingleIn<AddressedRequest<In>>, ManyOut<Out>, Error>>;
46
47pub struct AddressedRequest<T> {
48 request: T,
49 address: String,
50}
51
52impl<T> AddressedRequest<T> {
53 pub fn new(request: T, address: String) -> Self {
54 Self { request, address }
55 }
56
57 fn into_parts(self) -> (T, String) {
58 (self.request, self.address)
59 }
60}
61
62pub struct AddressedPushRouter {
63 req_transport: Client,
65
66 resp_transport: Arc<tcp::server::TcpStreamServer>,
68}
69
70impl AddressedPushRouter {
71 pub fn new(
72 req_transport: Client,
73 resp_transport: Arc<tcp::server::TcpStreamServer>,
74 ) -> Result<Arc<Self>> {
75 Ok(Arc::new(Self {
76 req_transport,
77 resp_transport,
78 }))
79 }
80}
81
82#[async_trait]
83impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
84where
85 T: Data + Serialize,
86 U: Data + for<'de> Deserialize<'de>,
87{
88 async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
89 let request_id = request.context().id().to_string();
90 let (addressed_request, context) = request.transfer(());
91 let (request, address) = addressed_request.into_parts();
92 let engine_ctx = context.context();
93
94 let options = StreamOptions::builder()
96 .context(engine_ctx.clone())
97 .enable_request_stream(false)
98 .enable_response_stream(true)
99 .build()
100 .unwrap();
101
102 let pending_connections: PendingConnections = self.resp_transport.register(options).await;
105
106 let pending_response_stream = match pending_connections.into_parts() {
108 (None, Some(recv_stream)) => recv_stream,
109 _ => {
110 panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
111 }
112 };
113
114 let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
116
117 let control_message = RequestControlMessage {
122 id: engine_ctx.id().to_string(),
123 request_type: RequestType::SingleIn,
124 response_type: ResponseType::ManyOut,
125 connection_info,
126 };
127
128 let ctrl = serde_json::to_vec(&control_message)?;
132 let data = serde_json::to_vec(&request)?;
133
134 log::trace!(
135 request_id,
136 "packaging two-part message; ctrl: {} bytes, data: {} bytes",
137 ctrl.len(),
138 data.len()
139 );
140
141 let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
142
143 let codec = TwoPartCodec::default();
147 let buffer = codec.encode_message(msg)?;
148
149 log::trace!(request_id, "enqueueing two-part message to nats");
152
153 let _response = self
156 .req_transport
157 .request(address.to_string(), buffer)
158 .await?;
159
160 log::trace!(request_id, "awaiting transport handshake");
161 let response_stream = response_stream_provider
162 .await
163 .map_err(|_| PipelineError::DetatchedStreamReceiver)?
164 .map_err(PipelineError::ConnectionFailed)?;
165
166 let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx);
167
168 let stream = stream.filter_map(|msg| async move {
169 match serde_json::from_slice::<U>(&msg) {
170 Ok(r) => Some(r),
171 Err(err) => {
172 let json_str = String::from_utf8_lossy(&msg);
173 log::warn!(%err, %json_str, "Failed deserializing JSON to response");
174 None
175 }
176 }
177 });
178
179 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
180 }
181}