Skip to main content

oxiphysics_gpu/
scheduler.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU workload scheduler.
5//!
6//! Provides task graphs, topological scheduling, resource barriers, async
7//! compute simulation, frame graphs, and timestamp queries — all CPU-side.
8
9#![allow(dead_code)]
10#![allow(missing_docs)]
11
12use std::collections::{HashMap, HashSet, VecDeque};
13
14// ---------------------------------------------------------------------------
15// TaskPriority
16// ---------------------------------------------------------------------------
17
18/// Priority level for a compute task.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
20pub enum TaskPriority {
21    /// Must complete within the current frame; highest urgency.
22    RealTime = 4,
23    /// High-importance work that should not be deferred.
24    High = 3,
25    /// Standard work.
26    #[default]
27    Normal = 2,
28    /// Can be deferred to future frames if time is tight.
29    Low = 1,
30    /// Background processing with no frame deadline.
31    Background = 0,
32}
33
34// ---------------------------------------------------------------------------
35// ComputeTask
36// ---------------------------------------------------------------------------
37
38/// A single GPU compute dispatch.
39#[derive(Debug, Clone)]
40pub struct ComputeTask {
41    /// Unique task name.
42    pub name: String,
43    /// Workgroup size in (x, y, z).
44    pub workgroup_size: [u32; 3],
45    /// Dispatch count in (x, y, z).
46    pub dispatch_count: [u32; 3],
47    /// Names of tasks that must complete before this one.
48    pub dependencies: Vec<String>,
49    /// Priority for the scheduler.
50    pub priority: TaskPriority,
51    /// Estimated execution time in milliseconds.
52    pub estimated_ms: f64,
53}
54
55impl ComputeTask {
56    /// Create a simple 1-D compute task.
57    pub fn new_1d(name: impl Into<String>, dispatch_x: u32) -> Self {
58        Self {
59            name: name.into(),
60            workgroup_size: [64, 1, 1],
61            dispatch_count: [dispatch_x, 1, 1],
62            dependencies: vec![],
63            priority: TaskPriority::Normal,
64            estimated_ms: 1.0,
65        }
66    }
67
68    /// Create a 2-D compute task.
69    pub fn new_2d(name: impl Into<String>, dispatch_x: u32, dispatch_y: u32) -> Self {
70        Self {
71            name: name.into(),
72            workgroup_size: [8, 8, 1],
73            dispatch_count: [dispatch_x, dispatch_y, 1],
74            dependencies: vec![],
75            priority: TaskPriority::Normal,
76            estimated_ms: 1.0,
77        }
78    }
79
80    /// Total number of workgroup invocations.
81    pub fn total_workgroups(&self) -> u64 {
82        self.dispatch_count[0] as u64
83            * self.dispatch_count[1] as u64
84            * self.dispatch_count[2] as u64
85    }
86
87    /// Total number of shader invocations.
88    pub fn total_invocations(&self) -> u64 {
89        self.total_workgroups()
90            * self.workgroup_size[0] as u64
91            * self.workgroup_size[1] as u64
92            * self.workgroup_size[2] as u64
93    }
94
95    /// Add a dependency by name.
96    pub fn depends_on(mut self, dep: impl Into<String>) -> Self {
97        self.dependencies.push(dep.into());
98        self
99    }
100
101    /// Set priority.
102    pub fn with_priority(mut self, priority: TaskPriority) -> Self {
103        self.priority = priority;
104        self
105    }
106
107    /// Set estimated execution time.
108    pub fn with_estimated_ms(mut self, ms: f64) -> Self {
109        self.estimated_ms = ms;
110        self
111    }
112}
113
114// ---------------------------------------------------------------------------
115// TaskGraph
116// ---------------------------------------------------------------------------
117
118/// A directed acyclic graph of compute tasks.
119#[derive(Debug, Clone, Default)]
120pub struct TaskGraph {
121    /// All tasks keyed by name.
122    tasks: HashMap<String, ComputeTask>,
123}
124
125impl TaskGraph {
126    /// Create an empty task graph.
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    /// Add a task to the graph.  Replaces any existing task with the same name.
132    pub fn add_task(&mut self, task: ComputeTask) {
133        self.tasks.insert(task.name.clone(), task);
134    }
135
136    /// Remove a task by name.
137    pub fn remove_task(&mut self, name: &str) {
138        self.tasks.remove(name);
139    }
140
141    /// Number of tasks.
142    pub fn len(&self) -> usize {
143        self.tasks.len()
144    }
145
146    /// True when there are no tasks.
147    pub fn is_empty(&self) -> bool {
148        self.tasks.is_empty()
149    }
150
151    /// Topological sort using Kahn's algorithm.
152    ///
153    /// Returns `Ok(order)` where `order` is a valid execution order, or
154    /// `Err(cycle)` naming one task involved in a cycle.
155    pub fn topological_sort(&self) -> Result<Vec<String>, String> {
156        // Build adjacency and in-degree maps
157        let mut in_degree: HashMap<&str, usize> = HashMap::new();
158        let mut rev: HashMap<&str, Vec<&str>> = HashMap::new(); // task -> tasks that depend on it
159
160        for (name, task) in &self.tasks {
161            in_degree.entry(name.as_str()).or_insert(0);
162            for dep in &task.dependencies {
163                if !self.tasks.contains_key(dep.as_str()) {
164                    // Unknown dependency — skip
165                    continue;
166                }
167                // dep -> name (name depends on dep)
168                rev.entry(dep.as_str()).or_default().push(name.as_str());
169                *in_degree.entry(name.as_str()).or_insert(0) += 1;
170            }
171        }
172
173        let mut queue: VecDeque<&str> = in_degree
174            .iter()
175            .filter(|(_, d)| **d == 0)
176            .map(|(&n, _)| n)
177            .collect();
178
179        // Sort for determinism
180        let mut queue_vec: Vec<&str> = queue.drain(..).collect();
181        queue_vec.sort();
182        queue.extend(queue_vec);
183
184        let mut order = Vec::new();
185        while let Some(name) = queue.pop_front() {
186            order.push(name.to_owned());
187            if let Some(dependents) = rev.get(name) {
188                let mut next: Vec<&str> = dependents
189                    .iter()
190                    .filter_map(|&d| {
191                        let deg = in_degree.get_mut(d)?;
192                        *deg -= 1;
193                        if *deg == 0 { Some(d) } else { None }
194                    })
195                    .collect();
196                next.sort();
197                queue.extend(next);
198            }
199        }
200
201        if order.len() != self.tasks.len() {
202            // Find a node still in a cycle
203            let cycle_node = self
204                .tasks
205                .keys()
206                .find(|n| !order.contains(*n))
207                .cloned()
208                .unwrap_or_else(|| "unknown".to_owned());
209            Err(cycle_node)
210        } else {
211            Ok(order)
212        }
213    }
214
215    /// Compute the critical path (longest chain by estimated_ms).
216    ///
217    /// Returns the list of task names on the critical path.
218    pub fn critical_path(&self) -> Vec<String> {
219        let order = match self.topological_sort() {
220            Ok(o) => o,
221            Err(_) => return vec![],
222        };
223
224        // Compute earliest finish times
225        let mut eft: HashMap<&str, f64> = HashMap::new();
226        let mut pred: HashMap<&str, &str> = HashMap::new();
227
228        for name in &order {
229            let task = &self.tasks[name.as_str()];
230            let dep_max = task
231                .dependencies
232                .iter()
233                .filter_map(|d| eft.get(d.as_str()).copied())
234                .fold(0.0f64, f64::max);
235            let ef = dep_max + task.estimated_ms;
236            eft.insert(name.as_str(), ef);
237            // Track predecessor that provides dep_max
238            if let Some(best_pred) = task
239                .dependencies
240                .iter()
241                .filter_map(|d| {
242                    let t = eft.get(d.as_str()).copied()?;
243                    Some((d.as_str(), t))
244                })
245                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
246                .map(|(d, _)| d)
247            {
248                pred.insert(name.as_str(), best_pred);
249            }
250        }
251
252        // Walk back from the task with maximum eft
253        let end = order.iter().max_by(|a, b| {
254            eft.get(a.as_str())
255                .unwrap_or(&0.0)
256                .partial_cmp(eft.get(b.as_str()).unwrap_or(&0.0))
257                .expect("operation should succeed")
258        });
259
260        let mut path = Vec::new();
261        let mut cur = match end {
262            Some(s) => s.as_str(),
263            None => return vec![],
264        };
265        loop {
266            path.push(cur.to_owned());
267            match pred.get(cur) {
268                Some(&p) => cur = p,
269                None => break,
270            }
271        }
272        path.reverse();
273        path
274    }
275
276    /// Check whether the graph contains a cycle.
277    pub fn has_cycle(&self) -> bool {
278        self.topological_sort().is_err()
279    }
280}
281
282// ---------------------------------------------------------------------------
283// ResourceBarrier
284// ---------------------------------------------------------------------------
285
286/// Type of resource access ordering barrier.
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288pub enum BarrierType {
289    /// Read-after-write: a reader must wait for a preceding writer.
290    ReadAfterWrite,
291    /// Write-after-read: a writer must wait for a preceding reader.
292    WriteAfterRead,
293    /// Write-after-write: serialise two writers.
294    WriteAfterWrite,
295}
296
297/// A resource barrier between two tasks.
298#[derive(Debug, Clone)]
299pub struct ResourceBarrier {
300    /// Name of the task that writes / produces.
301    pub producer: String,
302    /// Name of the task that reads / consumes.
303    pub consumer: String,
304    /// The kind of hazard this barrier prevents.
305    pub barrier_type: BarrierType,
306    /// Resource being protected (e.g. buffer name).
307    pub resource: String,
308}
309
310impl ResourceBarrier {
311    /// Create a read-after-write barrier.
312    pub fn raw(
313        producer: impl Into<String>,
314        consumer: impl Into<String>,
315        resource: impl Into<String>,
316    ) -> Self {
317        Self {
318            producer: producer.into(),
319            consumer: consumer.into(),
320            barrier_type: BarrierType::ReadAfterWrite,
321            resource: resource.into(),
322        }
323    }
324
325    /// Create a write-after-read barrier.
326    pub fn war(
327        producer: impl Into<String>,
328        consumer: impl Into<String>,
329        resource: impl Into<String>,
330    ) -> Self {
331        Self {
332            producer: producer.into(),
333            consumer: consumer.into(),
334            barrier_type: BarrierType::WriteAfterRead,
335            resource: resource.into(),
336        }
337    }
338}
339
340// ---------------------------------------------------------------------------
341// TaskScheduler
342// ---------------------------------------------------------------------------
343
344/// Schedules a task graph into an execution order.
345#[derive(Debug, Default)]
346pub struct TaskScheduler {
347    /// Barriers to inject between tasks.
348    pub barriers: Vec<ResourceBarrier>,
349}
350
351impl TaskScheduler {
352    /// Create a new scheduler.
353    pub fn new() -> Self {
354        Self::default()
355    }
356
357    /// Add a resource barrier.
358    pub fn add_barrier(&mut self, barrier: ResourceBarrier) {
359        self.barriers.push(barrier);
360    }
361
362    /// Schedule a task graph.
363    ///
364    /// Returns the topological execution order, or an error if a cycle exists.
365    pub fn schedule(&self, graph: &TaskGraph) -> Result<Vec<String>, String> {
366        graph.topological_sort()
367    }
368
369    /// Schedule and group independent tasks into parallel batches.
370    ///
371    /// Each inner `Vec` contains tasks that can run concurrently.
372    pub fn batch_schedule(&self, graph: &TaskGraph) -> Result<Vec<Vec<String>>, String> {
373        let order = self.schedule(graph)?;
374        let tasks = &graph.tasks;
375
376        // Compute the depth of each task (longest dependency chain)
377        let mut depth: HashMap<&str, usize> = HashMap::new();
378        for name in &order {
379            let task = &tasks[name.as_str()];
380            let d = task
381                .dependencies
382                .iter()
383                .filter_map(|dep| depth.get(dep.as_str()).copied())
384                .max()
385                .map(|m| m + 1)
386                .unwrap_or(0);
387            depth.insert(name.as_str(), d);
388        }
389
390        let max_depth = depth.values().copied().max().unwrap_or(0);
391        let mut batches: Vec<Vec<String>> = vec![vec![]; max_depth + 1];
392        for name in &order {
393            let d = *depth.get(name.as_str()).unwrap_or(&0);
394            batches[d].push(name.clone());
395        }
396        Ok(batches)
397    }
398}
399
400// ---------------------------------------------------------------------------
401// WorkloadBalancer
402// ---------------------------------------------------------------------------
403
404/// Splits large dispatches across frames to stay within a time budget.
405#[derive(Debug, Clone)]
406pub struct WorkloadBalancer {
407    /// GPU time budget per frame in milliseconds.
408    pub budget_ms: f64,
409    /// Accumulated pending tasks with their estimated costs.
410    pending: Vec<(ComputeTask, f64)>,
411}
412
413impl WorkloadBalancer {
414    /// Create a new balancer with the given budget.
415    pub fn new(budget_ms: f64) -> Self {
416        Self {
417            budget_ms,
418            pending: vec![],
419        }
420    }
421
422    /// Submit a task for scheduling.
423    pub fn submit(&mut self, task: ComputeTask) {
424        let cost = task.estimated_ms;
425        self.pending.push((task, cost));
426    }
427
428    /// Extract tasks that fit within this frame's budget.
429    ///
430    /// Higher-priority tasks are selected first.  Returns the tasks to
431    /// execute this frame.
432    pub fn extract_frame_work(&mut self) -> Vec<ComputeTask> {
433        // Sort by descending priority then descending estimated_ms
434        self.pending.sort_by(|a, b| {
435            b.0.priority
436                .cmp(&a.0.priority)
437                .then(b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal))
438        });
439
440        let mut remaining = self.budget_ms;
441        let mut this_frame = Vec::new();
442        let mut leftover = Vec::new();
443
444        for (task, cost) in self.pending.drain(..) {
445            if cost <= remaining || this_frame.is_empty() {
446                remaining -= cost;
447                this_frame.push(task);
448            } else {
449                leftover.push((task, cost));
450            }
451        }
452        self.pending = leftover;
453        this_frame
454    }
455
456    /// Number of pending tasks.
457    pub fn pending_count(&self) -> usize {
458        self.pending.len()
459    }
460}
461
462// ---------------------------------------------------------------------------
463// AsyncCompute
464// ---------------------------------------------------------------------------
465
466/// State of an async compute task.
467#[derive(Debug, Clone, PartialEq, Eq)]
468pub enum AsyncState {
469    /// Waiting to be dispatched.
470    Pending,
471    /// Currently running.
472    Running,
473    /// Completed successfully.
474    Done,
475    /// Failed with an error message.
476    Failed(String),
477}
478
479/// A promise-like result from an async compute submission.
480#[derive(Debug, Clone)]
481pub struct AsyncResult {
482    /// Task name.
483    pub name: String,
484    /// Current state.
485    pub state: AsyncState,
486    /// Simulated output data (bytes).
487    pub output: Vec<u8>,
488}
489
490impl AsyncResult {
491    /// True when the task has finished (successfully or not).
492    pub fn is_complete(&self) -> bool {
493        matches!(self.state, AsyncState::Done | AsyncState::Failed(_))
494    }
495}
496
497/// Simulated async compute queue.
498#[derive(Debug, Default)]
499pub struct AsyncCompute {
500    /// All submitted tasks.
501    results: Vec<AsyncResult>,
502}
503
504impl AsyncCompute {
505    /// Create a new async compute queue.
506    pub fn new() -> Self {
507        Self::default()
508    }
509
510    /// Submit a task for async execution.  Returns an index into the result list.
511    pub fn submit(&mut self, task: &ComputeTask) -> usize {
512        let idx = self.results.len();
513        self.results.push(AsyncResult {
514            name: task.name.clone(),
515            state: AsyncState::Pending,
516            output: vec![],
517        });
518        idx
519    }
520
521    /// Advance all pending tasks by one simulated tick.
522    ///
523    /// - Pending → Running
524    /// - Running → Done (with placeholder output)
525    pub fn tick(&mut self) {
526        for r in &mut self.results {
527            match r.state {
528                AsyncState::Pending => r.state = AsyncState::Running,
529                AsyncState::Running => {
530                    r.state = AsyncState::Done;
531                    r.output = vec![0u8; 4]; // placeholder
532                }
533                _ => {}
534            }
535        }
536    }
537
538    /// Get the result for submission index `idx`.
539    pub fn poll(&self, idx: usize) -> Option<&AsyncResult> {
540        self.results.get(idx)
541    }
542
543    /// Drain completed results.
544    pub fn drain_completed(&mut self) -> Vec<AsyncResult> {
545        let mut done = Vec::new();
546        let mut remaining = Vec::new();
547        for r in self.results.drain(..) {
548            if r.is_complete() {
549                done.push(r);
550            } else {
551                remaining.push(r);
552            }
553        }
554        self.results = remaining;
555        done
556    }
557}
558
559// ---------------------------------------------------------------------------
560// PipelineBarrier
561// ---------------------------------------------------------------------------
562
563/// Describes a memory barrier stage.
564#[derive(Debug, Clone, Copy, PartialEq, Eq)]
565pub enum PipelineStage {
566    /// Beginning of the pipeline.
567    Top,
568    /// Vertex shading.
569    Vertex,
570    /// Fragment / pixel shading.
571    Fragment,
572    /// Compute dispatch.
573    Compute,
574    /// Transfer / copy operations.
575    Transfer,
576    /// Color attachment output.
577    ColorAttachment,
578    /// Shader read.
579    ShaderRead,
580    /// End of pipeline.
581    Bottom,
582}
583
584/// A pipeline memory barrier between passes.
585#[derive(Debug, Clone)]
586pub struct PipelineBarrier {
587    /// Stage that must complete before the barrier.
588    pub src_stage: PipelineStage,
589    /// Stage that must wait after the barrier.
590    pub dst_stage: PipelineStage,
591    /// Human-readable label.
592    pub label: String,
593    /// Whether this is a color attachment → shader read transition.
594    pub color_to_shader_read: bool,
595}
596
597impl PipelineBarrier {
598    /// Create a color attachment output → shader read barrier.
599    pub fn color_attachment_to_shader_read(label: impl Into<String>) -> Self {
600        Self {
601            src_stage: PipelineStage::ColorAttachment,
602            dst_stage: PipelineStage::ShaderRead,
603            label: label.into(),
604            color_to_shader_read: true,
605        }
606    }
607
608    /// Create a compute → compute barrier (for storage buffer hazards).
609    pub fn compute_to_compute(label: impl Into<String>) -> Self {
610        Self {
611            src_stage: PipelineStage::Compute,
612            dst_stage: PipelineStage::Compute,
613            label: label.into(),
614            color_to_shader_read: false,
615        }
616    }
617
618    /// True when the barrier crosses a compute → read hazard.
619    pub fn is_compute_read_hazard(&self) -> bool {
620        self.src_stage == PipelineStage::Compute
621            && matches!(
622                self.dst_stage,
623                PipelineStage::ShaderRead | PipelineStage::Fragment
624            )
625    }
626}
627
628// ---------------------------------------------------------------------------
629// GpuTimestampQuery
630// ---------------------------------------------------------------------------
631
632/// A single GPU timestamp query pair.
633#[derive(Debug, Clone)]
634pub struct GpuTimestampQuery {
635    /// Label for this query.
636    pub label: String,
637    /// Simulated start time (nanoseconds).
638    pub start_ns: u64,
639    /// Simulated end time (nanoseconds).
640    pub end_ns: u64,
641    /// Whether `begin` has been called.
642    active: bool,
643}
644
645impl GpuTimestampQuery {
646    /// Create a new timestamp query.
647    pub fn new(label: impl Into<String>) -> Self {
648        Self {
649            label: label.into(),
650            start_ns: 0,
651            end_ns: 0,
652            active: false,
653        }
654    }
655
656    /// Record the start timestamp.
657    pub fn begin(&mut self, now_ns: u64) {
658        self.start_ns = now_ns;
659        self.active = true;
660    }
661
662    /// Record the end timestamp.
663    pub fn end(&mut self, now_ns: u64) {
664        self.end_ns = now_ns;
665        self.active = false;
666    }
667
668    /// Elapsed time in microseconds.
669    pub fn elapsed_us(&self) -> f64 {
670        (self.end_ns.saturating_sub(self.start_ns)) as f64 / 1_000.0
671    }
672
673    /// Elapsed time in milliseconds.
674    pub fn elapsed_ms(&self) -> f64 {
675        self.elapsed_us() / 1_000.0
676    }
677
678    /// True when a `begin` is outstanding.
679    pub fn is_active(&self) -> bool {
680        self.active
681    }
682}
683
684/// A pool of timestamp query pairs for profiling a frame.
685#[derive(Debug, Default)]
686pub struct TimestampPool {
687    /// All queries.
688    queries: Vec<GpuTimestampQuery>,
689}
690
691impl TimestampPool {
692    /// Create an empty pool.
693    pub fn new() -> Self {
694        Self::default()
695    }
696
697    /// Allocate and begin a new timestamp query.  Returns its index.
698    pub fn begin(&mut self, label: impl Into<String>, now_ns: u64) -> usize {
699        let mut q = GpuTimestampQuery::new(label);
700        q.begin(now_ns);
701        let idx = self.queries.len();
702        self.queries.push(q);
703        idx
704    }
705
706    /// End the query at `idx`.
707    pub fn end(&mut self, idx: usize, now_ns: u64) {
708        if let Some(q) = self.queries.get_mut(idx) {
709            q.end(now_ns);
710        }
711    }
712
713    /// Get elapsed ms for query `idx`.
714    pub fn elapsed_ms(&self, idx: usize) -> f64 {
715        self.queries.get(idx).map(|q| q.elapsed_ms()).unwrap_or(0.0)
716    }
717
718    /// Total elapsed ms across all finished queries.
719    pub fn total_ms(&self) -> f64 {
720        self.queries
721            .iter()
722            .filter(|q| !q.is_active())
723            .map(|q| q.elapsed_ms())
724            .sum()
725    }
726
727    /// Reset all queries.
728    pub fn reset(&mut self) {
729        self.queries.clear();
730    }
731}
732
733// ---------------------------------------------------------------------------
734// FrameGraph
735// ---------------------------------------------------------------------------
736
737/// A transient resource in the frame graph.
738#[derive(Debug, Clone)]
739pub struct FrameResource {
740    /// Resource name.
741    pub name: String,
742    /// Size in bytes.
743    pub size: usize,
744    /// First pass index that uses this resource.
745    pub first_use: usize,
746    /// Last pass index that uses this resource.
747    pub last_use: usize,
748    /// Allocated byte offset (set during aliasing).
749    pub offset: usize,
750}
751
752/// A render pass in the frame graph.
753#[derive(Debug, Clone)]
754pub struct FramePass {
755    /// Pass name.
756    pub name: String,
757    /// Resources read by this pass.
758    pub reads: Vec<String>,
759    /// Resources written by this pass.
760    pub writes: Vec<String>,
761    /// Pipeline barriers to inject before this pass.
762    pub barriers: Vec<PipelineBarrier>,
763}
764
765impl FramePass {
766    /// Create a new frame pass.
767    pub fn new(name: impl Into<String>) -> Self {
768        Self {
769            name: name.into(),
770            reads: vec![],
771            writes: vec![],
772            barriers: vec![],
773        }
774    }
775
776    /// Declare a resource read.
777    pub fn reads(mut self, res: impl Into<String>) -> Self {
778        self.reads.push(res.into());
779        self
780    }
781
782    /// Declare a resource write.
783    pub fn writes(mut self, res: impl Into<String>) -> Self {
784        self.writes.push(res.into());
785        self
786    }
787
788    /// Add a pipeline barrier.
789    pub fn barrier(mut self, b: PipelineBarrier) -> Self {
790        self.barriers.push(b);
791        self
792    }
793}
794
795/// A full-frame resource graph with transient resource aliasing.
796#[derive(Debug, Default)]
797pub struct FrameGraph {
798    /// All passes in submission order.
799    passes: Vec<FramePass>,
800    /// Declared transient resources.
801    resources: HashMap<String, FrameResource>,
802}
803
804impl FrameGraph {
805    /// Create an empty frame graph.
806    pub fn new() -> Self {
807        Self::default()
808    }
809
810    /// Add a render pass.
811    pub fn add_pass(&mut self, pass: FramePass) {
812        let idx = self.passes.len();
813        // Track resource lifetimes
814        for res in pass.reads.iter().chain(pass.writes.iter()) {
815            let e = self.resources.entry(res.clone()).or_insert(FrameResource {
816                name: res.clone(),
817                size: 0,
818                first_use: idx,
819                last_use: idx,
820                offset: 0,
821            });
822            // Update first_use if this is earlier (handles declare_resource + add_pass order)
823            if idx < e.first_use {
824                e.first_use = idx;
825            }
826            if idx > e.last_use {
827                e.last_use = idx;
828            }
829        }
830        self.passes.push(pass);
831    }
832
833    /// Declare a transient resource with its size.
834    pub fn declare_resource(&mut self, name: impl Into<String>, size: usize) {
835        let name = name.into();
836        let e = self.resources.entry(name.clone()).or_insert(FrameResource {
837            name: name.clone(),
838            size: 0,
839            first_use: usize::MAX,
840            last_use: 0,
841            offset: 0,
842        });
843        e.size = size;
844    }
845
846    /// Run a simple aliasing pass: resources that do not overlap in lifetime
847    /// share the same memory offset.
848    pub fn alias_resources(&mut self) {
849        // Greedy aliasing: sort by first_use, then assign offsets
850        let names: Vec<String> = {
851            let mut v: Vec<String> = self.resources.keys().cloned().collect();
852            v.sort();
853            v
854        };
855
856        // Track which offsets are "free" at each pass
857        let mut allocations: Vec<(usize, usize, usize)> = Vec::new(); // (offset, end_pass, size)
858        let pass_count = self.passes.len();
859
860        for name in &names {
861            if let Some(res) = self.resources.get_mut(name) {
862                if res.first_use > pass_count {
863                    continue;
864                }
865                // Find a free slot
866                let mut found = None;
867                for (off, end, sz) in &mut allocations {
868                    if *end < res.first_use && *sz >= res.size {
869                        found = Some(*off);
870                        *end = res.last_use;
871                        break;
872                    }
873                }
874                if let Some(off) = found {
875                    res.offset = off;
876                } else {
877                    let off: usize = allocations.iter().map(|(o, _, s)| o + s).max().unwrap_or(0);
878                    res.offset = off;
879                    let (last_use, size) = (res.last_use, res.size);
880                    allocations.push((off, last_use, size));
881                }
882            }
883        }
884    }
885
886    /// Compute the peak memory required (maximum end of any allocation).
887    pub fn peak_memory(&self) -> usize {
888        self.resources
889            .values()
890            .map(|r| r.offset + r.size)
891            .max()
892            .unwrap_or(0)
893    }
894
895    /// Number of passes.
896    pub fn pass_count(&self) -> usize {
897        self.passes.len()
898    }
899
900    /// Get all barriers for a given pass by index.
901    pub fn barriers_for_pass(&self, idx: usize) -> &[PipelineBarrier] {
902        self.passes
903            .get(idx)
904            .map(|p| p.barriers.as_slice())
905            .unwrap_or(&[])
906    }
907
908    /// Collect all pipeline barriers across the frame in order.
909    pub fn all_barriers(&self) -> Vec<&PipelineBarrier> {
910        self.passes.iter().flat_map(|p| p.barriers.iter()).collect()
911    }
912
913    /// Find all resources used by pass at index `idx`.
914    pub fn resources_for_pass(&self, idx: usize) -> Vec<&str> {
915        if let Some(pass) = self.passes.get(idx) {
916            pass.reads
917                .iter()
918                .chain(pass.writes.iter())
919                .map(|s| s.as_str())
920                .collect::<HashSet<_>>()
921                .into_iter()
922                .collect()
923        } else {
924            vec![]
925        }
926    }
927}
928
929// ---------------------------------------------------------------------------
930// Tests
931// ---------------------------------------------------------------------------
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936
937    // --- TaskPriority tests ---
938
939    #[test]
940    fn test_priority_ordering() {
941        assert!(TaskPriority::RealTime > TaskPriority::High);
942        assert!(TaskPriority::High > TaskPriority::Normal);
943        assert!(TaskPriority::Normal > TaskPriority::Low);
944        assert!(TaskPriority::Low > TaskPriority::Background);
945    }
946
947    // --- ComputeTask tests ---
948
949    #[test]
950    fn test_compute_task_invocations_1d() {
951        let t = ComputeTask::new_1d("particles", 100);
952        // dispatch 100 × 1 × 1, workgroup 64 × 1 × 1 → 6400 invocations
953        assert_eq!(t.total_invocations(), 6400);
954    }
955
956    #[test]
957    fn test_compute_task_invocations_2d() {
958        let t = ComputeTask::new_2d("shadows", 8, 8);
959        // dispatch 8×8×1, workgroup 8×8×1 → 8*8*8*8 = 4096
960        assert_eq!(t.total_invocations(), 4096);
961    }
962
963    #[test]
964    fn test_compute_task_depends_on() {
965        let t = ComputeTask::new_1d("B", 1).depends_on("A");
966        assert!(t.dependencies.contains(&"A".to_owned()));
967    }
968
969    #[test]
970    fn test_compute_task_priority() {
971        let t = ComputeTask::new_1d("t", 1).with_priority(TaskPriority::High);
972        assert_eq!(t.priority, TaskPriority::High);
973    }
974
975    // --- TaskGraph tests ---
976
977    #[test]
978    fn test_task_graph_topo_sort_simple() {
979        let mut g = TaskGraph::new();
980        g.add_task(ComputeTask::new_1d("A", 1));
981        g.add_task(ComputeTask::new_1d("B", 1).depends_on("A"));
982        g.add_task(ComputeTask::new_1d("C", 1).depends_on("B"));
983        let order = g.topological_sort().unwrap();
984        let pos: HashMap<&str, usize> = order
985            .iter()
986            .enumerate()
987            .map(|(i, s)| (s.as_str(), i))
988            .collect();
989        assert!(pos["A"] < pos["B"]);
990        assert!(pos["B"] < pos["C"]);
991    }
992
993    #[test]
994    fn test_task_graph_topo_sort_diamond() {
995        let mut g = TaskGraph::new();
996        g.add_task(ComputeTask::new_1d("A", 1));
997        g.add_task(ComputeTask::new_1d("B", 1).depends_on("A"));
998        g.add_task(ComputeTask::new_1d("C", 1).depends_on("A"));
999        g.add_task(ComputeTask::new_1d("D", 1).depends_on("B").depends_on("C"));
1000        let order = g.topological_sort().unwrap();
1001        assert_eq!(order.len(), 4);
1002    }
1003
1004    #[test]
1005    fn test_task_graph_cycle_detection() {
1006        let mut g = TaskGraph::new();
1007        g.add_task(ComputeTask::new_1d("A", 1).depends_on("B"));
1008        g.add_task(ComputeTask::new_1d("B", 1).depends_on("A"));
1009        assert!(g.has_cycle());
1010    }
1011
1012    #[test]
1013    fn test_task_graph_critical_path() {
1014        let mut g = TaskGraph::new();
1015        g.add_task(ComputeTask::new_1d("A", 1).with_estimated_ms(1.0));
1016        g.add_task(
1017            ComputeTask::new_1d("B", 1)
1018                .depends_on("A")
1019                .with_estimated_ms(2.0),
1020        );
1021        g.add_task(ComputeTask::new_1d("C", 1).with_estimated_ms(10.0));
1022        let cp = g.critical_path();
1023        // C alone is the critical path (10ms vs A+B = 3ms)
1024        assert!(cp.contains(&"C".to_owned()));
1025    }
1026
1027    #[test]
1028    fn test_task_graph_empty_topo() {
1029        let g = TaskGraph::new();
1030        let order = g.topological_sort().unwrap();
1031        assert!(order.is_empty());
1032    }
1033
1034    // --- TaskScheduler tests ---
1035
1036    #[test]
1037    fn test_scheduler_schedule() {
1038        let mut g = TaskGraph::new();
1039        g.add_task(ComputeTask::new_1d("X", 1));
1040        g.add_task(ComputeTask::new_1d("Y", 1).depends_on("X"));
1041        let sched = TaskScheduler::new();
1042        let order = sched.schedule(&g).unwrap();
1043        assert_eq!(order.len(), 2);
1044    }
1045
1046    #[test]
1047    fn test_scheduler_batch_schedule() {
1048        let mut g = TaskGraph::new();
1049        g.add_task(ComputeTask::new_1d("A", 1));
1050        g.add_task(ComputeTask::new_1d("B", 1));
1051        g.add_task(ComputeTask::new_1d("C", 1).depends_on("A").depends_on("B"));
1052        let sched = TaskScheduler::new();
1053        let batches = sched.batch_schedule(&g).unwrap();
1054        // A and B should be in the same batch (depth 0)
1055        assert!(batches[0].len() >= 2);
1056        // C in a later batch
1057        assert!(batches.len() >= 2);
1058    }
1059
1060    // --- ResourceBarrier tests ---
1061
1062    #[test]
1063    fn test_resource_barrier_raw() {
1064        let b = ResourceBarrier::raw("write_task", "read_task", "position_buffer");
1065        assert_eq!(b.barrier_type, BarrierType::ReadAfterWrite);
1066        assert_eq!(b.resource, "position_buffer");
1067    }
1068
1069    #[test]
1070    fn test_resource_barrier_war() {
1071        let b = ResourceBarrier::war("reader", "writer", "depth");
1072        assert_eq!(b.barrier_type, BarrierType::WriteAfterRead);
1073    }
1074
1075    // --- WorkloadBalancer tests ---
1076
1077    #[test]
1078    fn test_workload_balancer_respects_budget() {
1079        let mut wb = WorkloadBalancer::new(10.0);
1080        wb.submit(ComputeTask::new_1d("A", 1).with_estimated_ms(3.0));
1081        wb.submit(ComputeTask::new_1d("B", 1).with_estimated_ms(4.0));
1082        wb.submit(ComputeTask::new_1d("C", 1).with_estimated_ms(6.0));
1083        let frame = wb.extract_frame_work();
1084        let total: f64 = frame.iter().map(|t| t.estimated_ms).sum();
1085        // At most budget + one overflow (for non-empty guarantee)
1086        assert!(total <= 10.0 + 6.0);
1087    }
1088
1089    #[test]
1090    fn test_workload_balancer_priority_order() {
1091        let mut wb = WorkloadBalancer::new(5.0);
1092        wb.submit(
1093            ComputeTask::new_1d("low", 1)
1094                .with_priority(TaskPriority::Low)
1095                .with_estimated_ms(2.0),
1096        );
1097        wb.submit(
1098            ComputeTask::new_1d("rt", 1)
1099                .with_priority(TaskPriority::RealTime)
1100                .with_estimated_ms(2.0),
1101        );
1102        let frame = wb.extract_frame_work();
1103        // RealTime should be first
1104        assert_eq!(frame[0].name, "rt");
1105    }
1106
1107    #[test]
1108    fn test_workload_balancer_pending_count() {
1109        let mut wb = WorkloadBalancer::new(1.0);
1110        for i in 0..5 {
1111            wb.submit(ComputeTask::new_1d(format!("t{i}"), 1).with_estimated_ms(1.0));
1112        }
1113        wb.extract_frame_work();
1114        assert!(wb.pending_count() < 5);
1115    }
1116
1117    // --- AsyncCompute tests ---
1118
1119    #[test]
1120    fn test_async_compute_submit_poll() {
1121        let mut ac = AsyncCompute::new();
1122        let task = ComputeTask::new_1d("sim", 64);
1123        let idx = ac.submit(&task);
1124        let r = ac.poll(idx).unwrap();
1125        assert_eq!(r.state, AsyncState::Pending);
1126    }
1127
1128    #[test]
1129    fn test_async_compute_tick_to_done() {
1130        let mut ac = AsyncCompute::new();
1131        let task = ComputeTask::new_1d("sim", 1);
1132        let idx = ac.submit(&task);
1133        ac.tick(); // Pending → Running
1134        ac.tick(); // Running → Done
1135        assert_eq!(ac.poll(idx).unwrap().state, AsyncState::Done);
1136    }
1137
1138    #[test]
1139    fn test_async_compute_drain_completed() {
1140        let mut ac = AsyncCompute::new();
1141        let t = ComputeTask::new_1d("t", 1);
1142        ac.submit(&t);
1143        ac.tick();
1144        ac.tick();
1145        let done = ac.drain_completed();
1146        assert_eq!(done.len(), 1);
1147        assert!(ac.poll(0).is_none()); // drained
1148    }
1149
1150    // --- PipelineBarrier tests ---
1151
1152    #[test]
1153    fn test_pipeline_barrier_color_to_shader_read() {
1154        let b = PipelineBarrier::color_attachment_to_shader_read("gbuffer");
1155        assert!(b.color_to_shader_read);
1156        assert_eq!(b.src_stage, PipelineStage::ColorAttachment);
1157        assert_eq!(b.dst_stage, PipelineStage::ShaderRead);
1158    }
1159
1160    #[test]
1161    fn test_pipeline_barrier_compute_to_compute() {
1162        let b = PipelineBarrier::compute_to_compute("particles");
1163        assert_eq!(b.src_stage, PipelineStage::Compute);
1164        assert!(!b.is_compute_read_hazard()); // dst is also Compute
1165    }
1166
1167    #[test]
1168    fn test_pipeline_barrier_compute_read_hazard() {
1169        let b = PipelineBarrier {
1170            src_stage: PipelineStage::Compute,
1171            dst_stage: PipelineStage::ShaderRead,
1172            label: "test".to_owned(),
1173            color_to_shader_read: false,
1174        };
1175        assert!(b.is_compute_read_hazard());
1176    }
1177
1178    // --- GpuTimestampQuery tests ---
1179
1180    #[test]
1181    fn test_timestamp_query_elapsed() {
1182        let mut q = GpuTimestampQuery::new("render");
1183        q.begin(1_000_000); // 1 ms in ns
1184        q.end(2_000_000); // 2 ms in ns
1185        assert!((q.elapsed_ms() - 1.0).abs() < 1e-6);
1186    }
1187
1188    #[test]
1189    fn test_timestamp_query_is_active() {
1190        let mut q = GpuTimestampQuery::new("x");
1191        assert!(!q.is_active());
1192        q.begin(0);
1193        assert!(q.is_active());
1194        q.end(100);
1195        assert!(!q.is_active());
1196    }
1197
1198    #[test]
1199    fn test_timestamp_pool_total() {
1200        let mut pool = TimestampPool::new();
1201        let i0 = pool.begin("a", 0);
1202        pool.end(i0, 1_000_000);
1203        let i1 = pool.begin("b", 0);
1204        pool.end(i1, 2_000_000);
1205        let total = pool.total_ms();
1206        assert!((total - 3.0).abs() < 1e-6, "total={total}");
1207    }
1208
1209    #[test]
1210    fn test_timestamp_pool_reset() {
1211        let mut pool = TimestampPool::new();
1212        pool.begin("x", 0);
1213        pool.reset();
1214        assert!((pool.total_ms()).abs() < 1e-10);
1215    }
1216
1217    // --- FrameGraph tests ---
1218
1219    #[test]
1220    fn test_frame_graph_add_pass() {
1221        let mut fg = FrameGraph::new();
1222        fg.add_pass(FramePass::new("gbuffer").writes("color").writes("depth"));
1223        fg.add_pass(
1224            FramePass::new("lighting")
1225                .reads("color")
1226                .reads("depth")
1227                .writes("hdr"),
1228        );
1229        assert_eq!(fg.pass_count(), 2);
1230    }
1231
1232    #[test]
1233    fn test_frame_graph_resource_lifetime() {
1234        let mut fg = FrameGraph::new();
1235        fg.declare_resource("color", 1024 * 1024 * 4);
1236        fg.add_pass(FramePass::new("p0").writes("color"));
1237        fg.add_pass(FramePass::new("p1").reads("color"));
1238        let res = &fg.resources["color"];
1239        assert_eq!(res.first_use, 0);
1240        assert_eq!(res.last_use, 1);
1241    }
1242
1243    #[test]
1244    fn test_frame_graph_aliasing() {
1245        let mut fg = FrameGraph::new();
1246        fg.declare_resource("A", 1024);
1247        fg.declare_resource("B", 1024);
1248        fg.add_pass(FramePass::new("p0").writes("A"));
1249        fg.add_pass(FramePass::new("p1").reads("A"));
1250        fg.add_pass(FramePass::new("p2").writes("B"));
1251        fg.alias_resources();
1252        // B's lifetime starts after A ends, so they may share memory
1253        let peak = fg.peak_memory();
1254        assert!(peak > 0);
1255    }
1256
1257    #[test]
1258    fn test_frame_graph_barriers() {
1259        let mut fg = FrameGraph::new();
1260        fg.add_pass(
1261            FramePass::new("render")
1262                .barrier(PipelineBarrier::color_attachment_to_shader_read("test")),
1263        );
1264        let barriers = fg.barriers_for_pass(0);
1265        assert_eq!(barriers.len(), 1);
1266    }
1267
1268    #[test]
1269    fn test_frame_graph_all_barriers() {
1270        let mut fg = FrameGraph::new();
1271        fg.add_pass(FramePass::new("p0").barrier(PipelineBarrier::compute_to_compute("c0")));
1272        fg.add_pass(FramePass::new("p1").barrier(PipelineBarrier::compute_to_compute("c1")));
1273        assert_eq!(fg.all_barriers().len(), 2);
1274    }
1275}