Skip to main content

dynamo_runtime/pipeline/nodes/
sources.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5use crate::pipeline::{AsyncEngine, PipelineIO};
6
7mod base;
8mod common;
9
10pub struct Frontend<In: PipelineIO, Out: PipelineIO> {
11    edge: OnceLock<Edge<In>>,
12    sinks: Arc<Mutex<HashMap<String, oneshot::Sender<Out>>>>,
13}
14
15/// A [`ServiceFrontend`] is the interface for an [`AsyncEngine<SingleIn<Context<In>>, ManyOut<Annotated<Out>>, Error>`]
16pub struct ServiceFrontend<In: PipelineIO, Out: PipelineIO> {
17    inner: Frontend<In, Out>,
18}
19
20pub struct SegmentSource<In: PipelineIO, Out: PipelineIO> {
21    inner: Frontend<In, Out>,
22}
23
24// impl<In: DataType, Out: PipelineIO> Frontend<In, Out> {
25//     pub fn new() -> Arc<Self> {
26//         Arc::new(Self {
27//             edge: OnceLock::new(),
28//             sinks: Arc::new(Mutex::new(HashMap::new())),
29//         })
30//     }
31// }
32
33// impl<In: DataType, Out: PipelineIO> SegmentSource<In, Out> {
34//     pub fn new() -> Arc<Self> {
35//         Arc::new(Self {
36//             edge: OnceLock::new(),
37//             sinks: Arc::new(Mutex::new(HashMap::new())),
38//         })
39//     }
40// }
41
42// #[async_trait]
43// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for Frontend<In, Out> {
44//     async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
45//         self.edge
46//             .get()
47//             .ok_or(PipelineError::NoEdge)?
48//             .write(data)
49//             .await
50//     }
51
52//     fn set_edge(
53//         &self,
54//         edge: Edge<Context<In>>>,
55//         _: private::Token,
56//     ) -> Result<(), PipelineError> {
57//         self.edge
58//             .set(edge)
59//             .map_err(|_| PipelineError::EdgeAlreadySet)?;
60//         Ok(())
61//     }
62// }
63
64// #[async_trait]
65// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for Frontend<In, Out> {
66//     async fn on_data(
67//         &self,
68//         data: PipelineStream<Out>,
69//         _: private::Token,
70//     ) -> Result<(), PipelineError> {
71//         let context = data.context();
72
73//         let mut sinks = self.sinks.lock().unwrap();
74//         let tx = sinks
75//             .remove(context.id())
76//             .ok_or(PipelineError::DetachedStreamReceiver)
77//             .map_err(|e| {
78//                 data.context().stop_generating();
79//                 e
80//             })?;
81//         drop(sinks);
82
83//         let ctx = data.context();
84//         tx.send(data)
85//             .map_err(|_| PipelineError::DetachedStreamReceiver)
86//             .map_err(|e| {
87//                 ctx.stop_generating();
88//                 e
89//             })
90//     }
91// }
92
93// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for Frontend<In, Out> {
94//     fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
95//         let edge = Edge::new(sink.clone());
96//         self.set_edge(edge.into(), private::Token {})?;
97//         Ok(sink)
98//     }
99// }
100
101// #[async_trait]
102// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
103//     for Frontend<In, Out>
104// {
105//     async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
106//         let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
107//         {
108//             let mut sinks = self.sinks.lock().unwrap();
109//             sinks.insert(request.id().to_string(), tx);
110//         }
111//         self.on_next(request, private::Token {}).await?;
112//         rx.await.map_err(|_| PipelineError::DetachedStreamSender)
113//     }
114// }
115
116// // SegmentSource
117
118// #[async_trait]
119// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for SegmentSource<In, Out> {
120//     async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
121//         self.edge
122//             .get()
123//             .ok_or(PipelineError::NoEdge)?
124//             .write(data)
125//             .await
126//     }
127
128//     fn set_edge(
129//         &self,
130//         edge: Edge<Context<In>>>,
131//         _: private::Token,
132//     ) -> Result<(), PipelineError> {
133//         self.edge
134//             .set(edge)
135//             .map_err(|_| PipelineError::EdgeAlreadySet)?;
136//         Ok(())
137//     }
138// }
139
140// #[async_trait]
141// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for SegmentSource<In, Out> {
142//     async fn on_data(
143//         &self,
144//         data: PipelineStream<Out>,
145//         _: private::Token,
146//     ) -> Result<(), PipelineError> {
147//         let context = data.context();
148
149//         let mut sinks = self.sinks.lock().unwrap();
150//         let tx = sinks
151//             .remove(context.id())
152//             .ok_or(PipelineError::DetachedStreamReceiver)
153//             .map_err(|e| {
154//                 data.context().stop_generating();
155//                 e
156//             })?;
157//         drop(sinks);
158
159//         let ctx = data.context();
160//         tx.send(data)
161//             .map_err(|_| PipelineError::DetachedStreamReceiver)
162//             .map_err(|e| {
163//                 ctx.stop_generating();
164//                 e
165//             })
166//     }
167// }
168
169// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for SegmentSource<In, Out> {
170//     fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
171//         let edge = Edge::new(sink.clone());
172//         self.set_edge(edge.into(), private::Token {})?;
173//         Ok(sink)
174//     }
175// }
176
177// #[async_trait]
178// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
179//     for SegmentSource<In, Out>
180// {
181//     async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
182//         let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
183//         {
184//             let mut sinks = self.sinks.lock().unwrap();
185//             sinks.insert(request.id().to_string(), tx);
186//         }
187//         self.on_next(request, private::Token {}).await?;
188//         rx.await.map_err(|_| PipelineError::DetachedStreamSender)
189//     }
190// }
191
192// #[cfg(test)]
193
194// mod tests {
195//     use super::*;
196
197//     #[tokio::test]
198//     async fn test_pipeline_source_no_edge() {
199//         let source = Frontend::<(), ()>::new();
200//         let stream = source.generate(().into()).await;
201//         match stream {
202//             Err(PipelineError::NoEdge) => (),
203//             _ => panic!("Expected NoEdge error"),
204//         }
205//     }
206// }
207
208// pub struct IngressPort<In, Out: PipelineIO> {
209//     edge: OnceLock<ServiceEngine<In, Out>>,
210// }
211
212// impl<In, Out> IngressPort<In, Out>
213// where
214//     In: for<'de> Deserialize<'de> + DataType,
215//     Out: PipelineIO + Serialize,
216// {
217//     pub fn new() -> Arc<Self> {
218//         Arc::new(IngressPort {
219//             edge: OnceLock::new(),
220//         })
221//     }
222// }
223
224// #[async_trait]
225// impl<In, Out> AsyncEngine<Context<Vec<u8>>, Vec<u8>> for IngressPort<In, Out>
226// where
227//     In: for<'de> Deserialize<'de> + DataType,
228//     Out: PipelineIO + Serialize,
229// {
230//     async fn generate(
231//         &self,
232//         request: Context<Vec<u8>>,
233//     ) -> Result<EngineStream<Vec<u8>>, PipelineError> {
234//         // Deserialize request
235//         let request = request.try_map(|bytes| {
236//             bincode::deserialize::<In>(&bytes)
237//                 .map_err(|err| PipelineError(format!("Failed to deserialize request: {}", err)))
238//         })?;
239
240//         // Forward request to edge
241//         let stream = self
242//             .edge
243//             .get()
244//             .ok_or(PipelineError("No engine to forward request to".to_string()))?
245//             .generate(request)
246//             .await?;
247
248//         // Serialize response stream
249
250//         let stream =
251//             stream.map(|resp| bincode::serialize(&resp).expect("Failed to serialize response"));
252
253//         Err(PipelineError(format!("Not implemented")))
254//     }
255// }
256
257// fn convert_stream<T, U>(
258//     stream: impl Stream<Item = ServerStream<T>> + Send + 'static,
259//     ctx: Arc<dyn AsyncEngineContext>,
260//     transform: Arc<dyn Fn(T) -> Result<U, StreamError> + Send + Sync>,
261// ) -> Pin<Box<dyn Stream<Item = ServerStream<U>> + Send>>
262// where
263//     T: Send + 'static,
264//     U: Send + 'static,
265// {
266//     Box::pin(stream.flat_map(move |item| {
267//         let ctx = ctx.clone();
268//         let transform = transform.clone();
269//         match item {
270//             ServerStream::Data(data) => match transform(data) {
271//                 Ok(transformed) => futures::stream::iter(vec![ServerStream::Data(transformed)]),
272//                 Err(e) => {
273//                     // Trigger cancellation and propagate the error, followed by Sentinel
274//                     ctx.stop_generating();
275//                     futures::stream::iter(vec![ServerStream::Error(e), ServerStream::Sentinel])
276//                 }
277//             },
278//             other => futures::stream::iter(vec![other]),
279//         }
280//     })
281//     // Use take_while to stop processing when encountering the Sentinel
282//     .take_while(|item| futures::future::ready(!matches!(item, ServerStream::Sentinel))))
283// }