crossflow/
channel.rs

1/*
2 * Copyright (C) 2023 Open Source Robotics Foundation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16*/
17
18use bevy_ecs::{
19    prelude::{Entity, Resource, World},
20    system::Commands,
21    world::CommandQueue,
22};
23
24use tokio::sync::mpsc::{
25    unbounded_channel, UnboundedReceiver as TokioReceiver, UnboundedSender as TokioSender,
26};
27
28use std::sync::Arc;
29
30use crate::{OperationError, OperationRoster, Promise, Provider, RequestExt, StreamPack};
31
32/// Provides asynchronous access to the [`World`], allowing you to issue queries
33/// or commands and then await the result.
34#[derive(Clone)]
35pub struct Channel {
36    inner: Arc<InnerChannel>,
37}
38
39impl Channel {
40    /// Run a query in the world and receive the promise of the query's output.
41    pub fn query<P>(&self, request: P::Request, provider: P) -> Promise<P::Response>
42    where
43        P: Provider,
44        P::Request: 'static + Send + Sync,
45        P::Response: 'static + Send + Sync,
46        P::Streams: 'static + StreamPack,
47        P: 'static + Send + Sync,
48    {
49        self.command(move |commands| commands.request(request, provider).take().response)
50            .flatten()
51    }
52
53    /// Get access to a [`Commands`] for the [`World`]
54    pub fn command<F, U>(&self, f: F) -> Promise<U>
55    where
56        F: FnOnce(&mut Commands) -> U + 'static + Send,
57        U: 'static + Send,
58    {
59        let (sender, promise) = Promise::new();
60        self.inner
61            .sender
62            .send(Box::new(
63                move |world: &mut World, _: &mut OperationRoster| {
64                    let mut command_queue = CommandQueue::default();
65                    let mut commands = Commands::new(&mut command_queue, world);
66                    let u = f(&mut commands);
67                    command_queue.apply(world);
68                    let _ = sender.send(u);
69                },
70            ))
71            .ok();
72
73        promise
74    }
75
76    pub(crate) fn for_streams<Streams: StreamPack>(
77        &self,
78        world: &World,
79    ) -> Result<Streams::StreamChannels, OperationError> {
80        Ok(Streams::make_stream_channels(&self.inner, world))
81    }
82
83    pub(crate) fn new(source: Entity, session: Entity, sender: TokioSender<ChannelItem>) -> Self {
84        Self {
85            inner: Arc::new(InnerChannel {
86                source,
87                session,
88                sender,
89            }),
90        }
91    }
92}
93
94#[derive(Clone)]
95pub struct InnerChannel {
96    pub(crate) source: Entity,
97    pub(crate) session: Entity,
98    pub(crate) sender: TokioSender<ChannelItem>,
99}
100
101impl InnerChannel {
102    pub fn source(&self) -> Entity {
103        self.source
104    }
105
106    pub fn sender(&self) -> &TokioSender<ChannelItem> {
107        &self.sender
108    }
109}
110
111pub(crate) type ChannelItem = Box<dyn FnOnce(&mut World, &mut OperationRoster) + Send>;
112pub(crate) type ChannelSender = TokioSender<ChannelItem>;
113pub(crate) type ChannelReceiver = TokioReceiver<ChannelItem>;
114
115#[derive(Resource)]
116pub(crate) struct ChannelQueue {
117    pub(crate) sender: ChannelSender,
118    pub(crate) receiver: ChannelReceiver,
119}
120
121impl ChannelQueue {
122    pub(crate) fn new() -> Self {
123        let (sender, receiver) = unbounded_channel();
124        Self { sender, receiver }
125    }
126}
127
128impl Default for ChannelQueue {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use crate::{prelude::*, testing::*};
137    use bevy_ecs::system::EntityCommands;
138    use std::time::Duration;
139
140    #[test]
141    fn test_channel_request() {
142        let mut context = TestingContext::minimal_plugins();
143
144        let (hello, repeat) = context.command(|commands| {
145            let hello =
146                commands.spawn_service(say_hello.with(|entity_cmds: &mut EntityCommands| {
147                    entity_cmds.insert((
148                        Salutation("Guten tag, ".into()),
149                        Name("tester".into()),
150                        RunCount(0),
151                    ));
152                }));
153            let repeat =
154                commands.spawn_service(repeat_service.with(|entity_cmds: &mut EntityCommands| {
155                    entity_cmds.insert(RunCount(0));
156                }));
157            (hello, repeat)
158        });
159
160        for _ in 0..5 {
161            let mut promise = context.command(|commands| {
162                commands
163                    .request(
164                        RepeatRequest {
165                            service: hello,
166                            count: 5,
167                        },
168                        repeat,
169                    )
170                    .take()
171                    .response
172            });
173
174            context.run_with_conditions(
175                &mut promise,
176                FlushConditions::new().with_timeout(Duration::from_secs(5)),
177            );
178
179            assert!(promise.peek().is_available());
180            assert!(context.no_unhandled_errors());
181        }
182
183        let count = context
184            .app
185            .world()
186            .get::<RunCount>(hello.provider())
187            .unwrap()
188            .0;
189        assert_eq!(count, 25);
190
191        let count = context
192            .app
193            .world()
194            .get::<RunCount>(repeat.provider())
195            .unwrap()
196            .0;
197        assert_eq!(count, 5);
198    }
199}