dynamo_runtime/pipeline/network/egress/
push.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use async_nats::client::Client;
17use tracing as log;
18
19use super::*;
20use crate::Result;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24enum RequestType {
25    SingleIn,
26    ManyIn,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31enum ResponseType {
32    SingleOut,
33    ManyOut,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct RequestControlMessage {
38    id: String,
39    request_type: RequestType,
40    response_type: ResponseType,
41    connection_info: ConnectionInfo,
42}
43
44pub type PushRouter<In, Out> =
45    Arc<dyn AsyncEngine<SingleIn<AddressedRequest<In>>, ManyOut<Out>, Error>>;
46
47pub struct AddressedRequest<T> {
48    request: T,
49    address: String,
50}
51
52impl<T> AddressedRequest<T> {
53    pub fn new(request: T, address: String) -> Self {
54        Self { request, address }
55    }
56
57    fn into_parts(self) -> (T, String) {
58        (self.request, self.address)
59    }
60}
61
62pub struct AddressedPushRouter {
63    // todo: generalize with a generic
64    req_transport: Client,
65
66    // todo: generalize with a generic
67    resp_transport: Arc<tcp::server::TcpStreamServer>,
68}
69
70impl AddressedPushRouter {
71    pub fn new(
72        req_transport: Client,
73        resp_transport: Arc<tcp::server::TcpStreamServer>,
74    ) -> Result<Arc<Self>> {
75        Ok(Arc::new(Self {
76            req_transport,
77            resp_transport,
78        }))
79    }
80}
81
82#[async_trait]
83impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
84where
85    T: Data + Serialize,
86    U: Data + for<'de> Deserialize<'de>,
87{
88    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
89        let request_id = request.context().id().to_string();
90        let (addressed_request, context) = request.transfer(());
91        let (request, address) = addressed_request.into_parts();
92        let engine_ctx = context.context();
93
94        // registration options for the data plane in a singe in / many out configuration
95        let options = StreamOptions::builder()
96            .context(engine_ctx.clone())
97            .enable_request_stream(false)
98            .enable_response_stream(true)
99            .build()
100            .unwrap();
101
102        // register our needs with the data plane
103        // todo - generalize this with a generic data plane object which hides the specific transports
104        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
105
106        // validate and unwrap the RegisteredStream object
107        let pending_response_stream = match pending_connections.into_parts() {
108            (None, Some(recv_stream)) => recv_stream,
109            _ => {
110                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
111            }
112        };
113
114        // separate out the the connection info and the stream provider from the registered stream
115        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
116
117        // package up the connection info as part of the "header" component of the two part message
118        // used to issue the request on the
119        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
120        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
121        let control_message = RequestControlMessage {
122            id: engine_ctx.id().to_string(),
123            request_type: RequestType::SingleIn,
124            response_type: ResponseType::ManyOut,
125            connection_info,
126        };
127
128        // next build the two part message where we package the connection info and the request into
129        // a single Vec<u8> that can be sent over the wire.
130        // --- package this up in the WorkQueuePublisher ---
131        let ctrl = serde_json::to_vec(&control_message)?;
132        let data = serde_json::to_vec(&request)?;
133
134        log::trace!(
135            request_id,
136            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
137            ctrl.len(),
138            data.len()
139        );
140
141        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
142
143        // the request plane / work queue should provide a two part message codec that can be used
144        // or it should take a two part message directly
145        // todo - update this
146        let codec = TwoPartCodec::default();
147        let buffer = codec.encode_message(msg)?;
148
149        // TRANSPORT ABSTRACT REQUIRED - END HERE
150
151        log::trace!(request_id, "enqueueing two-part message to nats");
152
153        // we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats
154        // will handle this for us
155        let _response = self
156            .req_transport
157            .request(address.to_string(), buffer)
158            .await?;
159
160        log::trace!(request_id, "awaiting transport handshake");
161        let response_stream = response_stream_provider
162            .await
163            .map_err(|_| PipelineError::DetatchedStreamReceiver)?
164            .map_err(PipelineError::ConnectionFailed)?;
165
166        let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx);
167
168        let stream = stream.filter_map(|msg| async move {
169            match serde_json::from_slice::<U>(&msg) {
170                Ok(r) => Some(r),
171                Err(err) => {
172                    let json_str = String::from_utf8_lossy(&msg);
173                    log::warn!(%err, %json_str, "Failed deserializing JSON to response");
174                    None
175                }
176            }
177        });
178
179        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
180    }
181}