dynamo_runtime/pipeline/nodes/
sources.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use crate::pipeline::{AsyncEngine, PipelineIO};
18
19mod base;
20mod common;
21
22pub struct Frontend<In: PipelineIO, Out: PipelineIO> {
23    edge: OnceLock<Edge<In>>,
24    sinks: Arc<Mutex<HashMap<String, oneshot::Sender<Out>>>>,
25}
26
27/// A [`ServiceFrontend`] is the interface for an [`AsyncEngine<SingleIn<Context<In>>, ManyOut<Annotated<Out>>, Error>`]
28pub struct ServiceFrontend<In: PipelineIO, Out: PipelineIO> {
29    inner: Frontend<In, Out>,
30}
31
32pub struct SegmentSource<In: PipelineIO, Out: PipelineIO> {
33    inner: Frontend<In, Out>,
34}
35
36// impl<In: DataType, Out: PipelineIO> Frontend<In, Out> {
37//     pub fn new() -> Arc<Self> {
38//         Arc::new(Self {
39//             edge: OnceLock::new(),
40//             sinks: Arc::new(Mutex::new(HashMap::new())),
41//         })
42//     }
43// }
44
45// impl<In: DataType, Out: PipelineIO> SegmentSource<In, Out> {
46//     pub fn new() -> Arc<Self> {
47//         Arc::new(Self {
48//             edge: OnceLock::new(),
49//             sinks: Arc::new(Mutex::new(HashMap::new())),
50//         })
51//     }
52// }
53
54// #[async_trait]
55// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for Frontend<In, Out> {
56//     async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
57//         self.edge
58//             .get()
59//             .ok_or(PipelineError::NoEdge)?
60//             .write(data)
61//             .await
62//     }
63
64//     fn set_edge(
65//         &self,
66//         edge: Edge<Context<In>>>,
67//         _: private::Token,
68//     ) -> Result<(), PipelineError> {
69//         self.edge
70//             .set(edge)
71//             .map_err(|_| PipelineError::EdgeAlreadySet)?;
72//         Ok(())
73//     }
74// }
75
76// #[async_trait]
77// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for Frontend<In, Out> {
78//     async fn on_data(
79//         &self,
80//         data: PipelineStream<Out>,
81//         _: private::Token,
82//     ) -> Result<(), PipelineError> {
83//         let context = data.context();
84
85//         let mut sinks = self.sinks.lock().unwrap();
86//         let tx = sinks
87//             .remove(context.id())
88//             .ok_or(PipelineError::DetachedStreamReceiver)
89//             .map_err(|e| {
90//                 data.context().stop_generating();
91//                 e
92//             })?;
93//         drop(sinks);
94
95//         let ctx = data.context();
96//         tx.send(data)
97//             .map_err(|_| PipelineError::DetachedStreamReceiver)
98//             .map_err(|e| {
99//                 ctx.stop_generating();
100//                 e
101//             })
102//     }
103// }
104
105// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for Frontend<In, Out> {
106//     fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
107//         let edge = Edge::new(sink.clone());
108//         self.set_edge(edge.into(), private::Token {})?;
109//         Ok(sink)
110//     }
111// }
112
113// #[async_trait]
114// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
115//     for Frontend<In, Out>
116// {
117//     async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
118//         let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
119//         {
120//             let mut sinks = self.sinks.lock().unwrap();
121//             sinks.insert(request.id().to_string(), tx);
122//         }
123//         self.on_next(request, private::Token {}).await?;
124//         rx.await.map_err(|_| PipelineError::DetachedStreamSender)
125//     }
126// }
127
128// // SegmentSource
129
130// #[async_trait]
131// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for SegmentSource<In, Out> {
132//     async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
133//         self.edge
134//             .get()
135//             .ok_or(PipelineError::NoEdge)?
136//             .write(data)
137//             .await
138//     }
139
140//     fn set_edge(
141//         &self,
142//         edge: Edge<Context<In>>>,
143//         _: private::Token,
144//     ) -> Result<(), PipelineError> {
145//         self.edge
146//             .set(edge)
147//             .map_err(|_| PipelineError::EdgeAlreadySet)?;
148//         Ok(())
149//     }
150// }
151
152// #[async_trait]
153// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for SegmentSource<In, Out> {
154//     async fn on_data(
155//         &self,
156//         data: PipelineStream<Out>,
157//         _: private::Token,
158//     ) -> Result<(), PipelineError> {
159//         let context = data.context();
160
161//         let mut sinks = self.sinks.lock().unwrap();
162//         let tx = sinks
163//             .remove(context.id())
164//             .ok_or(PipelineError::DetachedStreamReceiver)
165//             .map_err(|e| {
166//                 data.context().stop_generating();
167//                 e
168//             })?;
169//         drop(sinks);
170
171//         let ctx = data.context();
172//         tx.send(data)
173//             .map_err(|_| PipelineError::DetachedStreamReceiver)
174//             .map_err(|e| {
175//                 ctx.stop_generating();
176//                 e
177//             })
178//     }
179// }
180
181// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for SegmentSource<In, Out> {
182//     fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
183//         let edge = Edge::new(sink.clone());
184//         self.set_edge(edge.into(), private::Token {})?;
185//         Ok(sink)
186//     }
187// }
188
189// #[async_trait]
190// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
191//     for SegmentSource<In, Out>
192// {
193//     async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
194//         let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
195//         {
196//             let mut sinks = self.sinks.lock().unwrap();
197//             sinks.insert(request.id().to_string(), tx);
198//         }
199//         self.on_next(request, private::Token {}).await?;
200//         rx.await.map_err(|_| PipelineError::DetachedStreamSender)
201//     }
202// }
203
204// #[cfg(test)]
205
206// mod tests {
207//     use super::*;
208
209//     #[tokio::test]
210//     async fn test_pipeline_source_no_edge() {
211//         let source = Frontend::<(), ()>::new();
212//         let stream = source.generate(().into()).await;
213//         match stream {
214//             Err(PipelineError::NoEdge) => (),
215//             _ => panic!("Expected NoEdge error"),
216//         }
217//     }
218// }
219
220// pub struct IngressPort<In, Out: PipelineIO> {
221//     edge: OnceLock<ServiceEngine<In, Out>>,
222// }
223
224// impl<In, Out> IngressPort<In, Out>
225// where
226//     In: for<'de> Deserialize<'de> + DataType,
227//     Out: PipelineIO + Serialize,
228// {
229//     pub fn new() -> Arc<Self> {
230//         Arc::new(IngressPort {
231//             edge: OnceLock::new(),
232//         })
233//     }
234// }
235
236// #[async_trait]
237// impl<In, Out> AsyncEngine<Context<Vec<u8>>, Vec<u8>> for IngressPort<In, Out>
238// where
239//     In: for<'de> Deserialize<'de> + DataType,
240//     Out: PipelineIO + Serialize,
241// {
242//     async fn generate(
243//         &self,
244//         request: Context<Vec<u8>>,
245//     ) -> Result<EngineStream<Vec<u8>>, PipelineError> {
246//         // Deserialize request
247//         let request = request.try_map(|bytes| {
248//             bincode::deserialize::<In>(&bytes)
249//                 .map_err(|err| PipelineError(format!("Failed to deserialize request: {}", err)))
250//         })?;
251
252//         // Forward request to edge
253//         let stream = self
254//             .edge
255//             .get()
256//             .ok_or(PipelineError("No engine to forward request to".to_string()))?
257//             .generate(request)
258//             .await?;
259
260//         // Serialize response stream
261
262//         let stream =
263//             stream.map(|resp| bincode::serialize(&resp).expect("Failed to serialize response"));
264
265//         Err(PipelineError(format!("Not implemented")))
266//     }
267// }
268
269// fn convert_stream<T, U>(
270//     stream: impl Stream<Item = ServerStream<T>> + Send + 'static,
271//     ctx: Arc<dyn AsyncEngineContext>,
272//     transform: Arc<dyn Fn(T) -> Result<U, StreamError> + Send + Sync>,
273// ) -> Pin<Box<dyn Stream<Item = ServerStream<U>> + Send>>
274// where
275//     T: Send + 'static,
276//     U: Send + 'static,
277// {
278//     Box::pin(stream.flat_map(move |item| {
279//         let ctx = ctx.clone();
280//         let transform = transform.clone();
281//         match item {
282//             ServerStream::Data(data) => match transform(data) {
283//                 Ok(transformed) => futures::stream::iter(vec![ServerStream::Data(transformed)]),
284//                 Err(e) => {
285//                     // Trigger cancellation and propagate the error, followed by Sentinel
286//                     ctx.stop_generating();
287//                     futures::stream::iter(vec![ServerStream::Error(e), ServerStream::Sentinel])
288//                 }
289//             },
290//             other => futures::stream::iter(vec![other]),
291//         }
292//     })
293//     // Use take_while to stop processing when encountering the Sentinel
294//     .take_while(|item| futures::future::ready(!matches!(item, ServerStream::Sentinel))))
295// }