crossflow/
testing.rs

1/*
2 * Copyright (C) 2024 Open Source Robotics Foundation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16*/
17
18pub use bevy_app::{App, Update};
19pub use bevy_ecs::{
20    prelude::{Commands, Component, Entity, In, IntoSystem, Local, Query, ResMut, Resource, World},
21    world::CommandQueue,
22};
23use bevy_time::TimePlugin;
24
25use thiserror::Error as ThisError;
26
27use std::collections::HashMap;
28pub use std::time::{Duration, Instant};
29
30use smallvec::SmallVec;
31
32use crate::{
33    Accessing, AddContinuousServicesExt, AnyBuffer, AsAnyBuffer, AsyncServiceInput, BlockingMap,
34    BlockingServiceInput, Buffer, BufferKey, BufferKeyLifecycle, Bufferable, Buffering, Builder,
35    ContinuousQuery, ContinuousQueueView, ContinuousService, CrossflowExecutorApp, FlushParameters,
36    GetBufferedSessionsFn, Joining, OperationError, OperationResult, OperationRoster, Promise,
37    RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, StreamOf, StreamPack, UnhandledErrors,
38    WorkflowSettings,
39};
40
41pub struct TestingContext {
42    pub app: App,
43}
44
45impl TestingContext {
46    /// Make a testing context with the minimum plugins needed for crossflow
47    /// to work properly.
48    pub fn minimal_plugins() -> Self {
49        let mut app = App::new();
50        app.add_plugins((CrossflowExecutorApp::default(), TimePlugin));
51
52        TestingContext { app }
53    }
54
55    pub fn set_flush_loop_limit(&mut self, limit: Option<usize>) {
56        self.app
57            .world_mut()
58            .get_resource_or_insert_with(FlushParameters::avoid_hanging)
59            .flush_loop_limit = limit;
60    }
61
62    pub fn avoid_hanging_flush(&mut self) {
63        self.app
64            .world_mut()
65            .insert_resource(FlushParameters::avoid_hanging());
66    }
67
68    pub fn command<U>(&mut self, f: impl FnOnce(&mut Commands) -> U) -> U {
69        self.app.world_mut().command(f)
70    }
71
72    /// Build a simple workflow with a single input and output, and no streams
73    /// or settings.
74    pub fn spawn_io_workflow<Request, Response, Settings>(
75        &mut self,
76        f: impl FnOnce(Scope<Request, Response, ()>, &mut Builder) -> Settings,
77    ) -> Service<Request, Response, ()>
78    where
79        Request: 'static + Send + Sync,
80        Response: 'static + Send + Sync,
81        Settings: Into<WorkflowSettings>,
82    {
83        self.command(move |commands| commands.spawn_workflow(f))
84    }
85
86    /// Build any kind of workflow with any settings.
87    pub fn spawn_workflow<Request, Response, Streams, Settings>(
88        &mut self,
89        f: impl FnOnce(Scope<Request, Response, Streams>, &mut Builder) -> Settings,
90    ) -> Service<Request, Response, Streams>
91    where
92        Request: 'static + Send + Sync,
93        Response: 'static + Send + Sync,
94        Streams: StreamPack,
95        Settings: Into<WorkflowSettings>,
96    {
97        self.command(move |commands| commands.spawn_workflow(f))
98    }
99
100    pub fn run(&mut self, conditions: impl Into<FlushConditions>) {
101        self.run_impl::<()>(None, conditions.into());
102    }
103
104    pub fn run_while_pending<T>(&mut self, promise: &mut Promise<T>) {
105        self.run_with_conditions(promise, FlushConditions::new());
106    }
107
108    pub fn run_with_conditions<T>(
109        &mut self,
110        promise: &mut Promise<T>,
111        conditions: impl Into<FlushConditions>,
112    ) -> bool {
113        self.run_impl(Some(promise), conditions)
114    }
115
116    fn run_impl<T>(
117        &mut self,
118        mut promise: Option<&mut Promise<T>>,
119        conditions: impl Into<FlushConditions>,
120    ) -> bool {
121        let conditions = conditions.into();
122        let t_initial = std::time::Instant::now();
123        let mut count = 0;
124        while !promise.as_mut().is_some_and(|p| !p.peek().is_pending()) {
125            if let Some(timeout) = conditions.timeout {
126                let elapsed = std::time::Instant::now() - t_initial;
127                if timeout < elapsed {
128                    return false;
129                }
130            }
131
132            count += 1;
133            if let Some(count_limit) = conditions.update_count {
134                if count_limit < count {
135                    return false;
136                }
137            }
138
139            self.app.update();
140        }
141
142        true
143    }
144
145    pub fn no_unhandled_errors(&self) -> bool {
146        let Some(errors) = self.app.world().get_resource::<UnhandledErrors>() else {
147            return true;
148        };
149
150        errors.is_empty()
151    }
152
153    pub fn get_unhandled_errors(&self) -> Option<&UnhandledErrors> {
154        self.app.world().get_resource::<UnhandledErrors>()
155    }
156
157    pub fn assert_no_errors(&self) {
158        assert!(
159            self.no_unhandled_errors(),
160            "{:#?}",
161            self.get_unhandled_errors(),
162        );
163    }
164
165    // Check that all buffers in the world are empty
166    pub fn confirm_buffers_empty(&mut self) -> Result<(), Vec<Entity>> {
167        let mut query = self
168            .app
169            .world_mut()
170            .query::<(Entity, &GetBufferedSessionsFn)>();
171        let buffers: Vec<_> = query
172            .iter(self.app.world())
173            .map(|(e, get_sessions)| (e, get_sessions.0))
174            .collect();
175
176        let mut non_empty_buffers = Vec::new();
177        for (e, get_sessions) in buffers {
178            if !get_sessions(e, self.app.world()).is_ok_and(|s| s.is_empty()) {
179                non_empty_buffers.push(e);
180            }
181        }
182
183        if non_empty_buffers.is_empty() {
184            Ok(())
185        } else {
186            Err(non_empty_buffers)
187        }
188    }
189
190    /// Create a service that passes along its inputs after a delay.
191    pub fn spawn_delay<T>(&mut self, duration: Duration) -> Service<T, T, StreamOf<()>>
192    where
193        T: Clone + 'static + Send + Sync,
194    {
195        self.spawn_delayed_map(duration, |t: &T| t.clone())
196    }
197
198    /// Create a service that applies a map to an input after a delay.
199    pub fn spawn_delayed_map<T, U, F>(
200        &mut self,
201        duration: Duration,
202        f: F,
203    ) -> Service<T, U, StreamOf<()>>
204    where
205        T: 'static + Send + Sync,
206        U: 'static + Send + Sync,
207        F: FnMut(&T) -> U + 'static + Send + Sync,
208    {
209        self.spawn_delayed_map_with_viewer(duration, f, |_| {})
210    }
211
212    /// Create a service that applies a map to an input after a delay and allows
213    /// you to view the current set of requests. Its output stream will be
214    /// triggered when the timer begins for the request.
215    pub fn spawn_delayed_map_with_viewer<T, U, F, V>(
216        &mut self,
217        duration: Duration,
218        mut f: F,
219        mut viewer: V,
220    ) -> Service<T, U, StreamOf<()>>
221    where
222        T: 'static + Send + Sync,
223        U: 'static + Send + Sync,
224        F: FnMut(&T) -> U + 'static + Send + Sync,
225        V: FnMut(&ContinuousQueueView<T, U>) + 'static + Send + Sync,
226    {
227        self.app.spawn_continuous_service(
228            Update,
229            move |In(input): In<ContinuousService<T, U, StreamOf<()>>>,
230                  mut query: ContinuousQuery<T, U, StreamOf<()>>,
231                  mut timers: Local<HashMap<Entity, Instant>>| {
232                if let Some(view) = query.view(&input.key) {
233                    viewer(&view);
234                }
235
236                // Use a single now and elapsed for the entire cycle of this
237                // system so that race conditions don't cause later orders to
238                // "finish" before earlier orders.
239                let now = Instant::now();
240
241                query.get_mut(&input.key).unwrap().for_each(|order| {
242                    let order_id = order.id();
243                    let t0 = *timers.entry(order_id).or_insert_with(|| {
244                        order.streams().send(());
245                        now
246                    });
247
248                    if now - t0 > duration {
249                        let u = f(order.request());
250                        order.respond(u);
251                        timers.remove(&order_id);
252                    }
253                });
254            },
255        )
256    }
257
258    #[cfg(test)]
259    pub fn spawn_async_delayed_map<T, U, F>(&mut self, duration: Duration, f: F) -> Service<T, U>
260    where
261        T: 'static + Send + Sync,
262        U: 'static + Send + Sync,
263        F: FnOnce(T) -> U + 'static + Send + Sync + Clone,
264    {
265        use crate::AddServicesExt;
266        self.app
267            .spawn_service(move |In(input): AsyncServiceInput<T>| {
268                let f = f.clone();
269                async move {
270                    let start = Instant::now();
271                    let mut elapsed = start.elapsed();
272                    while elapsed < duration {
273                        let never = async_std::future::pending::<()>();
274                        let timeout = duration - elapsed;
275                        let _ = async_std::future::timeout(timeout, never).await;
276                        elapsed = start.elapsed();
277                    }
278                    f(input.request)
279                }
280            })
281    }
282}
283
284#[derive(Debug, Default, Clone)]
285pub struct FlushConditions {
286    pub timeout: Option<std::time::Duration>,
287    pub update_count: Option<usize>,
288}
289
290impl From<Duration> for FlushConditions {
291    fn from(value: Duration) -> Self {
292        Self::new().with_timeout(value)
293    }
294}
295
296impl From<usize> for FlushConditions {
297    fn from(value: usize) -> Self {
298        Self::new().with_update_count(value)
299    }
300}
301
302impl FlushConditions {
303    pub fn new() -> Self {
304        FlushConditions::default()
305    }
306
307    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
308        self.timeout = Some(timeout);
309        self
310    }
311
312    pub fn with_update_count(mut self, count: usize) -> Self {
313        self.update_count = Some(count);
314        self
315    }
316}
317
318#[derive(Debug, Clone, Copy)]
319pub struct InvalidValue(pub f32);
320
321pub fn spawn_test_entities(
322    In(input): BlockingServiceInput<usize>,
323    mut commands: Commands,
324) -> SmallVec<[Entity; 8]> {
325    let mut entities = SmallVec::new();
326    for _ in 0..input.request {
327        entities.push(commands.spawn(TestComponent).id());
328    }
329
330    entities
331}
332
333pub fn duplicate<T: Clone>(value: T) -> (T, T) {
334    (value.clone(), value)
335}
336
337pub fn double(value: f64) -> f64 {
338    2.0 * value
339}
340
341pub fn opposite(value: f64) -> f64 {
342    -value
343}
344
345pub fn add((a, b): (f64, f64)) -> f64 {
346    a + b
347}
348
349pub fn sum<Values: IntoIterator<Item = f64>>(values: Values) -> f64 {
350    values.into_iter().fold(0.0, |a, b| a + b)
351}
352
353pub fn repeat_string((times, value): (usize, String)) -> String {
354    value.repeat(times)
355}
356
357pub fn concat<Values: IntoIterator<Item = String>>(values: Values) -> String {
358    values.into_iter().fold(String::new(), |b, s| b + &s)
359}
360
361pub fn string_from_utf8<Values: IntoIterator<Item = u8>>(
362    values: Values,
363) -> Result<String, std::string::FromUtf8Error> {
364    String::from_utf8(values.into_iter().collect())
365}
366
367pub fn to_uppercase(value: String) -> String {
368    value.to_uppercase()
369}
370
371pub fn to_lowercase(value: String) -> String {
372    value.to_lowercase()
373}
374
375#[derive(Clone, Copy, Debug)]
376pub struct WaitRequest<Value> {
377    pub duration: std::time::Duration,
378    pub value: Value,
379}
380
381/// This function is used to force certain branches to lose races in tests or
382/// validate async execution with delays.
383#[cfg(test)]
384pub async fn wait<Value>(request: WaitRequest<Value>) -> Value {
385    use async_std::future;
386    let start = Instant::now();
387    let mut elapsed = start.elapsed();
388    while elapsed < request.duration {
389        let never = future::pending::<()>();
390        let timeout = request.duration - elapsed;
391        let _ = future::timeout(timeout, never).await;
392        elapsed = start.elapsed();
393    }
394    request.value
395}
396
397/// Use this to add a blocking map to the chain that simply prints a debug
398/// message and then passes the data along.
399pub fn print_debug<T: std::fmt::Debug>(header: impl Into<String>) -> impl Fn(BlockingMap<T>) -> T {
400    let header = header.into();
401    move |input| {
402        println!(
403            "[source: {:?}, session: {:?}] {}: {:?}",
404            input.source, input.session, header, input.request,
405        );
406        input.request
407    }
408}
409
410#[derive(ThisError, Debug)]
411#[error("This error is for testing purposes only")]
412pub struct TestError;
413
414/// Use this to create a blocking map that simply produces an error.
415/// Used for testing special operations for the [`Result`] type.
416pub fn produce_err<T>(_: T) -> Result<T, TestError> {
417    Err(TestError)
418}
419
420/// Use this to create a blocking map that simply produces [`None`].
421/// Used for testing special operations for the [`Option`] type.
422pub fn produce_none<T>(_: T) -> Option<T> {
423    None
424}
425
426pub struct RepeatRequest {
427    pub service: Service<(), (), ()>,
428    pub count: usize,
429}
430
431#[derive(Component)]
432pub struct Salutation(pub Box<str>);
433
434#[derive(Component)]
435pub struct Name(pub Box<str>);
436
437#[derive(Component)]
438pub struct RunCount(pub usize);
439
440pub fn say_hello(
441    In(input): BlockingServiceInput<()>,
442    salutation_query: Query<Option<&Salutation>>,
443    name_query: Query<Option<&Name>>,
444    mut run_count: Query<Option<&mut RunCount>>,
445) {
446    let salutation = salutation_query
447        .get(input.provider)
448        .ok()
449        .flatten()
450        .map(|x| &*x.0)
451        .unwrap_or("Hello, ");
452
453    let name = name_query
454        .get(input.provider)
455        .ok()
456        .flatten()
457        .map(|x| &*x.0)
458        .unwrap_or("world");
459
460    println!("{salutation}{name}");
461
462    if let Ok(Some(mut count)) = run_count.get_mut(input.provider) {
463        count.0 += 1;
464    }
465}
466
467pub fn repeat_service(
468    In(input): AsyncServiceInput<RepeatRequest>,
469    mut run_count: Query<Option<&mut RunCount>>,
470) -> impl std::future::Future<Output = ()> + 'static + Send + Sync {
471    if let Ok(Some(mut count)) = run_count.get_mut(input.provider) {
472        count.0 += 1;
473    }
474
475    async move {
476        for _ in 0..input.request.count {
477            input.channel.query((), input.request.service).await;
478        }
479    }
480}
481
482#[derive(Component)]
483pub struct TestComponent;
484
485#[derive(Component, Resource)]
486pub struct Integer {
487    pub value: i32,
488}
489
490/// This is an ordinary buffer newtype whose only purpose is to test the
491/// #[joined(noncopy_buffer)] feature. We intentionally do not implement
492/// the Copy trait for it.
493pub struct NonCopyBuffer<T> {
494    inner: Buffer<T>,
495}
496
497impl<T: 'static + Send + Sync> NonCopyBuffer<T> {
498    pub fn register_downcast() {
499        let any_interface = AnyBuffer::interface_for::<T>();
500        any_interface.register_buffer_downcast(
501            std::any::TypeId::of::<NonCopyBuffer<T>>(),
502            Box::new(|buffer| {
503                Ok(Box::new(NonCopyBuffer::<T> {
504                    inner: Buffer {
505                        location: buffer.location,
506                        _ignore: Default::default(),
507                    },
508                }))
509            }),
510        );
511    }
512}
513
514impl<T> Clone for NonCopyBuffer<T> {
515    fn clone(&self) -> Self {
516        Self { inner: self.inner }
517    }
518}
519
520impl<T: 'static + Send + Sync> AsAnyBuffer for NonCopyBuffer<T> {
521    fn as_any_buffer(&self) -> AnyBuffer {
522        self.inner.as_any_buffer()
523    }
524
525    fn message_type_hint() -> crate::MessageTypeHint {
526        crate::MessageTypeHint::exact::<T>()
527    }
528}
529
530impl<T: 'static + Send + Sync> Bufferable for NonCopyBuffer<T> {
531    type BufferType = Self;
532    fn into_buffer(self, _builder: &mut Builder) -> Self::BufferType {
533        self
534    }
535}
536
537impl<T: 'static + Send + Sync> Buffering for NonCopyBuffer<T> {
538    fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult {
539        self.inner.add_listener(listener, world)
540    }
541
542    fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> {
543        self.inner.as_input()
544    }
545
546    fn buffered_count(&self, session: Entity, world: &World) -> Result<usize, OperationError> {
547        self.inner.buffered_count(session, world)
548    }
549
550    fn buffered_count_for(
551        &self,
552        buffer: Entity,
553        session: Entity,
554        world: &World,
555    ) -> Result<usize, OperationError> {
556        self.inner.buffered_count_for(buffer, session, world)
557    }
558
559    fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult {
560        self.inner.ensure_active_session(session, world)
561    }
562
563    fn gate_action(
564        &self,
565        session: Entity,
566        action: crate::Gate,
567        world: &mut World,
568        roster: &mut OperationRoster,
569    ) -> OperationResult {
570        self.inner.gate_action(session, action, world, roster)
571    }
572
573    fn verify_scope(&self, scope: Entity) {
574        self.inner.verify_scope(scope);
575    }
576}
577
578impl<T: 'static + Send + Sync> Joining for NonCopyBuffer<T> {
579    type Item = T;
580    fn fetch_for_join(
581        &self,
582        session: Entity,
583        world: &mut World,
584    ) -> Result<Self::Item, OperationError> {
585        self.inner.fetch_for_join(session, world)
586    }
587}
588
589impl<T: 'static + Send + Sync> Accessing for NonCopyBuffer<T> {
590    type Key = BufferKey<T>;
591    fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult {
592        self.inner.add_accessor(accessor, world)
593    }
594
595    fn create_key(&self, builder: &crate::BufferKeyBuilder) -> Self::Key {
596        self.inner.create_key(builder)
597    }
598
599    fn deep_clone_key(key: &Self::Key) -> Self::Key {
600        key.deep_clone()
601    }
602
603    fn is_key_in_use(key: &Self::Key) -> bool {
604        key.is_in_use()
605    }
606}