use bevy::ecs::schedule::{LogLevel, ScheduleBuildSettings, ScheduleLabel};
use bevy::prelude::*;
use bevy::tasks::AsyncComputeTaskPool;
use bevy::{prelude::World, time::Time};
use crossbeam_channel::Receiver;
use std::default;
use std::{collections::VecDeque, time::Duration};
#[derive(Component, Debug)]
pub struct WorkTask<T: TaskWorkerTrait + Send + Sync> {
pub started_at_render_time: Duration,
pub update_frames_elapsed: u32,
pub recv: Receiver<TaskResultRaw<T>>,
}
#[derive(Debug, Default)]
pub struct TaskResultRaw<T: TaskWorkerTrait + Send + Sync> {
pub result: T::TaskResultPure,
pub simulated_time: Duration,
}
pub struct TaskResult<T: TaskWorkerTrait + Send + Sync> {
pub result_raw: TaskResultRaw<T>,
pub render_time_elapsed_during_the_simulation: Duration,
pub started_at_render_time: Duration,
pub update_frames_elapsed: u32,
}
#[derive(Component)]
pub struct TaskResults<T: TaskWorkerTrait + Send + Sync> {
pub results: VecDeque<TaskResult<T>>,
}
impl<T: TaskWorkerTrait + Send + Sync> Default for TaskResults<T> {
fn default() -> Self {
Self {
results: VecDeque::new(),
}
}
}
#[derive(Default)]
pub struct BackgroundFixedUpdatePlugin<T: TaskWorkerTrait> {
pub phantom: std::marker::PhantomData<T>,
}
impl<T: TaskWorkerTrait> Plugin for BackgroundFixedUpdatePlugin<T> {
fn build(&self, app: &mut App) {
app.add_systems(
bevy::app::prelude::RunFixedMainLoop,
FixedMain::run_schedule::<T>,
);
app.edit_schedule(FixedMain, |schedule| {
schedule
.add_systems(HandleTask::run_schedule)
.set_build_settings(ScheduleBuildSettings {
ambiguity_detection: LogLevel::Error,
..default()
});
});
app.init_schedule(PreWriteBack);
app.edit_schedule(WriteBack, |schedule| {
schedule
.add_systems(handle_task::<T>)
.set_build_settings(ScheduleBuildSettings {
ambiguity_detection: LogLevel::Error,
..default()
});
});
app.edit_schedule(SpawnTask, |schedule| {
schedule
.add_systems((extract::<T>, spawn_task::<T>).chain())
.set_build_settings(ScheduleBuildSettings {
ambiguity_detection: LogLevel::Error,
..default()
});
});
app.edit_schedule(PostWriteBack, |schedule| {
schedule.set_build_settings(ScheduleBuildSettings {
ambiguity_detection: LogLevel::Error,
..default()
});
});
}
}
#[derive(Component, Debug, Default, Reflect, Clone)]
pub struct TaskToRenderTime {
pub diff: f64,
pub last_task_frame_count: u32,
}
#[derive(Component, Reflect, Clone)]
#[require(SubstepCount)]
pub struct Timestep {
pub timestep: Duration,
}
impl Default for Timestep {
fn default() -> Self {
Self {
timestep: Duration::from_secs_f64(1.0 / 60.0),
}
}
}
#[derive(Component, Reflect, Clone)]
pub struct SubstepCount(pub u32);
impl default::Default for SubstepCount {
fn default() -> Self {
Self(1)
}
}
#[derive(Clone, Component)]
#[require(TaskToRenderTime, Timestep)]
pub struct TaskWorker<T: TaskWorkerTrait> {
pub worker: T,
}
pub trait TaskWorkerTrait: Clone + Send + Sync + 'static {
type TaskExtractedData: Clone + Send + Sync + 'static + Component;
type TaskResultPure: Clone + Send + Sync + 'static;
fn extract(&self, worker_entity: Entity, world: &mut World) -> Self::TaskExtractedData;
fn work(
&self,
worker_entity: Entity,
data: Self::TaskExtractedData,
timestep: Duration,
substep_count: u32,
) -> Self::TaskResultPure;
fn write_back(&self, worker_entity: Entity, result: TaskResult<Self>, world: &mut World);
}
#[derive(SystemSet, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FixedMainLoop {
Before,
During,
After,
}
#[derive(ScheduleLabel, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct PreWriteBack;
#[derive(ScheduleLabel, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct WriteBack;
#[derive(ScheduleLabel, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SpawnTask;
#[derive(ScheduleLabel, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct PostWriteBack;
#[derive(Debug, Hash, PartialEq, Eq, Clone, ScheduleLabel)]
pub struct FixedMain;
impl FixedMain {
pub fn run_schedule<T: TaskWorkerTrait>(
world: &mut World,
mut has_run_at_least_once: Local<bool>,
) {
if !*has_run_at_least_once {
world.run_schedule(SpawnTask);
*has_run_at_least_once = true;
return;
}
world
.run_system_cached(finish_task_and_store_result::<T>)
.unwrap();
let clock = world.resource::<Time>().as_generic();
let mut query = world.query::<(&mut TaskToRenderTime, &Timestep, &SubstepCount)>();
let Ok((mut task_to_render_time, timestep, substep_count)) = query.get_single_mut(world)
else {
return;
};
task_to_render_time.diff += clock.delta().as_secs_f64();
if task_to_render_time.diff < (timestep.timestep.as_secs_f64() * substep_count.0 as f64) {
return;
}
let simulated_time = {
let mut query = world.query::<&TaskResults<T>>();
let task_result = query.single(world).results.front();
task_result.map(|task_result| task_result.result_raw.simulated_time)
};
let Some(simulated_time) = simulated_time else {
return;
};
let mut query = world.query::<&mut TaskToRenderTime>();
let mut task_to_render_time = query.single_mut(world);
task_to_render_time.diff -= simulated_time.as_secs_f64();
let _ = world.try_schedule_scope(FixedMain, |world, schedule| {
schedule.run(world);
});
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, ScheduleLabel)]
pub struct HandleTask;
impl HandleTask {
pub fn run_schedule(world: &mut World) {
let _ = world.try_schedule_scope(PreWriteBack, |world, schedule| {
schedule.run(world);
});
let _ = world.try_schedule_scope(WriteBack, |world, schedule| {
schedule.run(world);
});
let _ = world.try_schedule_scope(SpawnTask, |world, schedule| {
schedule.run(world);
});
let _ = world.try_schedule_scope(PostWriteBack, |world, schedule| {
schedule.run(world);
});
}
}
pub fn extract<T: TaskWorkerTrait>(world: &mut World) {
let Ok((entity_ctx, worker)) = world
.query_filtered::<(Entity, &TaskWorker<T>), With<Timestep>>()
.get_single(world)
else {
info!("No correct entity found.");
return;
};
let extractor = worker.worker.clone();
let extracted_data = extractor.extract(entity_ctx, world);
world.entity_mut(entity_ctx).insert(extracted_data.clone());
}
#[expect(clippy::type_complexity)]
pub fn spawn_task<T: TaskWorkerTrait>(
mut commands: Commands,
q_context: Query<(
Entity,
&TaskWorker<T>,
&Timestep,
&SubstepCount,
&T::TaskExtractedData,
)>,
virtual_time: Res<Time<Virtual>>,
) {
let Ok((entity_ctx, worker, timestep, substep_count, extracted_data)) = q_context.get_single()
else {
return;
};
let timestep = timestep.timestep;
let substep_count = substep_count.0;
let (sender, recv) = crossbeam_channel::unbounded();
let extracted_data = extracted_data.clone();
let worker = worker.clone();
let thread_pool = AsyncComputeTaskPool::get();
thread_pool
.spawn(async move {
let simulated_time = timestep * substep_count;
profiling::scope!("Task execution");
let result_data =
worker
.worker
.work(entity_ctx, extracted_data, timestep, substep_count);
let result = TaskResultRaw::<T> {
result: result_data,
simulated_time,
};
let _ = sender.send(result);
})
.detach();
commands.entity(entity_ctx).insert(WorkTask {
recv,
started_at_render_time: virtual_time.elapsed(),
update_frames_elapsed: 0,
});
}
#[expect(clippy::type_complexity)]
pub fn finish_task_and_store_result<T: TaskWorkerTrait>(
mut commands: Commands,
time: Res<Time<Virtual>>,
mut q_tasks: Query<(Entity, &mut WorkTask<T>, Option<&mut TaskResults<T>>)>,
) {
let Ok((e, mut task, mut results)) = q_tasks.get_single_mut() else {
return;
};
task.update_frames_elapsed += 1;
let mut handle_result = |task_result_raw: TaskResultRaw<T>| {
commands.entity(e).remove::<WorkTask<T>>();
let result = TaskResult::<T> {
result_raw: task_result_raw,
render_time_elapsed_during_the_simulation: time.elapsed() - task.started_at_render_time,
started_at_render_time: task.started_at_render_time,
update_frames_elapsed: task.update_frames_elapsed,
};
if let Some(results) = results.as_mut() {
results.results.push_back(result);
} else {
let mut results = TaskResults::<T>::default();
results.results.push_back(result);
commands.entity(e).insert(results);
}
};
if task.update_frames_elapsed > 60 {
if let Ok(result) = task.recv.recv() {
handle_result(result);
}
} else if let Ok(result) = task.recv.try_recv() {
handle_result(result);
}
}
pub(crate) fn handle_task<T: TaskWorkerTrait>(world: &mut World) {
let mut task_results = world.query::<(
Entity,
&mut TaskResults<T>,
&TaskWorker<T>,
&mut TaskToRenderTime,
)>();
let mut tasks_to_handle = vec![];
for (entity_ctx, mut results, worker, mut task_to_render) in task_results.iter_mut(world) {
let Some(task) = results.results.pop_front() else {
continue;
};
task_to_render.last_task_frame_count = task.update_frames_elapsed;
tasks_to_handle.push((entity_ctx, worker.clone(), task));
}
for (entity_ctx, worker, task) in tasks_to_handle {
worker.worker.write_back(entity_ctx, task, world);
}
}