1pub use bevy_app::{App, Update};
19pub use bevy_ecs::{
20 prelude::{Commands, Component, Entity, In, IntoSystem, Local, Query, ResMut, Resource, World},
21 world::CommandQueue,
22};
23use bevy_time::TimePlugin;
24
25use thiserror::Error as ThisError;
26
27use std::collections::HashMap;
28pub use std::time::{Duration, Instant};
29
30use smallvec::SmallVec;
31
32use crate::{
33 Accessing, AddContinuousServicesExt, AnyBuffer, AsAnyBuffer, AsyncServiceInput, BlockingMap,
34 BlockingServiceInput, Buffer, BufferKey, BufferKeyLifecycle, Bufferable, Buffering, Builder,
35 ContinuousQuery, ContinuousQueueView, ContinuousService, CrossflowExecutorApp, FlushParameters,
36 GetBufferedSessionsFn, Joining, OperationError, OperationResult, OperationRoster, Promise,
37 RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, StreamOf, StreamPack, UnhandledErrors,
38 WorkflowSettings,
39};
40
41pub struct TestingContext {
42 pub app: App,
43}
44
45impl TestingContext {
46 pub fn minimal_plugins() -> Self {
49 let mut app = App::new();
50 app.add_plugins((CrossflowExecutorApp::default(), TimePlugin));
51
52 TestingContext { app }
53 }
54
55 pub fn set_flush_loop_limit(&mut self, limit: Option<usize>) {
56 self.app
57 .world_mut()
58 .get_resource_or_insert_with(FlushParameters::avoid_hanging)
59 .flush_loop_limit = limit;
60 }
61
62 pub fn avoid_hanging_flush(&mut self) {
63 self.app
64 .world_mut()
65 .insert_resource(FlushParameters::avoid_hanging());
66 }
67
68 pub fn command<U>(&mut self, f: impl FnOnce(&mut Commands) -> U) -> U {
69 self.app.world_mut().command(f)
70 }
71
72 pub fn spawn_io_workflow<Request, Response, Settings>(
75 &mut self,
76 f: impl FnOnce(Scope<Request, Response, ()>, &mut Builder) -> Settings,
77 ) -> Service<Request, Response, ()>
78 where
79 Request: 'static + Send + Sync,
80 Response: 'static + Send + Sync,
81 Settings: Into<WorkflowSettings>,
82 {
83 self.command(move |commands| commands.spawn_workflow(f))
84 }
85
86 pub fn spawn_workflow<Request, Response, Streams, Settings>(
88 &mut self,
89 f: impl FnOnce(Scope<Request, Response, Streams>, &mut Builder) -> Settings,
90 ) -> Service<Request, Response, Streams>
91 where
92 Request: 'static + Send + Sync,
93 Response: 'static + Send + Sync,
94 Streams: StreamPack,
95 Settings: Into<WorkflowSettings>,
96 {
97 self.command(move |commands| commands.spawn_workflow(f))
98 }
99
100 pub fn run(&mut self, conditions: impl Into<FlushConditions>) {
101 self.run_impl::<()>(None, conditions.into());
102 }
103
104 pub fn run_while_pending<T>(&mut self, promise: &mut Promise<T>) {
105 self.run_with_conditions(promise, FlushConditions::new());
106 }
107
108 pub fn run_with_conditions<T>(
109 &mut self,
110 promise: &mut Promise<T>,
111 conditions: impl Into<FlushConditions>,
112 ) -> bool {
113 self.run_impl(Some(promise), conditions)
114 }
115
116 fn run_impl<T>(
117 &mut self,
118 mut promise: Option<&mut Promise<T>>,
119 conditions: impl Into<FlushConditions>,
120 ) -> bool {
121 let conditions = conditions.into();
122 let t_initial = std::time::Instant::now();
123 let mut count = 0;
124 while !promise.as_mut().is_some_and(|p| !p.peek().is_pending()) {
125 if let Some(timeout) = conditions.timeout {
126 let elapsed = std::time::Instant::now() - t_initial;
127 if timeout < elapsed {
128 return false;
129 }
130 }
131
132 count += 1;
133 if let Some(count_limit) = conditions.update_count {
134 if count_limit < count {
135 return false;
136 }
137 }
138
139 self.app.update();
140 }
141
142 true
143 }
144
145 pub fn no_unhandled_errors(&self) -> bool {
146 let Some(errors) = self.app.world().get_resource::<UnhandledErrors>() else {
147 return true;
148 };
149
150 errors.is_empty()
151 }
152
153 pub fn get_unhandled_errors(&self) -> Option<&UnhandledErrors> {
154 self.app.world().get_resource::<UnhandledErrors>()
155 }
156
157 pub fn assert_no_errors(&self) {
158 assert!(
159 self.no_unhandled_errors(),
160 "{:#?}",
161 self.get_unhandled_errors(),
162 );
163 }
164
165 pub fn confirm_buffers_empty(&mut self) -> Result<(), Vec<Entity>> {
167 let mut query = self
168 .app
169 .world_mut()
170 .query::<(Entity, &GetBufferedSessionsFn)>();
171 let buffers: Vec<_> = query
172 .iter(self.app.world())
173 .map(|(e, get_sessions)| (e, get_sessions.0))
174 .collect();
175
176 let mut non_empty_buffers = Vec::new();
177 for (e, get_sessions) in buffers {
178 if !get_sessions(e, self.app.world()).is_ok_and(|s| s.is_empty()) {
179 non_empty_buffers.push(e);
180 }
181 }
182
183 if non_empty_buffers.is_empty() {
184 Ok(())
185 } else {
186 Err(non_empty_buffers)
187 }
188 }
189
190 pub fn spawn_delay<T>(&mut self, duration: Duration) -> Service<T, T, StreamOf<()>>
192 where
193 T: Clone + 'static + Send + Sync,
194 {
195 self.spawn_delayed_map(duration, |t: &T| t.clone())
196 }
197
198 pub fn spawn_delayed_map<T, U, F>(
200 &mut self,
201 duration: Duration,
202 f: F,
203 ) -> Service<T, U, StreamOf<()>>
204 where
205 T: 'static + Send + Sync,
206 U: 'static + Send + Sync,
207 F: FnMut(&T) -> U + 'static + Send + Sync,
208 {
209 self.spawn_delayed_map_with_viewer(duration, f, |_| {})
210 }
211
212 pub fn spawn_delayed_map_with_viewer<T, U, F, V>(
216 &mut self,
217 duration: Duration,
218 mut f: F,
219 mut viewer: V,
220 ) -> Service<T, U, StreamOf<()>>
221 where
222 T: 'static + Send + Sync,
223 U: 'static + Send + Sync,
224 F: FnMut(&T) -> U + 'static + Send + Sync,
225 V: FnMut(&ContinuousQueueView<T, U>) + 'static + Send + Sync,
226 {
227 self.app.spawn_continuous_service(
228 Update,
229 move |In(input): In<ContinuousService<T, U, StreamOf<()>>>,
230 mut query: ContinuousQuery<T, U, StreamOf<()>>,
231 mut timers: Local<HashMap<Entity, Instant>>| {
232 if let Some(view) = query.view(&input.key) {
233 viewer(&view);
234 }
235
236 let now = Instant::now();
240
241 query.get_mut(&input.key).unwrap().for_each(|order| {
242 let order_id = order.id();
243 let t0 = *timers.entry(order_id).or_insert_with(|| {
244 order.streams().send(());
245 now
246 });
247
248 if now - t0 > duration {
249 let u = f(order.request());
250 order.respond(u);
251 timers.remove(&order_id);
252 }
253 });
254 },
255 )
256 }
257
258 #[cfg(test)]
259 pub fn spawn_async_delayed_map<T, U, F>(&mut self, duration: Duration, f: F) -> Service<T, U>
260 where
261 T: 'static + Send + Sync,
262 U: 'static + Send + Sync,
263 F: FnOnce(T) -> U + 'static + Send + Sync + Clone,
264 {
265 use crate::AddServicesExt;
266 self.app
267 .spawn_service(move |In(input): AsyncServiceInput<T>| {
268 let f = f.clone();
269 async move {
270 let start = Instant::now();
271 let mut elapsed = start.elapsed();
272 while elapsed < duration {
273 let never = async_std::future::pending::<()>();
274 let timeout = duration - elapsed;
275 let _ = async_std::future::timeout(timeout, never).await;
276 elapsed = start.elapsed();
277 }
278 f(input.request)
279 }
280 })
281 }
282}
283
284#[derive(Debug, Default, Clone)]
285pub struct FlushConditions {
286 pub timeout: Option<std::time::Duration>,
287 pub update_count: Option<usize>,
288}
289
290impl From<Duration> for FlushConditions {
291 fn from(value: Duration) -> Self {
292 Self::new().with_timeout(value)
293 }
294}
295
296impl From<usize> for FlushConditions {
297 fn from(value: usize) -> Self {
298 Self::new().with_update_count(value)
299 }
300}
301
302impl FlushConditions {
303 pub fn new() -> Self {
304 FlushConditions::default()
305 }
306
307 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
308 self.timeout = Some(timeout);
309 self
310 }
311
312 pub fn with_update_count(mut self, count: usize) -> Self {
313 self.update_count = Some(count);
314 self
315 }
316}
317
318#[derive(Debug, Clone, Copy)]
319pub struct InvalidValue(pub f32);
320
321pub fn spawn_test_entities(
322 In(input): BlockingServiceInput<usize>,
323 mut commands: Commands,
324) -> SmallVec<[Entity; 8]> {
325 let mut entities = SmallVec::new();
326 for _ in 0..input.request {
327 entities.push(commands.spawn(TestComponent).id());
328 }
329
330 entities
331}
332
333pub fn duplicate<T: Clone>(value: T) -> (T, T) {
334 (value.clone(), value)
335}
336
337pub fn double(value: f64) -> f64 {
338 2.0 * value
339}
340
341pub fn opposite(value: f64) -> f64 {
342 -value
343}
344
345pub fn add((a, b): (f64, f64)) -> f64 {
346 a + b
347}
348
349pub fn sum<Values: IntoIterator<Item = f64>>(values: Values) -> f64 {
350 values.into_iter().fold(0.0, |a, b| a + b)
351}
352
353pub fn repeat_string((times, value): (usize, String)) -> String {
354 value.repeat(times)
355}
356
357pub fn concat<Values: IntoIterator<Item = String>>(values: Values) -> String {
358 values.into_iter().fold(String::new(), |b, s| b + &s)
359}
360
361pub fn string_from_utf8<Values: IntoIterator<Item = u8>>(
362 values: Values,
363) -> Result<String, std::string::FromUtf8Error> {
364 String::from_utf8(values.into_iter().collect())
365}
366
367pub fn to_uppercase(value: String) -> String {
368 value.to_uppercase()
369}
370
371pub fn to_lowercase(value: String) -> String {
372 value.to_lowercase()
373}
374
375#[derive(Clone, Copy, Debug)]
376pub struct WaitRequest<Value> {
377 pub duration: std::time::Duration,
378 pub value: Value,
379}
380
381#[cfg(test)]
384pub async fn wait<Value>(request: WaitRequest<Value>) -> Value {
385 use async_std::future;
386 let start = Instant::now();
387 let mut elapsed = start.elapsed();
388 while elapsed < request.duration {
389 let never = future::pending::<()>();
390 let timeout = request.duration - elapsed;
391 let _ = future::timeout(timeout, never).await;
392 elapsed = start.elapsed();
393 }
394 request.value
395}
396
397pub fn print_debug<T: std::fmt::Debug>(header: impl Into<String>) -> impl Fn(BlockingMap<T>) -> T {
400 let header = header.into();
401 move |input| {
402 println!(
403 "[source: {:?}, session: {:?}] {}: {:?}",
404 input.source, input.session, header, input.request,
405 );
406 input.request
407 }
408}
409
410#[derive(ThisError, Debug)]
411#[error("This error is for testing purposes only")]
412pub struct TestError;
413
414pub fn produce_err<T>(_: T) -> Result<T, TestError> {
417 Err(TestError)
418}
419
420pub fn produce_none<T>(_: T) -> Option<T> {
423 None
424}
425
426pub struct RepeatRequest {
427 pub service: Service<(), (), ()>,
428 pub count: usize,
429}
430
431#[derive(Component)]
432pub struct Salutation(pub Box<str>);
433
434#[derive(Component)]
435pub struct Name(pub Box<str>);
436
437#[derive(Component)]
438pub struct RunCount(pub usize);
439
440pub fn say_hello(
441 In(input): BlockingServiceInput<()>,
442 salutation_query: Query<Option<&Salutation>>,
443 name_query: Query<Option<&Name>>,
444 mut run_count: Query<Option<&mut RunCount>>,
445) {
446 let salutation = salutation_query
447 .get(input.provider)
448 .ok()
449 .flatten()
450 .map(|x| &*x.0)
451 .unwrap_or("Hello, ");
452
453 let name = name_query
454 .get(input.provider)
455 .ok()
456 .flatten()
457 .map(|x| &*x.0)
458 .unwrap_or("world");
459
460 println!("{salutation}{name}");
461
462 if let Ok(Some(mut count)) = run_count.get_mut(input.provider) {
463 count.0 += 1;
464 }
465}
466
467pub fn repeat_service(
468 In(input): AsyncServiceInput<RepeatRequest>,
469 mut run_count: Query<Option<&mut RunCount>>,
470) -> impl std::future::Future<Output = ()> + 'static + Send + Sync {
471 if let Ok(Some(mut count)) = run_count.get_mut(input.provider) {
472 count.0 += 1;
473 }
474
475 async move {
476 for _ in 0..input.request.count {
477 input.channel.query((), input.request.service).await;
478 }
479 }
480}
481
482#[derive(Component)]
483pub struct TestComponent;
484
485#[derive(Component, Resource)]
486pub struct Integer {
487 pub value: i32,
488}
489
490pub struct NonCopyBuffer<T> {
494 inner: Buffer<T>,
495}
496
497impl<T: 'static + Send + Sync> NonCopyBuffer<T> {
498 pub fn register_downcast() {
499 let any_interface = AnyBuffer::interface_for::<T>();
500 any_interface.register_buffer_downcast(
501 std::any::TypeId::of::<NonCopyBuffer<T>>(),
502 Box::new(|buffer| {
503 Ok(Box::new(NonCopyBuffer::<T> {
504 inner: Buffer {
505 location: buffer.location,
506 _ignore: Default::default(),
507 },
508 }))
509 }),
510 );
511 }
512}
513
514impl<T> Clone for NonCopyBuffer<T> {
515 fn clone(&self) -> Self {
516 Self { inner: self.inner }
517 }
518}
519
520impl<T: 'static + Send + Sync> AsAnyBuffer for NonCopyBuffer<T> {
521 fn as_any_buffer(&self) -> AnyBuffer {
522 self.inner.as_any_buffer()
523 }
524
525 fn message_type_hint() -> crate::MessageTypeHint {
526 crate::MessageTypeHint::exact::<T>()
527 }
528}
529
530impl<T: 'static + Send + Sync> Bufferable for NonCopyBuffer<T> {
531 type BufferType = Self;
532 fn into_buffer(self, _builder: &mut Builder) -> Self::BufferType {
533 self
534 }
535}
536
537impl<T: 'static + Send + Sync> Buffering for NonCopyBuffer<T> {
538 fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult {
539 self.inner.add_listener(listener, world)
540 }
541
542 fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> {
543 self.inner.as_input()
544 }
545
546 fn buffered_count(&self, session: Entity, world: &World) -> Result<usize, OperationError> {
547 self.inner.buffered_count(session, world)
548 }
549
550 fn buffered_count_for(
551 &self,
552 buffer: Entity,
553 session: Entity,
554 world: &World,
555 ) -> Result<usize, OperationError> {
556 self.inner.buffered_count_for(buffer, session, world)
557 }
558
559 fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult {
560 self.inner.ensure_active_session(session, world)
561 }
562
563 fn gate_action(
564 &self,
565 session: Entity,
566 action: crate::Gate,
567 world: &mut World,
568 roster: &mut OperationRoster,
569 ) -> OperationResult {
570 self.inner.gate_action(session, action, world, roster)
571 }
572
573 fn verify_scope(&self, scope: Entity) {
574 self.inner.verify_scope(scope);
575 }
576}
577
578impl<T: 'static + Send + Sync> Joining for NonCopyBuffer<T> {
579 type Item = T;
580 fn fetch_for_join(
581 &self,
582 session: Entity,
583 world: &mut World,
584 ) -> Result<Self::Item, OperationError> {
585 self.inner.fetch_for_join(session, world)
586 }
587}
588
589impl<T: 'static + Send + Sync> Accessing for NonCopyBuffer<T> {
590 type Key = BufferKey<T>;
591 fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult {
592 self.inner.add_accessor(accessor, world)
593 }
594
595 fn create_key(&self, builder: &crate::BufferKeyBuilder) -> Self::Key {
596 self.inner.create_key(builder)
597 }
598
599 fn deep_clone_key(key: &Self::Key) -> Self::Key {
600 key.deep_clone()
601 }
602
603 fn is_key_in_use(key: &Self::Key) -> bool {
604 key.is_in_use()
605 }
606}