dynamo_runtime/pipeline/nodes/sources/
common.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::engine::AsyncEngineContextProvider;
5
6use super::*;
7
8macro_rules! impl_frontend {
9    ($type:ident) => {
10        impl<In: PipelineIO, Out: PipelineIO> $type<In, Out> {
11            pub fn new() -> Arc<Self> {
12                Arc::new(Self {
13                    inner: Frontend::default(),
14                })
15            }
16        }
17
18        #[async_trait]
19        impl<In: PipelineIO, Out: PipelineIO> Source<In> for $type<In, Out> {
20            async fn on_next(&self, data: In, token: private::Token) -> Result<(), Error> {
21                self.inner.on_next(data, token).await
22            }
23
24            fn set_edge(&self, edge: Edge<In>, token: private::Token) -> Result<(), PipelineError> {
25                self.inner.set_edge(edge, token)
26            }
27        }
28
29        #[async_trait]
30        impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out>
31            for $type<In, Out>
32        {
33            async fn on_data(&self, data: Out, token: private::Token) -> Result<(), Error> {
34                self.inner.on_data(data, token).await
35            }
36        }
37
38        #[async_trait]
39        impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error>
40            for $type<In, Out>
41        {
42            async fn generate(&self, request: In) -> Result<Out, Error> {
43                self.inner.generate(request).await
44            }
45        }
46    };
47}
48
49impl_frontend!(ServiceFrontend);
50impl_frontend!(SegmentSource);
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use crate::pipeline::{ManyOut, PipelineErrorExt, SingleIn};
56
57    #[tokio::test]
58    async fn test_pipeline_source_no_edge() {
59        let source = Frontend::<SingleIn<()>, ManyOut<()>>::default();
60        let stream = source
61            .generate(().into())
62            .await
63            .unwrap_err()
64            .try_into_pipeline_error()
65            .unwrap();
66
67        match stream {
68            PipelineError::NoEdge => (),
69            _ => panic!("Expected NoEdge error"),
70        }
71    }
72}