dynamo_runtime/pipeline/nodes/sources/
common.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::engine::AsyncEngineContextProvider;
17
18use super::*;
19
20macro_rules! impl_frontend {
21    ($type:ident) => {
22        impl<In: PipelineIO, Out: PipelineIO> $type<In, Out> {
23            pub fn new() -> Arc<Self> {
24                Arc::new(Self {
25                    inner: Frontend::default(),
26                })
27            }
28        }
29
30        #[async_trait]
31        impl<In: PipelineIO, Out: PipelineIO> Source<In> for $type<In, Out> {
32            async fn on_next(&self, data: In, token: private::Token) -> Result<(), Error> {
33                self.inner.on_next(data, token).await
34            }
35
36            fn set_edge(&self, edge: Edge<In>, token: private::Token) -> Result<(), PipelineError> {
37                self.inner.set_edge(edge, token)
38            }
39        }
40
41        #[async_trait]
42        impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out>
43            for $type<In, Out>
44        {
45            async fn on_data(&self, data: Out, token: private::Token) -> Result<(), Error> {
46                self.inner.on_data(data, token).await
47            }
48        }
49
50        #[async_trait]
51        impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error>
52            for $type<In, Out>
53        {
54            async fn generate(&self, request: In) -> Result<Out, Error> {
55                self.inner.generate(request).await
56            }
57        }
58    };
59}
60
61impl_frontend!(ServiceFrontend);
62impl_frontend!(SegmentSource);
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use crate::pipeline::{ManyOut, PipelineErrorExt, SingleIn};
68
69    #[tokio::test]
70    async fn test_pipeline_source_no_edge() {
71        let source = Frontend::<SingleIn<()>, ManyOut<()>>::default();
72        let stream = source
73            .generate(().into())
74            .await
75            .unwrap_err()
76            .try_into_pipeline_error()
77            .unwrap();
78
79        match stream {
80            PipelineError::NoEdge => (),
81            _ => panic!("Expected NoEdge error"),
82        }
83    }
84}