dynamo_runtime/pipeline/nodes/sources.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use crate::pipeline::{AsyncEngine, PipelineIO};
18
19mod base;
20mod common;
21
22pub struct Frontend<In: PipelineIO, Out: PipelineIO> {
23 edge: OnceLock<Edge<In>>,
24 sinks: Arc<Mutex<HashMap<String, oneshot::Sender<Out>>>>,
25}
26
27/// A [`ServiceFrontend`] is the interface for an [`AsyncEngine<SingleIn<Context<In>>, ManyOut<Annotated<Out>>, Error>`]
28pub struct ServiceFrontend<In: PipelineIO, Out: PipelineIO> {
29 inner: Frontend<In, Out>,
30}
31
32pub struct SegmentSource<In: PipelineIO, Out: PipelineIO> {
33 inner: Frontend<In, Out>,
34}
35
36// impl<In: DataType, Out: PipelineIO> Frontend<In, Out> {
37// pub fn new() -> Arc<Self> {
38// Arc::new(Self {
39// edge: OnceLock::new(),
40// sinks: Arc::new(Mutex::new(HashMap::new())),
41// })
42// }
43// }
44
45// impl<In: DataType, Out: PipelineIO> SegmentSource<In, Out> {
46// pub fn new() -> Arc<Self> {
47// Arc::new(Self {
48// edge: OnceLock::new(),
49// sinks: Arc::new(Mutex::new(HashMap::new())),
50// })
51// }
52// }
53
54// #[async_trait]
55// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for Frontend<In, Out> {
56// async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
57// self.edge
58// .get()
59// .ok_or(PipelineError::NoEdge)?
60// .write(data)
61// .await
62// }
63
64// fn set_edge(
65// &self,
66// edge: Edge<Context<In>>>,
67// _: private::Token,
68// ) -> Result<(), PipelineError> {
69// self.edge
70// .set(edge)
71// .map_err(|_| PipelineError::EdgeAlreadySet)?;
72// Ok(())
73// }
74// }
75
76// #[async_trait]
77// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for Frontend<In, Out> {
78// async fn on_data(
79// &self,
80// data: PipelineStream<Out>,
81// _: private::Token,
82// ) -> Result<(), PipelineError> {
83// let context = data.context();
84
85// let mut sinks = self.sinks.lock().unwrap();
86// let tx = sinks
87// .remove(context.id())
88// .ok_or(PipelineError::DetachedStreamReceiver)
89// .map_err(|e| {
90// data.context().stop_generating();
91// e
92// })?;
93// drop(sinks);
94
95// let ctx = data.context();
96// tx.send(data)
97// .map_err(|_| PipelineError::DetachedStreamReceiver)
98// .map_err(|e| {
99// ctx.stop_generating();
100// e
101// })
102// }
103// }
104
105// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for Frontend<In, Out> {
106// fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
107// let edge = Edge::new(sink.clone());
108// self.set_edge(edge.into(), private::Token {})?;
109// Ok(sink)
110// }
111// }
112
113// #[async_trait]
114// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
115// for Frontend<In, Out>
116// {
117// async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
118// let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
119// {
120// let mut sinks = self.sinks.lock().unwrap();
121// sinks.insert(request.id().to_string(), tx);
122// }
123// self.on_next(request, private::Token {}).await?;
124// rx.await.map_err(|_| PipelineError::DetachedStreamSender)
125// }
126// }
127
128// // SegmentSource
129
130// #[async_trait]
131// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for SegmentSource<In, Out> {
132// async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
133// self.edge
134// .get()
135// .ok_or(PipelineError::NoEdge)?
136// .write(data)
137// .await
138// }
139
140// fn set_edge(
141// &self,
142// edge: Edge<Context<In>>>,
143// _: private::Token,
144// ) -> Result<(), PipelineError> {
145// self.edge
146// .set(edge)
147// .map_err(|_| PipelineError::EdgeAlreadySet)?;
148// Ok(())
149// }
150// }
151
152// #[async_trait]
153// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for SegmentSource<In, Out> {
154// async fn on_data(
155// &self,
156// data: PipelineStream<Out>,
157// _: private::Token,
158// ) -> Result<(), PipelineError> {
159// let context = data.context();
160
161// let mut sinks = self.sinks.lock().unwrap();
162// let tx = sinks
163// .remove(context.id())
164// .ok_or(PipelineError::DetachedStreamReceiver)
165// .map_err(|e| {
166// data.context().stop_generating();
167// e
168// })?;
169// drop(sinks);
170
171// let ctx = data.context();
172// tx.send(data)
173// .map_err(|_| PipelineError::DetachedStreamReceiver)
174// .map_err(|e| {
175// ctx.stop_generating();
176// e
177// })
178// }
179// }
180
181// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for SegmentSource<In, Out> {
182// fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
183// let edge = Edge::new(sink.clone());
184// self.set_edge(edge.into(), private::Token {})?;
185// Ok(sink)
186// }
187// }
188
189// #[async_trait]
190// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
191// for SegmentSource<In, Out>
192// {
193// async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
194// let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
195// {
196// let mut sinks = self.sinks.lock().unwrap();
197// sinks.insert(request.id().to_string(), tx);
198// }
199// self.on_next(request, private::Token {}).await?;
200// rx.await.map_err(|_| PipelineError::DetachedStreamSender)
201// }
202// }
203
204// #[cfg(test)]
205
206// mod tests {
207// use super::*;
208
209// #[tokio::test]
210// async fn test_pipeline_source_no_edge() {
211// let source = Frontend::<(), ()>::new();
212// let stream = source.generate(().into()).await;
213// match stream {
214// Err(PipelineError::NoEdge) => (),
215// _ => panic!("Expected NoEdge error"),
216// }
217// }
218// }
219
220// pub struct IngressPort<In, Out: PipelineIO> {
221// edge: OnceLock<ServiceEngine<In, Out>>,
222// }
223
224// impl<In, Out> IngressPort<In, Out>
225// where
226// In: for<'de> Deserialize<'de> + DataType,
227// Out: PipelineIO + Serialize,
228// {
229// pub fn new() -> Arc<Self> {
230// Arc::new(IngressPort {
231// edge: OnceLock::new(),
232// })
233// }
234// }
235
236// #[async_trait]
237// impl<In, Out> AsyncEngine<Context<Vec<u8>>, Vec<u8>> for IngressPort<In, Out>
238// where
239// In: for<'de> Deserialize<'de> + DataType,
240// Out: PipelineIO + Serialize,
241// {
242// async fn generate(
243// &self,
244// request: Context<Vec<u8>>,
245// ) -> Result<EngineStream<Vec<u8>>, PipelineError> {
246// // Deserialize request
247// let request = request.try_map(|bytes| {
248// bincode::deserialize::<In>(&bytes)
249// .map_err(|err| PipelineError(format!("Failed to deserialize request: {}", err)))
250// })?;
251
252// // Forward request to edge
253// let stream = self
254// .edge
255// .get()
256// .ok_or(PipelineError("No engine to forward request to".to_string()))?
257// .generate(request)
258// .await?;
259
260// // Serialize response stream
261
262// let stream =
263// stream.map(|resp| bincode::serialize(&resp).expect("Failed to serialize response"));
264
265// Err(PipelineError(format!("Not implemented")))
266// }
267// }
268
269// fn convert_stream<T, U>(
270// stream: impl Stream<Item = ServerStream<T>> + Send + 'static,
271// ctx: Arc<dyn AsyncEngineContext>,
272// transform: Arc<dyn Fn(T) -> Result<U, StreamError> + Send + Sync>,
273// ) -> Pin<Box<dyn Stream<Item = ServerStream<U>> + Send>>
274// where
275// T: Send + 'static,
276// U: Send + 'static,
277// {
278// Box::pin(stream.flat_map(move |item| {
279// let ctx = ctx.clone();
280// let transform = transform.clone();
281// match item {
282// ServerStream::Data(data) => match transform(data) {
283// Ok(transformed) => futures::stream::iter(vec![ServerStream::Data(transformed)]),
284// Err(e) => {
285// // Trigger cancellation and propagate the error, followed by Sentinel
286// ctx.stop_generating();
287// futures::stream::iter(vec![ServerStream::Error(e), ServerStream::Sentinel])
288// }
289// },
290// other => futures::stream::iter(vec![other]),
291// }
292// })
293// // Use take_while to stop processing when encountering the Sentinel
294// .take_while(|item| futures::future::ready(!matches!(item, ServerStream::Sentinel))))
295// }