dynamo_runtime/pipeline/network/egress/
addressed_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use async_nats::client::Client;
5use async_nats::{HeaderMap, HeaderValue};
6use tracing as log;
7
8use super::*;
9use crate::logging::DistributedTraceContext;
10use crate::logging::get_distributed_tracing_context;
11use crate::logging::inject_otel_context_into_nats_headers;
12use crate::{Result, protocols::maybe_error::MaybeError};
13use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
14use tracing::Instrument;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18enum RequestType {
19    SingleIn,
20    ManyIn,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25enum ResponseType {
26    SingleOut,
27    ManyOut,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31struct RequestControlMessage {
32    id: String,
33    request_type: RequestType,
34    response_type: ResponseType,
35    connection_info: ConnectionInfo,
36}
37
38pub struct AddressedRequest<T> {
39    request: T,
40    address: String,
41}
42
43impl<T> AddressedRequest<T> {
44    pub fn new(request: T, address: String) -> Self {
45        Self { request, address }
46    }
47
48    fn into_parts(self) -> (T, String) {
49        (self.request, self.address)
50    }
51}
52
53pub struct AddressedPushRouter {
54    // todo: generalize with a generic
55    req_transport: Client,
56
57    // todo: generalize with a generic
58    resp_transport: Arc<tcp::server::TcpStreamServer>,
59}
60
61impl AddressedPushRouter {
62    pub fn new(
63        req_transport: Client,
64        resp_transport: Arc<tcp::server::TcpStreamServer>,
65    ) -> Result<Arc<Self>> {
66        Ok(Arc::new(Self {
67            req_transport,
68            resp_transport,
69        }))
70    }
71}
72
73#[async_trait]
74impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
75where
76    T: Data + Serialize,
77    U: Data + for<'de> Deserialize<'de> + MaybeError,
78{
79    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
80        let request_id = request.context().id().to_string();
81        let (addressed_request, context) = request.transfer(());
82        let (request, address) = addressed_request.into_parts();
83        let engine_ctx = context.context();
84        let engine_ctx_ = engine_ctx.clone();
85
86        // registration options for the data plane in a singe in / many out configuration
87        let options = StreamOptions::builder()
88            .context(engine_ctx.clone())
89            .enable_request_stream(false)
90            .enable_response_stream(true)
91            .build()
92            .unwrap();
93
94        // register our needs with the data plane
95        // todo - generalize this with a generic data plane object which hides the specific transports
96        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
97
98        // validate and unwrap the RegisteredStream object
99        let pending_response_stream = match pending_connections.into_parts() {
100            (None, Some(recv_stream)) => recv_stream,
101            _ => {
102                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
103            }
104        };
105
106        // separate out the connection info and the stream provider from the registered stream
107        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
108
109        // package up the connection info as part of the "header" component of the two part message
110        // used to issue the request on the
111        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
112        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
113        let control_message = RequestControlMessage {
114            id: engine_ctx.id().to_string(),
115            request_type: RequestType::SingleIn,
116            response_type: ResponseType::ManyOut,
117            connection_info,
118        };
119
120        // next build the two part message where we package the connection info and the request into
121        // a single Vec<u8> that can be sent over the wire.
122        // --- package this up in the WorkQueuePublisher ---
123        let ctrl = serde_json::to_vec(&control_message)?;
124        let data = serde_json::to_vec(&request)?;
125
126        log::trace!(
127            request_id,
128            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
129            ctrl.len(),
130            data.len()
131        );
132
133        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
134
135        // the request plane / work queue should provide a two part message codec that can be used
136        // or it should take a two part message directly
137        // todo - update this
138        let codec = TwoPartCodec::default();
139        let buffer = codec.encode_message(msg)?;
140
141        // TRANSPORT ABSTRACT REQUIRED - END HERE
142
143        log::trace!(request_id, "enqueueing two-part message to nats");
144
145        // Insert Trace Context into Headers
146        // Enables span to be created in push_endpoint before
147        // payload is parsed
148
149        // Prepare trace headers using the OpenTelemetry injector pattern
150        // This handles traceparent and tracestate headers according to W3C Trace Context standard
151        let mut headers = HeaderMap::new();
152        inject_otel_context_into_nats_headers(&mut headers, None);
153
154        // Add additional custom headers that aren't handled by the OpenTelemetry propagator
155        if let Some(trace_context) = get_distributed_tracing_context() {
156            if let Some(x_request_id) = trace_context.x_request_id {
157                headers.insert("x-request-id", x_request_id);
158            }
159            if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
160                headers.insert("x-dynamo-request-id", x_dynamo_request_id);
161            }
162        }
163
164        // we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats
165        // will handle this for us
166        let _response = self
167            .req_transport
168            .request_with_headers(address.to_string(), headers, buffer)
169            .await?;
170
171        log::trace!(request_id, "awaiting transport handshake");
172        let response_stream = response_stream_provider
173            .await
174            .map_err(|_| PipelineError::DetachedStreamReceiver)?
175            .map_err(PipelineError::ConnectionFailed)?;
176
177        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
178        let mut is_complete_final = false;
179        let stream = tokio_stream::StreamNotifyClose::new(
180            tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
181        )
182        .filter_map(move |res| {
183            if let Some(res_bytes) = res {
184                if is_complete_final {
185                    return Some(U::from_err(
186                        Error::msg(
187                            "Response received after generation ended - this should never happen",
188                        )
189                        .into(),
190                    ));
191                }
192                match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
193                    Ok(item) => {
194                        is_complete_final = item.complete_final;
195                        if let Some(data) = item.data {
196                            Some(data)
197                        } else if is_complete_final {
198                            None
199                        } else {
200                            Some(U::from_err(
201                                Error::msg("Empty response received - this should never happen")
202                                    .into(),
203                            ))
204                        }
205                    }
206                    Err(err) => {
207                        // legacy log print
208                        let json_str = String::from_utf8_lossy(&res_bytes);
209                        log::warn!(%err, %json_str, "Failed deserializing JSON to response");
210
211                        Some(U::from_err(Error::new(err).into()))
212                    }
213                }
214            } else if is_complete_final {
215                // end of stream
216                None
217            } else if engine_ctx_.is_stopped() {
218                // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
219                // 'is_killed()' here because it implies the stream ended abnormally which should be
220                // handled by the error branch below.
221                log::debug!("Request cancelled and then trying to read a response");
222                None
223            } else {
224                // stream ended unexpectedly
225                log::debug!("{STREAM_ERR_MSG}");
226                Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
227            }
228        });
229
230        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
231    }
232}