dynamo_runtime/pipeline/network/egress/
addressed_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use super::unified_client::RequestPlaneClient;
7use super::*;
8use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
9use crate::logging::inject_trace_headers_into_map;
10use crate::pipeline::network::ConnectionInfo;
11use crate::pipeline::network::NetworkStreamWrapper;
12use crate::pipeline::network::PendingConnections;
13use crate::pipeline::network::STREAM_ERR_MSG;
14use crate::pipeline::network::StreamOptions;
15use crate::pipeline::network::TwoPartCodec;
16use crate::pipeline::network::codec::TwoPartMessage;
17use crate::pipeline::network::tcp;
18use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
19use crate::protocols::maybe_error::MaybeError;
20
21use anyhow::{Error, Result};
22use serde::Deserialize;
23use serde::Serialize;
24use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
25use tracing::Instrument;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29enum RequestType {
30    SingleIn,
31    ManyIn,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36enum ResponseType {
37    SingleOut,
38    ManyOut,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42struct RequestControlMessage {
43    id: String,
44    request_type: RequestType,
45    response_type: ResponseType,
46    connection_info: ConnectionInfo,
47}
48
49pub struct AddressedRequest<T> {
50    request: T,
51    address: String,
52}
53
54impl<T> AddressedRequest<T> {
55    pub fn new(request: T, address: String) -> Self {
56        Self { request, address }
57    }
58
59    pub(crate) fn into_parts(self) -> (T, String) {
60        (self.request, self.address)
61    }
62}
63
64pub struct AddressedPushRouter {
65    // Request transport (unified trait object - works with all transports)
66    req_client: Arc<dyn RequestPlaneClient>,
67
68    // Response transport (TCP streaming - unchanged)
69    resp_transport: Arc<tcp::server::TcpStreamServer>,
70}
71
72impl AddressedPushRouter {
73    /// Create a new router with a request plane client
74    ///
75    /// This is the unified constructor that works with any transport type.
76    /// The client is provided as a trait object, hiding the specific implementation.
77    pub fn new(
78        req_client: Arc<dyn RequestPlaneClient>,
79        resp_transport: Arc<tcp::server::TcpStreamServer>,
80    ) -> Result<Arc<Self>> {
81        Ok(Arc::new(Self {
82            req_client,
83            resp_transport,
84        }))
85    }
86}
87
88#[async_trait::async_trait]
89impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
90where
91    T: Data + Serialize,
92    U: Data + for<'de> Deserialize<'de> + MaybeError,
93{
94    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
95        let request_id = request.context().id().to_string();
96        let (addressed_request, context) = request.transfer(());
97        let (request, address) = addressed_request.into_parts();
98        let engine_ctx = context.context();
99        let engine_ctx_ = engine_ctx.clone();
100
101        // registration options for the data plane in a singe in / many out configuration
102        let options = StreamOptions::builder()
103            .context(engine_ctx.clone())
104            .enable_request_stream(false)
105            .enable_response_stream(true)
106            .build()
107            .unwrap();
108
109        // register our needs with the data plane
110        // todo - generalize this with a generic data plane object which hides the specific transports
111        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
112
113        // validate and unwrap the RegisteredStream object
114        let pending_response_stream = match pending_connections.into_parts() {
115            (None, Some(recv_stream)) => recv_stream,
116            _ => {
117                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
118            }
119        };
120
121        // separate out the connection info and the stream provider from the registered stream
122        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
123
124        // package up the connection info as part of the "header" component of the two part message
125        // used to issue the request on the
126        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
127        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
128        let control_message = RequestControlMessage {
129            id: engine_ctx.id().to_string(),
130            request_type: RequestType::SingleIn,
131            response_type: ResponseType::ManyOut,
132            connection_info,
133        };
134
135        // next build the two part message where we package the connection info and the request into
136        // a single Vec<u8> that can be sent over the wire.
137        // --- package this up in the WorkQueuePublisher ---
138        let ctrl = serde_json::to_vec(&control_message)?;
139        let data = serde_json::to_vec(&request)?;
140
141        tracing::trace!(
142            request_id,
143            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
144            ctrl.len(),
145            data.len()
146        );
147
148        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
149
150        // the request plane / work queue should provide a two part message codec that can be used
151        // or it should take a two part message directly
152        // todo - update this
153        let codec = TwoPartCodec::default();
154        let buffer = codec.encode_message(msg)?;
155
156        // TRANSPORT ABSTRACT REQUIRED - END HERE
157
158        // Send request using unified client interface
159        tracing::trace!(
160            request_id,
161            transport = self.req_client.transport_name(),
162            address = %address,
163            "Sending request via request plane client"
164        );
165
166        // Prepare trace headers using shared helper
167        let mut headers = std::collections::HashMap::new();
168        inject_trace_headers_into_map(&mut headers);
169
170        // Send request (works for all transport types)
171        let _response = self
172            .req_client
173            .send_request(address, buffer, headers)
174            .await?;
175
176        tracing::trace!(request_id, "awaiting transport handshake");
177        let response_stream = response_stream_provider
178            .await
179            .map_err(|_| PipelineError::DetachedStreamReceiver)?
180            .map_err(PipelineError::ConnectionFailed)?;
181
182        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
183        let mut is_complete_final = false;
184        let stream = tokio_stream::StreamNotifyClose::new(
185            tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
186        )
187        .filter_map(move |res| {
188            if let Some(res_bytes) = res {
189                if is_complete_final {
190                    return Some(U::from_err(
191                        Error::msg(
192                            "Response received after generation ended - this should never happen",
193                        )
194                        .into(),
195                    ));
196                }
197                match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
198                    Ok(item) => {
199                        is_complete_final = item.complete_final;
200                        if let Some(data) = item.data {
201                            Some(data)
202                        } else if is_complete_final {
203                            None
204                        } else {
205                            Some(U::from_err(
206                                Error::msg("Empty response received - this should never happen")
207                                    .into(),
208                            ))
209                        }
210                    }
211                    Err(err) => {
212                        // legacy log print
213                        let json_str = String::from_utf8_lossy(&res_bytes);
214                        tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
215
216                        Some(U::from_err(Error::new(err).into()))
217                    }
218                }
219            } else if is_complete_final {
220                // end of stream
221                None
222            } else if engine_ctx_.is_stopped() {
223                // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
224                // 'is_killed()' here because it implies the stream ended abnormally which should be
225                // handled by the error branch below.
226                tracing::debug!("Request cancelled and then trying to read a response");
227                None
228            } else {
229                // stream ended unexpectedly
230                tracing::debug!("{STREAM_ERR_MSG}");
231                Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
232            }
233        });
234
235        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
236    }
237}