1#![allow(dead_code)]
10#![allow(missing_docs)]
11
12use std::collections::{HashMap, HashSet, VecDeque};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
20pub enum TaskPriority {
21 RealTime = 4,
23 High = 3,
25 #[default]
27 Normal = 2,
28 Low = 1,
30 Background = 0,
32}
33
34#[derive(Debug, Clone)]
40pub struct ComputeTask {
41 pub name: String,
43 pub workgroup_size: [u32; 3],
45 pub dispatch_count: [u32; 3],
47 pub dependencies: Vec<String>,
49 pub priority: TaskPriority,
51 pub estimated_ms: f64,
53}
54
55impl ComputeTask {
56 pub fn new_1d(name: impl Into<String>, dispatch_x: u32) -> Self {
58 Self {
59 name: name.into(),
60 workgroup_size: [64, 1, 1],
61 dispatch_count: [dispatch_x, 1, 1],
62 dependencies: vec![],
63 priority: TaskPriority::Normal,
64 estimated_ms: 1.0,
65 }
66 }
67
68 pub fn new_2d(name: impl Into<String>, dispatch_x: u32, dispatch_y: u32) -> Self {
70 Self {
71 name: name.into(),
72 workgroup_size: [8, 8, 1],
73 dispatch_count: [dispatch_x, dispatch_y, 1],
74 dependencies: vec![],
75 priority: TaskPriority::Normal,
76 estimated_ms: 1.0,
77 }
78 }
79
80 pub fn total_workgroups(&self) -> u64 {
82 self.dispatch_count[0] as u64
83 * self.dispatch_count[1] as u64
84 * self.dispatch_count[2] as u64
85 }
86
87 pub fn total_invocations(&self) -> u64 {
89 self.total_workgroups()
90 * self.workgroup_size[0] as u64
91 * self.workgroup_size[1] as u64
92 * self.workgroup_size[2] as u64
93 }
94
95 pub fn depends_on(mut self, dep: impl Into<String>) -> Self {
97 self.dependencies.push(dep.into());
98 self
99 }
100
101 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
103 self.priority = priority;
104 self
105 }
106
107 pub fn with_estimated_ms(mut self, ms: f64) -> Self {
109 self.estimated_ms = ms;
110 self
111 }
112}
113
114#[derive(Debug, Clone, Default)]
120pub struct TaskGraph {
121 tasks: HashMap<String, ComputeTask>,
123}
124
125impl TaskGraph {
126 pub fn new() -> Self {
128 Self::default()
129 }
130
131 pub fn add_task(&mut self, task: ComputeTask) {
133 self.tasks.insert(task.name.clone(), task);
134 }
135
136 pub fn remove_task(&mut self, name: &str) {
138 self.tasks.remove(name);
139 }
140
141 pub fn len(&self) -> usize {
143 self.tasks.len()
144 }
145
146 pub fn is_empty(&self) -> bool {
148 self.tasks.is_empty()
149 }
150
151 pub fn topological_sort(&self) -> Result<Vec<String>, String> {
156 let mut in_degree: HashMap<&str, usize> = HashMap::new();
158 let mut rev: HashMap<&str, Vec<&str>> = HashMap::new(); for (name, task) in &self.tasks {
161 in_degree.entry(name.as_str()).or_insert(0);
162 for dep in &task.dependencies {
163 if !self.tasks.contains_key(dep.as_str()) {
164 continue;
166 }
167 rev.entry(dep.as_str()).or_default().push(name.as_str());
169 *in_degree.entry(name.as_str()).or_insert(0) += 1;
170 }
171 }
172
173 let mut queue: VecDeque<&str> = in_degree
174 .iter()
175 .filter(|(_, d)| **d == 0)
176 .map(|(&n, _)| n)
177 .collect();
178
179 let mut queue_vec: Vec<&str> = queue.drain(..).collect();
181 queue_vec.sort();
182 queue.extend(queue_vec);
183
184 let mut order = Vec::new();
185 while let Some(name) = queue.pop_front() {
186 order.push(name.to_owned());
187 if let Some(dependents) = rev.get(name) {
188 let mut next: Vec<&str> = dependents
189 .iter()
190 .filter_map(|&d| {
191 let deg = in_degree.get_mut(d)?;
192 *deg -= 1;
193 if *deg == 0 { Some(d) } else { None }
194 })
195 .collect();
196 next.sort();
197 queue.extend(next);
198 }
199 }
200
201 if order.len() != self.tasks.len() {
202 let cycle_node = self
204 .tasks
205 .keys()
206 .find(|n| !order.contains(*n))
207 .cloned()
208 .unwrap_or_else(|| "unknown".to_owned());
209 Err(cycle_node)
210 } else {
211 Ok(order)
212 }
213 }
214
215 pub fn critical_path(&self) -> Vec<String> {
219 let order = match self.topological_sort() {
220 Ok(o) => o,
221 Err(_) => return vec![],
222 };
223
224 let mut eft: HashMap<&str, f64> = HashMap::new();
226 let mut pred: HashMap<&str, &str> = HashMap::new();
227
228 for name in &order {
229 let task = &self.tasks[name.as_str()];
230 let dep_max = task
231 .dependencies
232 .iter()
233 .filter_map(|d| eft.get(d.as_str()).copied())
234 .fold(0.0f64, f64::max);
235 let ef = dep_max + task.estimated_ms;
236 eft.insert(name.as_str(), ef);
237 if let Some(best_pred) = task
239 .dependencies
240 .iter()
241 .filter_map(|d| {
242 let t = eft.get(d.as_str()).copied()?;
243 Some((d.as_str(), t))
244 })
245 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
246 .map(|(d, _)| d)
247 {
248 pred.insert(name.as_str(), best_pred);
249 }
250 }
251
252 let end = order.iter().max_by(|a, b| {
254 eft.get(a.as_str())
255 .unwrap_or(&0.0)
256 .partial_cmp(eft.get(b.as_str()).unwrap_or(&0.0))
257 .expect("operation should succeed")
258 });
259
260 let mut path = Vec::new();
261 let mut cur = match end {
262 Some(s) => s.as_str(),
263 None => return vec![],
264 };
265 loop {
266 path.push(cur.to_owned());
267 match pred.get(cur) {
268 Some(&p) => cur = p,
269 None => break,
270 }
271 }
272 path.reverse();
273 path
274 }
275
276 pub fn has_cycle(&self) -> bool {
278 self.topological_sort().is_err()
279 }
280}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288pub enum BarrierType {
289 ReadAfterWrite,
291 WriteAfterRead,
293 WriteAfterWrite,
295}
296
297#[derive(Debug, Clone)]
299pub struct ResourceBarrier {
300 pub producer: String,
302 pub consumer: String,
304 pub barrier_type: BarrierType,
306 pub resource: String,
308}
309
310impl ResourceBarrier {
311 pub fn raw(
313 producer: impl Into<String>,
314 consumer: impl Into<String>,
315 resource: impl Into<String>,
316 ) -> Self {
317 Self {
318 producer: producer.into(),
319 consumer: consumer.into(),
320 barrier_type: BarrierType::ReadAfterWrite,
321 resource: resource.into(),
322 }
323 }
324
325 pub fn war(
327 producer: impl Into<String>,
328 consumer: impl Into<String>,
329 resource: impl Into<String>,
330 ) -> Self {
331 Self {
332 producer: producer.into(),
333 consumer: consumer.into(),
334 barrier_type: BarrierType::WriteAfterRead,
335 resource: resource.into(),
336 }
337 }
338}
339
340#[derive(Debug, Default)]
346pub struct TaskScheduler {
347 pub barriers: Vec<ResourceBarrier>,
349}
350
351impl TaskScheduler {
352 pub fn new() -> Self {
354 Self::default()
355 }
356
357 pub fn add_barrier(&mut self, barrier: ResourceBarrier) {
359 self.barriers.push(barrier);
360 }
361
362 pub fn schedule(&self, graph: &TaskGraph) -> Result<Vec<String>, String> {
366 graph.topological_sort()
367 }
368
369 pub fn batch_schedule(&self, graph: &TaskGraph) -> Result<Vec<Vec<String>>, String> {
373 let order = self.schedule(graph)?;
374 let tasks = &graph.tasks;
375
376 let mut depth: HashMap<&str, usize> = HashMap::new();
378 for name in &order {
379 let task = &tasks[name.as_str()];
380 let d = task
381 .dependencies
382 .iter()
383 .filter_map(|dep| depth.get(dep.as_str()).copied())
384 .max()
385 .map(|m| m + 1)
386 .unwrap_or(0);
387 depth.insert(name.as_str(), d);
388 }
389
390 let max_depth = depth.values().copied().max().unwrap_or(0);
391 let mut batches: Vec<Vec<String>> = vec![vec![]; max_depth + 1];
392 for name in &order {
393 let d = *depth.get(name.as_str()).unwrap_or(&0);
394 batches[d].push(name.clone());
395 }
396 Ok(batches)
397 }
398}
399
400#[derive(Debug, Clone)]
406pub struct WorkloadBalancer {
407 pub budget_ms: f64,
409 pending: Vec<(ComputeTask, f64)>,
411}
412
413impl WorkloadBalancer {
414 pub fn new(budget_ms: f64) -> Self {
416 Self {
417 budget_ms,
418 pending: vec![],
419 }
420 }
421
422 pub fn submit(&mut self, task: ComputeTask) {
424 let cost = task.estimated_ms;
425 self.pending.push((task, cost));
426 }
427
428 pub fn extract_frame_work(&mut self) -> Vec<ComputeTask> {
433 self.pending.sort_by(|a, b| {
435 b.0.priority
436 .cmp(&a.0.priority)
437 .then(b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal))
438 });
439
440 let mut remaining = self.budget_ms;
441 let mut this_frame = Vec::new();
442 let mut leftover = Vec::new();
443
444 for (task, cost) in self.pending.drain(..) {
445 if cost <= remaining || this_frame.is_empty() {
446 remaining -= cost;
447 this_frame.push(task);
448 } else {
449 leftover.push((task, cost));
450 }
451 }
452 self.pending = leftover;
453 this_frame
454 }
455
456 pub fn pending_count(&self) -> usize {
458 self.pending.len()
459 }
460}
461
462#[derive(Debug, Clone, PartialEq, Eq)]
468pub enum AsyncState {
469 Pending,
471 Running,
473 Done,
475 Failed(String),
477}
478
479#[derive(Debug, Clone)]
481pub struct AsyncResult {
482 pub name: String,
484 pub state: AsyncState,
486 pub output: Vec<u8>,
488}
489
490impl AsyncResult {
491 pub fn is_complete(&self) -> bool {
493 matches!(self.state, AsyncState::Done | AsyncState::Failed(_))
494 }
495}
496
497#[derive(Debug, Default)]
499pub struct AsyncCompute {
500 results: Vec<AsyncResult>,
502}
503
504impl AsyncCompute {
505 pub fn new() -> Self {
507 Self::default()
508 }
509
510 pub fn submit(&mut self, task: &ComputeTask) -> usize {
512 let idx = self.results.len();
513 self.results.push(AsyncResult {
514 name: task.name.clone(),
515 state: AsyncState::Pending,
516 output: vec![],
517 });
518 idx
519 }
520
521 pub fn tick(&mut self) {
526 for r in &mut self.results {
527 match r.state {
528 AsyncState::Pending => r.state = AsyncState::Running,
529 AsyncState::Running => {
530 r.state = AsyncState::Done;
531 r.output = vec![0u8; 4]; }
533 _ => {}
534 }
535 }
536 }
537
538 pub fn poll(&self, idx: usize) -> Option<&AsyncResult> {
540 self.results.get(idx)
541 }
542
543 pub fn drain_completed(&mut self) -> Vec<AsyncResult> {
545 let mut done = Vec::new();
546 let mut remaining = Vec::new();
547 for r in self.results.drain(..) {
548 if r.is_complete() {
549 done.push(r);
550 } else {
551 remaining.push(r);
552 }
553 }
554 self.results = remaining;
555 done
556 }
557}
558
559#[derive(Debug, Clone, Copy, PartialEq, Eq)]
565pub enum PipelineStage {
566 Top,
568 Vertex,
570 Fragment,
572 Compute,
574 Transfer,
576 ColorAttachment,
578 ShaderRead,
580 Bottom,
582}
583
584#[derive(Debug, Clone)]
586pub struct PipelineBarrier {
587 pub src_stage: PipelineStage,
589 pub dst_stage: PipelineStage,
591 pub label: String,
593 pub color_to_shader_read: bool,
595}
596
597impl PipelineBarrier {
598 pub fn color_attachment_to_shader_read(label: impl Into<String>) -> Self {
600 Self {
601 src_stage: PipelineStage::ColorAttachment,
602 dst_stage: PipelineStage::ShaderRead,
603 label: label.into(),
604 color_to_shader_read: true,
605 }
606 }
607
608 pub fn compute_to_compute(label: impl Into<String>) -> Self {
610 Self {
611 src_stage: PipelineStage::Compute,
612 dst_stage: PipelineStage::Compute,
613 label: label.into(),
614 color_to_shader_read: false,
615 }
616 }
617
618 pub fn is_compute_read_hazard(&self) -> bool {
620 self.src_stage == PipelineStage::Compute
621 && matches!(
622 self.dst_stage,
623 PipelineStage::ShaderRead | PipelineStage::Fragment
624 )
625 }
626}
627
628#[derive(Debug, Clone)]
634pub struct GpuTimestampQuery {
635 pub label: String,
637 pub start_ns: u64,
639 pub end_ns: u64,
641 active: bool,
643}
644
645impl GpuTimestampQuery {
646 pub fn new(label: impl Into<String>) -> Self {
648 Self {
649 label: label.into(),
650 start_ns: 0,
651 end_ns: 0,
652 active: false,
653 }
654 }
655
656 pub fn begin(&mut self, now_ns: u64) {
658 self.start_ns = now_ns;
659 self.active = true;
660 }
661
662 pub fn end(&mut self, now_ns: u64) {
664 self.end_ns = now_ns;
665 self.active = false;
666 }
667
668 pub fn elapsed_us(&self) -> f64 {
670 (self.end_ns.saturating_sub(self.start_ns)) as f64 / 1_000.0
671 }
672
673 pub fn elapsed_ms(&self) -> f64 {
675 self.elapsed_us() / 1_000.0
676 }
677
678 pub fn is_active(&self) -> bool {
680 self.active
681 }
682}
683
684#[derive(Debug, Default)]
686pub struct TimestampPool {
687 queries: Vec<GpuTimestampQuery>,
689}
690
691impl TimestampPool {
692 pub fn new() -> Self {
694 Self::default()
695 }
696
697 pub fn begin(&mut self, label: impl Into<String>, now_ns: u64) -> usize {
699 let mut q = GpuTimestampQuery::new(label);
700 q.begin(now_ns);
701 let idx = self.queries.len();
702 self.queries.push(q);
703 idx
704 }
705
706 pub fn end(&mut self, idx: usize, now_ns: u64) {
708 if let Some(q) = self.queries.get_mut(idx) {
709 q.end(now_ns);
710 }
711 }
712
713 pub fn elapsed_ms(&self, idx: usize) -> f64 {
715 self.queries.get(idx).map(|q| q.elapsed_ms()).unwrap_or(0.0)
716 }
717
718 pub fn total_ms(&self) -> f64 {
720 self.queries
721 .iter()
722 .filter(|q| !q.is_active())
723 .map(|q| q.elapsed_ms())
724 .sum()
725 }
726
727 pub fn reset(&mut self) {
729 self.queries.clear();
730 }
731}
732
733#[derive(Debug, Clone)]
739pub struct FrameResource {
740 pub name: String,
742 pub size: usize,
744 pub first_use: usize,
746 pub last_use: usize,
748 pub offset: usize,
750}
751
752#[derive(Debug, Clone)]
754pub struct FramePass {
755 pub name: String,
757 pub reads: Vec<String>,
759 pub writes: Vec<String>,
761 pub barriers: Vec<PipelineBarrier>,
763}
764
765impl FramePass {
766 pub fn new(name: impl Into<String>) -> Self {
768 Self {
769 name: name.into(),
770 reads: vec![],
771 writes: vec![],
772 barriers: vec![],
773 }
774 }
775
776 pub fn reads(mut self, res: impl Into<String>) -> Self {
778 self.reads.push(res.into());
779 self
780 }
781
782 pub fn writes(mut self, res: impl Into<String>) -> Self {
784 self.writes.push(res.into());
785 self
786 }
787
788 pub fn barrier(mut self, b: PipelineBarrier) -> Self {
790 self.barriers.push(b);
791 self
792 }
793}
794
795#[derive(Debug, Default)]
797pub struct FrameGraph {
798 passes: Vec<FramePass>,
800 resources: HashMap<String, FrameResource>,
802}
803
804impl FrameGraph {
805 pub fn new() -> Self {
807 Self::default()
808 }
809
810 pub fn add_pass(&mut self, pass: FramePass) {
812 let idx = self.passes.len();
813 for res in pass.reads.iter().chain(pass.writes.iter()) {
815 let e = self.resources.entry(res.clone()).or_insert(FrameResource {
816 name: res.clone(),
817 size: 0,
818 first_use: idx,
819 last_use: idx,
820 offset: 0,
821 });
822 if idx < e.first_use {
824 e.first_use = idx;
825 }
826 if idx > e.last_use {
827 e.last_use = idx;
828 }
829 }
830 self.passes.push(pass);
831 }
832
833 pub fn declare_resource(&mut self, name: impl Into<String>, size: usize) {
835 let name = name.into();
836 let e = self.resources.entry(name.clone()).or_insert(FrameResource {
837 name: name.clone(),
838 size: 0,
839 first_use: usize::MAX,
840 last_use: 0,
841 offset: 0,
842 });
843 e.size = size;
844 }
845
846 pub fn alias_resources(&mut self) {
849 let names: Vec<String> = {
851 let mut v: Vec<String> = self.resources.keys().cloned().collect();
852 v.sort();
853 v
854 };
855
856 let mut allocations: Vec<(usize, usize, usize)> = Vec::new(); let pass_count = self.passes.len();
859
860 for name in &names {
861 if let Some(res) = self.resources.get_mut(name) {
862 if res.first_use > pass_count {
863 continue;
864 }
865 let mut found = None;
867 for (off, end, sz) in &mut allocations {
868 if *end < res.first_use && *sz >= res.size {
869 found = Some(*off);
870 *end = res.last_use;
871 break;
872 }
873 }
874 if let Some(off) = found {
875 res.offset = off;
876 } else {
877 let off: usize = allocations.iter().map(|(o, _, s)| o + s).max().unwrap_or(0);
878 res.offset = off;
879 let (last_use, size) = (res.last_use, res.size);
880 allocations.push((off, last_use, size));
881 }
882 }
883 }
884 }
885
886 pub fn peak_memory(&self) -> usize {
888 self.resources
889 .values()
890 .map(|r| r.offset + r.size)
891 .max()
892 .unwrap_or(0)
893 }
894
895 pub fn pass_count(&self) -> usize {
897 self.passes.len()
898 }
899
900 pub fn barriers_for_pass(&self, idx: usize) -> &[PipelineBarrier] {
902 self.passes
903 .get(idx)
904 .map(|p| p.barriers.as_slice())
905 .unwrap_or(&[])
906 }
907
908 pub fn all_barriers(&self) -> Vec<&PipelineBarrier> {
910 self.passes.iter().flat_map(|p| p.barriers.iter()).collect()
911 }
912
913 pub fn resources_for_pass(&self, idx: usize) -> Vec<&str> {
915 if let Some(pass) = self.passes.get(idx) {
916 pass.reads
917 .iter()
918 .chain(pass.writes.iter())
919 .map(|s| s.as_str())
920 .collect::<HashSet<_>>()
921 .into_iter()
922 .collect()
923 } else {
924 vec![]
925 }
926 }
927}
928
929#[cfg(test)]
934mod tests {
935 use super::*;
936
937 #[test]
940 fn test_priority_ordering() {
941 assert!(TaskPriority::RealTime > TaskPriority::High);
942 assert!(TaskPriority::High > TaskPriority::Normal);
943 assert!(TaskPriority::Normal > TaskPriority::Low);
944 assert!(TaskPriority::Low > TaskPriority::Background);
945 }
946
947 #[test]
950 fn test_compute_task_invocations_1d() {
951 let t = ComputeTask::new_1d("particles", 100);
952 assert_eq!(t.total_invocations(), 6400);
954 }
955
956 #[test]
957 fn test_compute_task_invocations_2d() {
958 let t = ComputeTask::new_2d("shadows", 8, 8);
959 assert_eq!(t.total_invocations(), 4096);
961 }
962
963 #[test]
964 fn test_compute_task_depends_on() {
965 let t = ComputeTask::new_1d("B", 1).depends_on("A");
966 assert!(t.dependencies.contains(&"A".to_owned()));
967 }
968
969 #[test]
970 fn test_compute_task_priority() {
971 let t = ComputeTask::new_1d("t", 1).with_priority(TaskPriority::High);
972 assert_eq!(t.priority, TaskPriority::High);
973 }
974
975 #[test]
978 fn test_task_graph_topo_sort_simple() {
979 let mut g = TaskGraph::new();
980 g.add_task(ComputeTask::new_1d("A", 1));
981 g.add_task(ComputeTask::new_1d("B", 1).depends_on("A"));
982 g.add_task(ComputeTask::new_1d("C", 1).depends_on("B"));
983 let order = g.topological_sort().unwrap();
984 let pos: HashMap<&str, usize> = order
985 .iter()
986 .enumerate()
987 .map(|(i, s)| (s.as_str(), i))
988 .collect();
989 assert!(pos["A"] < pos["B"]);
990 assert!(pos["B"] < pos["C"]);
991 }
992
993 #[test]
994 fn test_task_graph_topo_sort_diamond() {
995 let mut g = TaskGraph::new();
996 g.add_task(ComputeTask::new_1d("A", 1));
997 g.add_task(ComputeTask::new_1d("B", 1).depends_on("A"));
998 g.add_task(ComputeTask::new_1d("C", 1).depends_on("A"));
999 g.add_task(ComputeTask::new_1d("D", 1).depends_on("B").depends_on("C"));
1000 let order = g.topological_sort().unwrap();
1001 assert_eq!(order.len(), 4);
1002 }
1003
1004 #[test]
1005 fn test_task_graph_cycle_detection() {
1006 let mut g = TaskGraph::new();
1007 g.add_task(ComputeTask::new_1d("A", 1).depends_on("B"));
1008 g.add_task(ComputeTask::new_1d("B", 1).depends_on("A"));
1009 assert!(g.has_cycle());
1010 }
1011
1012 #[test]
1013 fn test_task_graph_critical_path() {
1014 let mut g = TaskGraph::new();
1015 g.add_task(ComputeTask::new_1d("A", 1).with_estimated_ms(1.0));
1016 g.add_task(
1017 ComputeTask::new_1d("B", 1)
1018 .depends_on("A")
1019 .with_estimated_ms(2.0),
1020 );
1021 g.add_task(ComputeTask::new_1d("C", 1).with_estimated_ms(10.0));
1022 let cp = g.critical_path();
1023 assert!(cp.contains(&"C".to_owned()));
1025 }
1026
1027 #[test]
1028 fn test_task_graph_empty_topo() {
1029 let g = TaskGraph::new();
1030 let order = g.topological_sort().unwrap();
1031 assert!(order.is_empty());
1032 }
1033
1034 #[test]
1037 fn test_scheduler_schedule() {
1038 let mut g = TaskGraph::new();
1039 g.add_task(ComputeTask::new_1d("X", 1));
1040 g.add_task(ComputeTask::new_1d("Y", 1).depends_on("X"));
1041 let sched = TaskScheduler::new();
1042 let order = sched.schedule(&g).unwrap();
1043 assert_eq!(order.len(), 2);
1044 }
1045
1046 #[test]
1047 fn test_scheduler_batch_schedule() {
1048 let mut g = TaskGraph::new();
1049 g.add_task(ComputeTask::new_1d("A", 1));
1050 g.add_task(ComputeTask::new_1d("B", 1));
1051 g.add_task(ComputeTask::new_1d("C", 1).depends_on("A").depends_on("B"));
1052 let sched = TaskScheduler::new();
1053 let batches = sched.batch_schedule(&g).unwrap();
1054 assert!(batches[0].len() >= 2);
1056 assert!(batches.len() >= 2);
1058 }
1059
1060 #[test]
1063 fn test_resource_barrier_raw() {
1064 let b = ResourceBarrier::raw("write_task", "read_task", "position_buffer");
1065 assert_eq!(b.barrier_type, BarrierType::ReadAfterWrite);
1066 assert_eq!(b.resource, "position_buffer");
1067 }
1068
1069 #[test]
1070 fn test_resource_barrier_war() {
1071 let b = ResourceBarrier::war("reader", "writer", "depth");
1072 assert_eq!(b.barrier_type, BarrierType::WriteAfterRead);
1073 }
1074
1075 #[test]
1078 fn test_workload_balancer_respects_budget() {
1079 let mut wb = WorkloadBalancer::new(10.0);
1080 wb.submit(ComputeTask::new_1d("A", 1).with_estimated_ms(3.0));
1081 wb.submit(ComputeTask::new_1d("B", 1).with_estimated_ms(4.0));
1082 wb.submit(ComputeTask::new_1d("C", 1).with_estimated_ms(6.0));
1083 let frame = wb.extract_frame_work();
1084 let total: f64 = frame.iter().map(|t| t.estimated_ms).sum();
1085 assert!(total <= 10.0 + 6.0);
1087 }
1088
1089 #[test]
1090 fn test_workload_balancer_priority_order() {
1091 let mut wb = WorkloadBalancer::new(5.0);
1092 wb.submit(
1093 ComputeTask::new_1d("low", 1)
1094 .with_priority(TaskPriority::Low)
1095 .with_estimated_ms(2.0),
1096 );
1097 wb.submit(
1098 ComputeTask::new_1d("rt", 1)
1099 .with_priority(TaskPriority::RealTime)
1100 .with_estimated_ms(2.0),
1101 );
1102 let frame = wb.extract_frame_work();
1103 assert_eq!(frame[0].name, "rt");
1105 }
1106
1107 #[test]
1108 fn test_workload_balancer_pending_count() {
1109 let mut wb = WorkloadBalancer::new(1.0);
1110 for i in 0..5 {
1111 wb.submit(ComputeTask::new_1d(format!("t{i}"), 1).with_estimated_ms(1.0));
1112 }
1113 wb.extract_frame_work();
1114 assert!(wb.pending_count() < 5);
1115 }
1116
1117 #[test]
1120 fn test_async_compute_submit_poll() {
1121 let mut ac = AsyncCompute::new();
1122 let task = ComputeTask::new_1d("sim", 64);
1123 let idx = ac.submit(&task);
1124 let r = ac.poll(idx).unwrap();
1125 assert_eq!(r.state, AsyncState::Pending);
1126 }
1127
1128 #[test]
1129 fn test_async_compute_tick_to_done() {
1130 let mut ac = AsyncCompute::new();
1131 let task = ComputeTask::new_1d("sim", 1);
1132 let idx = ac.submit(&task);
1133 ac.tick(); ac.tick(); assert_eq!(ac.poll(idx).unwrap().state, AsyncState::Done);
1136 }
1137
1138 #[test]
1139 fn test_async_compute_drain_completed() {
1140 let mut ac = AsyncCompute::new();
1141 let t = ComputeTask::new_1d("t", 1);
1142 ac.submit(&t);
1143 ac.tick();
1144 ac.tick();
1145 let done = ac.drain_completed();
1146 assert_eq!(done.len(), 1);
1147 assert!(ac.poll(0).is_none()); }
1149
1150 #[test]
1153 fn test_pipeline_barrier_color_to_shader_read() {
1154 let b = PipelineBarrier::color_attachment_to_shader_read("gbuffer");
1155 assert!(b.color_to_shader_read);
1156 assert_eq!(b.src_stage, PipelineStage::ColorAttachment);
1157 assert_eq!(b.dst_stage, PipelineStage::ShaderRead);
1158 }
1159
1160 #[test]
1161 fn test_pipeline_barrier_compute_to_compute() {
1162 let b = PipelineBarrier::compute_to_compute("particles");
1163 assert_eq!(b.src_stage, PipelineStage::Compute);
1164 assert!(!b.is_compute_read_hazard()); }
1166
1167 #[test]
1168 fn test_pipeline_barrier_compute_read_hazard() {
1169 let b = PipelineBarrier {
1170 src_stage: PipelineStage::Compute,
1171 dst_stage: PipelineStage::ShaderRead,
1172 label: "test".to_owned(),
1173 color_to_shader_read: false,
1174 };
1175 assert!(b.is_compute_read_hazard());
1176 }
1177
1178 #[test]
1181 fn test_timestamp_query_elapsed() {
1182 let mut q = GpuTimestampQuery::new("render");
1183 q.begin(1_000_000); q.end(2_000_000); assert!((q.elapsed_ms() - 1.0).abs() < 1e-6);
1186 }
1187
1188 #[test]
1189 fn test_timestamp_query_is_active() {
1190 let mut q = GpuTimestampQuery::new("x");
1191 assert!(!q.is_active());
1192 q.begin(0);
1193 assert!(q.is_active());
1194 q.end(100);
1195 assert!(!q.is_active());
1196 }
1197
1198 #[test]
1199 fn test_timestamp_pool_total() {
1200 let mut pool = TimestampPool::new();
1201 let i0 = pool.begin("a", 0);
1202 pool.end(i0, 1_000_000);
1203 let i1 = pool.begin("b", 0);
1204 pool.end(i1, 2_000_000);
1205 let total = pool.total_ms();
1206 assert!((total - 3.0).abs() < 1e-6, "total={total}");
1207 }
1208
1209 #[test]
1210 fn test_timestamp_pool_reset() {
1211 let mut pool = TimestampPool::new();
1212 pool.begin("x", 0);
1213 pool.reset();
1214 assert!((pool.total_ms()).abs() < 1e-10);
1215 }
1216
1217 #[test]
1220 fn test_frame_graph_add_pass() {
1221 let mut fg = FrameGraph::new();
1222 fg.add_pass(FramePass::new("gbuffer").writes("color").writes("depth"));
1223 fg.add_pass(
1224 FramePass::new("lighting")
1225 .reads("color")
1226 .reads("depth")
1227 .writes("hdr"),
1228 );
1229 assert_eq!(fg.pass_count(), 2);
1230 }
1231
1232 #[test]
1233 fn test_frame_graph_resource_lifetime() {
1234 let mut fg = FrameGraph::new();
1235 fg.declare_resource("color", 1024 * 1024 * 4);
1236 fg.add_pass(FramePass::new("p0").writes("color"));
1237 fg.add_pass(FramePass::new("p1").reads("color"));
1238 let res = &fg.resources["color"];
1239 assert_eq!(res.first_use, 0);
1240 assert_eq!(res.last_use, 1);
1241 }
1242
1243 #[test]
1244 fn test_frame_graph_aliasing() {
1245 let mut fg = FrameGraph::new();
1246 fg.declare_resource("A", 1024);
1247 fg.declare_resource("B", 1024);
1248 fg.add_pass(FramePass::new("p0").writes("A"));
1249 fg.add_pass(FramePass::new("p1").reads("A"));
1250 fg.add_pass(FramePass::new("p2").writes("B"));
1251 fg.alias_resources();
1252 let peak = fg.peak_memory();
1254 assert!(peak > 0);
1255 }
1256
1257 #[test]
1258 fn test_frame_graph_barriers() {
1259 let mut fg = FrameGraph::new();
1260 fg.add_pass(
1261 FramePass::new("render")
1262 .barrier(PipelineBarrier::color_attachment_to_shader_read("test")),
1263 );
1264 let barriers = fg.barriers_for_pass(0);
1265 assert_eq!(barriers.len(), 1);
1266 }
1267
1268 #[test]
1269 fn test_frame_graph_all_barriers() {
1270 let mut fg = FrameGraph::new();
1271 fg.add_pass(FramePass::new("p0").barrier(PipelineBarrier::compute_to_compute("c0")));
1272 fg.add_pass(FramePass::new("p1").barrier(PipelineBarrier::compute_to_compute("c1")));
1273 assert_eq!(fg.all_barriers().len(), 2);
1274 }
1275}