dynamo_runtime/pipeline/nodes/sources.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5use crate::pipeline::{AsyncEngine, PipelineIO};
6
7mod base;
8mod common;
9
10pub struct Frontend<In: PipelineIO, Out: PipelineIO> {
11 edge: OnceLock<Edge<In>>,
12 sinks: Arc<Mutex<HashMap<String, oneshot::Sender<Out>>>>,
13}
14
15/// A [`ServiceFrontend`] is the interface for an [`AsyncEngine<SingleIn<Context<In>>, ManyOut<Annotated<Out>>, Error>`]
16pub struct ServiceFrontend<In: PipelineIO, Out: PipelineIO> {
17 inner: Frontend<In, Out>,
18}
19
20pub struct SegmentSource<In: PipelineIO, Out: PipelineIO> {
21 inner: Frontend<In, Out>,
22}
23
24// impl<In: DataType, Out: PipelineIO> Frontend<In, Out> {
25// pub fn new() -> Arc<Self> {
26// Arc::new(Self {
27// edge: OnceLock::new(),
28// sinks: Arc::new(Mutex::new(HashMap::new())),
29// })
30// }
31// }
32
33// impl<In: DataType, Out: PipelineIO> SegmentSource<In, Out> {
34// pub fn new() -> Arc<Self> {
35// Arc::new(Self {
36// edge: OnceLock::new(),
37// sinks: Arc::new(Mutex::new(HashMap::new())),
38// })
39// }
40// }
41
42// #[async_trait]
43// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for Frontend<In, Out> {
44// async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
45// self.edge
46// .get()
47// .ok_or(PipelineError::NoEdge)?
48// .write(data)
49// .await
50// }
51
52// fn set_edge(
53// &self,
54// edge: Edge<Context<In>>>,
55// _: private::Token,
56// ) -> Result<(), PipelineError> {
57// self.edge
58// .set(edge)
59// .map_err(|_| PipelineError::EdgeAlreadySet)?;
60// Ok(())
61// }
62// }
63
64// #[async_trait]
65// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for Frontend<In, Out> {
66// async fn on_data(
67// &self,
68// data: PipelineStream<Out>,
69// _: private::Token,
70// ) -> Result<(), PipelineError> {
71// let context = data.context();
72
73// let mut sinks = self.sinks.lock().unwrap();
74// let tx = sinks
75// .remove(context.id())
76// .ok_or(PipelineError::DetachedStreamReceiver)
77// .map_err(|e| {
78// data.context().stop_generating();
79// e
80// })?;
81// drop(sinks);
82
83// let ctx = data.context();
84// tx.send(data)
85// .map_err(|_| PipelineError::DetachedStreamReceiver)
86// .map_err(|e| {
87// ctx.stop_generating();
88// e
89// })
90// }
91// }
92
93// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for Frontend<In, Out> {
94// fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
95// let edge = Edge::new(sink.clone());
96// self.set_edge(edge.into(), private::Token {})?;
97// Ok(sink)
98// }
99// }
100
101// #[async_trait]
102// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
103// for Frontend<In, Out>
104// {
105// async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
106// let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
107// {
108// let mut sinks = self.sinks.lock().unwrap();
109// sinks.insert(request.id().to_string(), tx);
110// }
111// self.on_next(request, private::Token {}).await?;
112// rx.await.map_err(|_| PipelineError::DetachedStreamSender)
113// }
114// }
115
116// // SegmentSource
117
118// #[async_trait]
119// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for SegmentSource<In, Out> {
120// async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
121// self.edge
122// .get()
123// .ok_or(PipelineError::NoEdge)?
124// .write(data)
125// .await
126// }
127
128// fn set_edge(
129// &self,
130// edge: Edge<Context<In>>>,
131// _: private::Token,
132// ) -> Result<(), PipelineError> {
133// self.edge
134// .set(edge)
135// .map_err(|_| PipelineError::EdgeAlreadySet)?;
136// Ok(())
137// }
138// }
139
140// #[async_trait]
141// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for SegmentSource<In, Out> {
142// async fn on_data(
143// &self,
144// data: PipelineStream<Out>,
145// _: private::Token,
146// ) -> Result<(), PipelineError> {
147// let context = data.context();
148
149// let mut sinks = self.sinks.lock().unwrap();
150// let tx = sinks
151// .remove(context.id())
152// .ok_or(PipelineError::DetachedStreamReceiver)
153// .map_err(|e| {
154// data.context().stop_generating();
155// e
156// })?;
157// drop(sinks);
158
159// let ctx = data.context();
160// tx.send(data)
161// .map_err(|_| PipelineError::DetachedStreamReceiver)
162// .map_err(|e| {
163// ctx.stop_generating();
164// e
165// })
166// }
167// }
168
169// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for SegmentSource<In, Out> {
170// fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
171// let edge = Edge::new(sink.clone());
172// self.set_edge(edge.into(), private::Token {})?;
173// Ok(sink)
174// }
175// }
176
177// #[async_trait]
178// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
179// for SegmentSource<In, Out>
180// {
181// async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
182// let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
183// {
184// let mut sinks = self.sinks.lock().unwrap();
185// sinks.insert(request.id().to_string(), tx);
186// }
187// self.on_next(request, private::Token {}).await?;
188// rx.await.map_err(|_| PipelineError::DetachedStreamSender)
189// }
190// }
191
192// #[cfg(test)]
193
194// mod tests {
195// use super::*;
196
197// #[tokio::test]
198// async fn test_pipeline_source_no_edge() {
199// let source = Frontend::<(), ()>::new();
200// let stream = source.generate(().into()).await;
201// match stream {
202// Err(PipelineError::NoEdge) => (),
203// _ => panic!("Expected NoEdge error"),
204// }
205// }
206// }
207
208// pub struct IngressPort<In, Out: PipelineIO> {
209// edge: OnceLock<ServiceEngine<In, Out>>,
210// }
211
212// impl<In, Out> IngressPort<In, Out>
213// where
214// In: for<'de> Deserialize<'de> + DataType,
215// Out: PipelineIO + Serialize,
216// {
217// pub fn new() -> Arc<Self> {
218// Arc::new(IngressPort {
219// edge: OnceLock::new(),
220// })
221// }
222// }
223
224// #[async_trait]
225// impl<In, Out> AsyncEngine<Context<Vec<u8>>, Vec<u8>> for IngressPort<In, Out>
226// where
227// In: for<'de> Deserialize<'de> + DataType,
228// Out: PipelineIO + Serialize,
229// {
230// async fn generate(
231// &self,
232// request: Context<Vec<u8>>,
233// ) -> Result<EngineStream<Vec<u8>>, PipelineError> {
234// // Deserialize request
235// let request = request.try_map(|bytes| {
236// bincode::deserialize::<In>(&bytes)
237// .map_err(|err| PipelineError(format!("Failed to deserialize request: {}", err)))
238// })?;
239
240// // Forward request to edge
241// let stream = self
242// .edge
243// .get()
244// .ok_or(PipelineError("No engine to forward request to".to_string()))?
245// .generate(request)
246// .await?;
247
248// // Serialize response stream
249
250// let stream =
251// stream.map(|resp| bincode::serialize(&resp).expect("Failed to serialize response"));
252
253// Err(PipelineError(format!("Not implemented")))
254// }
255// }
256
257// fn convert_stream<T, U>(
258// stream: impl Stream<Item = ServerStream<T>> + Send + 'static,
259// ctx: Arc<dyn AsyncEngineContext>,
260// transform: Arc<dyn Fn(T) -> Result<U, StreamError> + Send + Sync>,
261// ) -> Pin<Box<dyn Stream<Item = ServerStream<U>> + Send>>
262// where
263// T: Send + 'static,
264// U: Send + 'static,
265// {
266// Box::pin(stream.flat_map(move |item| {
267// let ctx = ctx.clone();
268// let transform = transform.clone();
269// match item {
270// ServerStream::Data(data) => match transform(data) {
271// Ok(transformed) => futures::stream::iter(vec![ServerStream::Data(transformed)]),
272// Err(e) => {
273// // Trigger cancellation and propagate the error, followed by Sentinel
274// ctx.stop_generating();
275// futures::stream::iter(vec![ServerStream::Error(e), ServerStream::Sentinel])
276// }
277// },
278// other => futures::stream::iter(vec![other]),
279// }
280// })
281// // Use take_while to stop processing when encountering the Sentinel
282// .take_while(|item| futures::future::ready(!matches!(item, ServerStream::Sentinel))))
283// }