dynamo_runtime/pipeline/network/egress/
addressed_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use async_nats::client::Client;
5use async_nats::{HeaderMap, HeaderValue};
6use tracing as log;
7
8use super::*;
9use crate::logging::DistributedTraceContext;
10use crate::logging::get_distributed_tracing_context;
11use crate::{Result, protocols::maybe_error::MaybeError};
12use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
13use tracing::Instrument;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17enum RequestType {
18    SingleIn,
19    ManyIn,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24enum ResponseType {
25    SingleOut,
26    ManyOut,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct RequestControlMessage {
31    id: String,
32    request_type: RequestType,
33    response_type: ResponseType,
34    connection_info: ConnectionInfo,
35}
36
37pub struct AddressedRequest<T> {
38    request: T,
39    address: String,
40}
41
42impl<T> AddressedRequest<T> {
43    pub fn new(request: T, address: String) -> Self {
44        Self { request, address }
45    }
46
47    fn into_parts(self) -> (T, String) {
48        (self.request, self.address)
49    }
50}
51
52pub struct AddressedPushRouter {
53    // todo: generalize with a generic
54    req_transport: Client,
55
56    // todo: generalize with a generic
57    resp_transport: Arc<tcp::server::TcpStreamServer>,
58}
59
60impl AddressedPushRouter {
61    pub fn new(
62        req_transport: Client,
63        resp_transport: Arc<tcp::server::TcpStreamServer>,
64    ) -> Result<Arc<Self>> {
65        Ok(Arc::new(Self {
66            req_transport,
67            resp_transport,
68        }))
69    }
70}
71
72#[async_trait]
73impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
74where
75    T: Data + Serialize,
76    U: Data + for<'de> Deserialize<'de> + MaybeError,
77{
78    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
79        let request_id = request.context().id().to_string();
80        let (addressed_request, context) = request.transfer(());
81        let (request, address) = addressed_request.into_parts();
82        let engine_ctx = context.context();
83        let engine_ctx_ = engine_ctx.clone();
84
85        // registration options for the data plane in a singe in / many out configuration
86        let options = StreamOptions::builder()
87            .context(engine_ctx.clone())
88            .enable_request_stream(false)
89            .enable_response_stream(true)
90            .build()
91            .unwrap();
92
93        // register our needs with the data plane
94        // todo - generalize this with a generic data plane object which hides the specific transports
95        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
96
97        // validate and unwrap the RegisteredStream object
98        let pending_response_stream = match pending_connections.into_parts() {
99            (None, Some(recv_stream)) => recv_stream,
100            _ => {
101                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
102            }
103        };
104
105        // separate out the connection info and the stream provider from the registered stream
106        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
107
108        // package up the connection info as part of the "header" component of the two part message
109        // used to issue the request on the
110        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
111        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
112        let control_message = RequestControlMessage {
113            id: engine_ctx.id().to_string(),
114            request_type: RequestType::SingleIn,
115            response_type: ResponseType::ManyOut,
116            connection_info,
117        };
118
119        // next build the two part message where we package the connection info and the request into
120        // a single Vec<u8> that can be sent over the wire.
121        // --- package this up in the WorkQueuePublisher ---
122        let ctrl = serde_json::to_vec(&control_message)?;
123        let data = serde_json::to_vec(&request)?;
124
125        log::trace!(
126            request_id,
127            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
128            ctrl.len(),
129            data.len()
130        );
131
132        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
133
134        // the request plane / work queue should provide a two part message codec that can be used
135        // or it should take a two part message directly
136        // todo - update this
137        let codec = TwoPartCodec::default();
138        let buffer = codec.encode_message(msg)?;
139
140        // TRANSPORT ABSTRACT REQUIRED - END HERE
141
142        log::trace!(request_id, "enqueueing two-part message to nats");
143
144        // Insert Trace Context into Headers
145        // Enables span to be created in push_endpoint before
146        // payload is parsed
147
148        let mut headers = HeaderMap::new();
149        if let Some(trace_context) = get_distributed_tracing_context() {
150            headers.insert("traceparent", trace_context.create_traceparent());
151            if let Some(tracestate) = trace_context.tracestate {
152                headers.insert("tracestate", tracestate);
153            }
154            if let Some(x_request_id) = trace_context.x_request_id {
155                headers.insert("x-request-id", x_request_id);
156            }
157            if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
158                headers.insert("x-dynamo-request-id", x_dynamo_request_id);
159            }
160        }
161
162        // we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats
163        // will handle this for us
164        let _response = self
165            .req_transport
166            .request_with_headers(address.to_string(), headers, buffer)
167            .await?;
168
169        log::trace!(request_id, "awaiting transport handshake");
170        let response_stream = response_stream_provider
171            .await
172            .map_err(|_| PipelineError::DetachedStreamReceiver)?
173            .map_err(PipelineError::ConnectionFailed)?;
174
175        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
176        let mut is_complete_final = false;
177        let stream = tokio_stream::StreamNotifyClose::new(
178            tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
179        )
180        .filter_map(move |res| {
181            if let Some(res_bytes) = res {
182                if is_complete_final {
183                    return Some(U::from_err(
184                        Error::msg(
185                            "Response received after generation ended - this should never happen",
186                        )
187                        .into(),
188                    ));
189                }
190                match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
191                    Ok(item) => {
192                        is_complete_final = item.complete_final;
193                        if let Some(data) = item.data {
194                            Some(data)
195                        } else if is_complete_final {
196                            None
197                        } else {
198                            Some(U::from_err(
199                                Error::msg("Empty response received - this should never happen")
200                                    .into(),
201                            ))
202                        }
203                    }
204                    Err(err) => {
205                        // legacy log print
206                        let json_str = String::from_utf8_lossy(&res_bytes);
207                        log::warn!(%err, %json_str, "Failed deserializing JSON to response");
208
209                        Some(U::from_err(Error::new(err).into()))
210                    }
211                }
212            } else if is_complete_final {
213                // end of stream
214                None
215            } else if engine_ctx_.is_stopped() {
216                // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
217                // 'is_killed()' here because it implies the stream ended abnormally which should be
218                // handled by the error branch below.
219                log::debug!("Request cancelled and then trying to read a response");
220                None
221            } else {
222                // stream ended unexpectedly
223                log::debug!("{STREAM_ERR_MSG}");
224                Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
225            }
226        });
227
228        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
229    }
230}