1use std::collections::HashMap;
4use std::time::{Duration, Instant};
5
6use super::backend::{
7 BackendCapabilities, BackendContext, BufferHandle, BufferUsage, ComputePipelineHandle,
8 GpuBackend, GpuCommand, PipelineLayout, ShaderHandle, ShaderStage, SoftwareContext,
9};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum AccessMode {
18 ReadOnly,
19 WriteOnly,
20 ReadWrite,
21}
22
23#[derive(Debug, Clone)]
29pub struct BindGroupEntry {
30 pub binding: u32,
31 pub buffer_or_texture: BindingResource,
32 pub access: AccessMode,
33}
34
35#[derive(Debug, Clone)]
37pub enum BindingResource {
38 Buffer(BufferHandle),
39 Texture(super::backend::TextureHandle),
40}
41
42#[derive(Debug, Clone)]
44pub struct BindGroupLayout {
45 pub entries: Vec<BindGroupLayoutEntry>,
46}
47
48impl BindGroupLayout {
49 pub fn new() -> Self { Self { entries: Vec::new() } }
50
51 pub fn push(mut self, binding: u32, access: AccessMode) -> Self {
52 self.entries.push(BindGroupLayoutEntry { binding, access });
53 self
54 }
55}
56
57impl Default for BindGroupLayout {
58 fn default() -> Self { Self::new() }
59}
60
61#[derive(Debug, Clone)]
63pub struct BindGroupLayoutEntry {
64 pub binding: u32,
65 pub access: AccessMode,
66}
67
68#[derive(Debug, Clone)]
74pub struct ComputePipeline {
75 pub shader: ShaderHandle,
76 pub bind_group_layout: BindGroupLayout,
77 pub workgroup_size: [u32; 3],
78 pub handle: ComputePipelineHandle,
79}
80
81#[derive(Debug, Clone)]
87pub struct ComputeBuffer {
88 pub handle: BufferHandle,
89 pub size: usize,
90 pub element_size: usize,
91}
92
93impl ComputeBuffer {
94 pub fn element_count(&self) -> usize {
96 if self.element_size == 0 { 0 } else { self.size / self.element_size }
97 }
98}
99
100pub struct ComputeProfiler {
106 records: Vec<DispatchRecord>,
107 max_records: usize,
108}
109
110#[derive(Debug, Clone)]
111pub struct DispatchRecord {
112 pub label: String,
113 pub workgroups: [u32; 3],
114 pub duration: Duration,
115}
116
117impl ComputeProfiler {
118 pub fn new(max_records: usize) -> Self {
119 Self {
120 records: Vec::with_capacity(max_records.min(4096)),
121 max_records,
122 }
123 }
124
125 pub fn record(&mut self, label: &str, workgroups: [u32; 3], duration: Duration) {
126 if self.records.len() >= self.max_records {
127 self.records.remove(0);
128 }
129 self.records.push(DispatchRecord {
130 label: label.to_string(),
131 workgroups,
132 duration,
133 });
134 }
135
136 pub fn average_duration(&self) -> Duration {
137 if self.records.is_empty() {
138 return Duration::ZERO;
139 }
140 let total: Duration = self.records.iter().map(|r| r.duration).sum();
141 total / self.records.len() as u32
142 }
143
144 pub fn total_dispatches(&self) -> usize {
145 self.records.len()
146 }
147
148 pub fn clear(&mut self) {
149 self.records.clear();
150 }
151
152 pub fn last(&self) -> Option<&DispatchRecord> {
153 self.records.last()
154 }
155
156 pub fn records(&self) -> &[DispatchRecord] {
157 &self.records
158 }
159}
160
161pub struct CpuKernel {
168 pub workgroup_size: [u32; 3],
169}
170
171impl CpuKernel {
172 pub fn new(workgroup_size: [u32; 3]) -> Self {
173 Self { workgroup_size }
174 }
175
176 pub fn dispatch<F>(&self, groups: [u32; 3], mut f: F)
179 where
180 F: FnMut(u32, u32, u32),
181 {
182 let [sx, sy, sz] = self.workgroup_size;
183 let [gx, gy, gz] = groups;
184 for gz_i in 0..gz {
185 for gy_i in 0..gy {
186 for gx_i in 0..gx {
187 for lz in 0..sz {
188 for ly in 0..sy {
189 for lx in 0..sx {
190 let x = gx_i * sx + lx;
191 let y = gy_i * sy + ly;
192 let z = gz_i * sz + lz;
193 f(x, y, z);
194 }
195 }
196 }
197 }
198 }
199 }
200 }
201
202 pub fn total_invocations(&self, groups: [u32; 3]) -> u64 {
204 let [sx, sy, sz] = self.workgroup_size;
205 let [gx, gy, gz] = groups;
206 (sx as u64) * (sy as u64) * (sz as u64)
207 * (gx as u64) * (gy as u64) * (gz as u64)
208 }
209}
210
211pub struct ComputeContext {
217 pub backend_type: GpuBackend,
218 backend: Box<dyn BackendContext>,
219 capabilities: BackendCapabilities,
220 profiler: ComputeProfiler,
221 pipelines: HashMap<u64, ComputePipeline>,
222}
223
224impl ComputeContext {
225 pub fn new(backend: Box<dyn BackendContext>, backend_type: GpuBackend) -> Self {
226 let capabilities = BackendCapabilities::for_backend(backend_type);
227 Self {
228 backend_type,
229 backend,
230 capabilities,
231 profiler: ComputeProfiler::new(1024),
232 pipelines: HashMap::new(),
233 }
234 }
235
236 pub fn software() -> Self {
238 Self::new(Box::new(SoftwareContext::new()), GpuBackend::Software)
239 }
240
241 pub fn create_storage_buffer<T: Copy>(&mut self, data: &[T]) -> ComputeBuffer {
243 let element_size = std::mem::size_of::<T>();
244 let byte_size = element_size * data.len();
245 let handle = self.backend.create_buffer(byte_size, BufferUsage::STORAGE);
246
247 let byte_slice = unsafe {
249 std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_size)
250 };
251 self.backend.write_buffer(handle, byte_slice);
252
253 ComputeBuffer {
254 handle,
255 size: byte_size,
256 element_size,
257 }
258 }
259
260 pub fn create_empty_buffer<T>(&mut self, count: usize) -> ComputeBuffer {
262 let element_size = std::mem::size_of::<T>();
263 let byte_size = element_size * count;
264 let handle = self.backend.create_buffer(byte_size, BufferUsage::STORAGE);
265 ComputeBuffer {
266 handle,
267 size: byte_size,
268 element_size,
269 }
270 }
271
272 pub fn create_pipeline(
274 &mut self,
275 source: &str,
276 layout: BindGroupLayout,
277 workgroup_size: [u32; 3],
278 ) -> ComputePipeline {
279 let shader = self.backend.create_shader(source, ShaderStage::Compute);
280 let pl = PipelineLayout::default();
281 let handle = self.backend.create_compute_pipeline(shader, &pl);
282 let pipeline = ComputePipeline {
283 shader,
284 bind_group_layout: layout,
285 workgroup_size,
286 handle,
287 };
288 self.pipelines.insert(handle.0, pipeline.clone());
289 pipeline
290 }
291
292 pub fn dispatch(&mut self, pipeline: &ComputePipeline, x: u32, y: u32, z: u32) {
294 let start = Instant::now();
295
296 if self.backend_type == GpuBackend::Software {
297 }
300
301 self.backend.submit(&[GpuCommand::Dispatch {
302 pipeline: pipeline.handle,
303 x,
304 y,
305 z,
306 }]);
307
308 let elapsed = start.elapsed();
309 self.profiler.record("dispatch", [x, y, z], elapsed);
310 }
311
312 pub fn indirect_dispatch(&mut self, pipeline: &ComputePipeline, args_buffer: &ComputeBuffer) {
314 let data = self.backend.read_buffer(args_buffer.handle);
316 let mut groups = [1u32, 1, 1];
317 if data.len() >= 12 {
318 for i in 0..3 {
319 let bytes = [data[i * 4], data[i * 4 + 1], data[i * 4 + 2], data[i * 4 + 3]];
320 groups[i] = u32::from_le_bytes(bytes);
321 }
322 }
323 self.dispatch(pipeline, groups[0], groups[1], groups[2]);
324 }
325
326 pub fn memory_barrier(&mut self) {
328 self.backend.submit(&[GpuCommand::Barrier]);
329 }
330
331 pub fn read_back<T: Copy + Default>(&self, buffer: &ComputeBuffer) -> Vec<T> {
333 let data = self.backend.read_buffer(buffer.handle);
334 let elem_size = std::mem::size_of::<T>();
335 if elem_size == 0 {
336 return Vec::new();
337 }
338 let count = data.len() / elem_size;
339 let mut result = vec![T::default(); count];
340 unsafe {
341 let dst = std::slice::from_raw_parts_mut(
342 result.as_mut_ptr() as *mut u8,
343 count * elem_size,
344 );
345 dst.copy_from_slice(&data[..count * elem_size]);
346 }
347 result
348 }
349
350 pub fn write_buffer<T: Copy>(&mut self, buffer: &ComputeBuffer, data: &[T]) {
352 let byte_size = std::mem::size_of::<T>() * data.len();
353 let byte_slice = unsafe {
354 std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_size)
355 };
356 self.backend.write_buffer(buffer.handle, byte_slice);
357 }
358
359 pub fn profiler(&self) -> &ComputeProfiler {
361 &self.profiler
362 }
363
364 pub fn profiler_mut(&mut self) -> &mut ComputeProfiler {
366 &mut self.profiler
367 }
368
369 pub fn supports_compute(&self) -> bool {
371 self.capabilities.compute_shaders
372 }
373
374 pub fn destroy_buffer(&mut self, buffer: &ComputeBuffer) {
376 self.backend.destroy_buffer(buffer.handle);
377 }
378}
379
380pub fn cpu_parallel_dispatch<F>(
387 workgroup_size: [u32; 3],
388 groups: [u32; 3],
389 num_threads: usize,
390 f: F,
391) where
392 F: Fn(usize, u32, u32, u32) + Send + Sync,
393{
394 let [sx, sy, sz] = workgroup_size;
395 let [gx, gy, gz] = groups;
396 let total_groups = (gx as usize) * (gy as usize) * (gz as usize);
397 let num_threads = num_threads.max(1).min(total_groups);
398
399 if num_threads <= 1 {
400 let kernel = CpuKernel::new(workgroup_size);
402 kernel.dispatch(groups, |x, y, z| f(0, x, y, z));
403 return;
404 }
405
406 let groups_per_thread = (total_groups + num_threads - 1) / num_threads;
408 let f_ref = &f;
409
410 std::thread::scope(|scope| {
411 for tid in 0..num_threads {
412 let start = tid * groups_per_thread;
413 let end = ((tid + 1) * groups_per_thread).min(total_groups);
414 scope.spawn(move || {
415 for flat in start..end {
416 let gz_i = (flat / ((gx as usize) * (gy as usize))) as u32;
417 let rem = flat % ((gx as usize) * (gy as usize));
418 let gy_i = (rem / (gx as usize)) as u32;
419 let gx_i = (rem % (gx as usize)) as u32;
420 for lz in 0..sz {
421 for ly in 0..sy {
422 for lx in 0..sx {
423 let x = gx_i * sx + lx;
424 let y = gy_i * sy + ly;
425 let z = gz_i * sz + lz;
426 f_ref(tid, x, y, z);
427 }
428 }
429 }
430 }
431 });
432 }
433 });
434}
435
436#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn compute_buffer_element_count() {
446 let buf = ComputeBuffer {
447 handle: BufferHandle(1),
448 size: 40,
449 element_size: 4,
450 };
451 assert_eq!(buf.element_count(), 10);
452 }
453
454 #[test]
455 fn create_storage_buffer_f32() {
456 let mut ctx = ComputeContext::software();
457 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
458 let buf = ctx.create_storage_buffer(&data);
459 assert_eq!(buf.size, 16);
460 assert_eq!(buf.element_size, 4);
461 assert_eq!(buf.element_count(), 4);
462
463 let readback: Vec<f32> = ctx.read_back(&buf);
464 assert_eq!(readback, vec![1.0, 2.0, 3.0, 4.0]);
465 }
466
467 #[test]
468 fn create_storage_buffer_u32() {
469 let mut ctx = ComputeContext::software();
470 let data: Vec<u32> = vec![10, 20, 30];
471 let buf = ctx.create_storage_buffer(&data);
472 let readback: Vec<u32> = ctx.read_back(&buf);
473 assert_eq!(readback, vec![10, 20, 30]);
474 }
475
476 #[test]
477 fn create_empty_buffer() {
478 let mut ctx = ComputeContext::software();
479 let buf = ctx.create_empty_buffer::<f32>(8);
480 assert_eq!(buf.size, 32);
481 assert_eq!(buf.element_count(), 8);
482 }
483
484 #[test]
485 fn dispatch_pipeline() {
486 let mut ctx = ComputeContext::software();
487 let layout = BindGroupLayout::new().push(0, AccessMode::ReadWrite);
488 let pipeline = ctx.create_pipeline("void main(){}", layout, [64, 1, 1]);
489 ctx.dispatch(&pipeline, 4, 1, 1);
490 assert_eq!(ctx.profiler().total_dispatches(), 1);
491 }
492
493 #[test]
494 fn indirect_dispatch() {
495 let mut ctx = ComputeContext::software();
496 let layout = BindGroupLayout::new();
497 let pipeline = ctx.create_pipeline("void main(){}", layout, [1, 1, 1]);
498
499 let args: Vec<u32> = vec![2, 1, 1];
501 let args_buf = ctx.create_storage_buffer(&args);
502 ctx.indirect_dispatch(&pipeline, &args_buf);
503 assert_eq!(ctx.profiler().total_dispatches(), 1);
504 }
505
506 #[test]
507 fn memory_barrier() {
508 let mut ctx = ComputeContext::software();
509 ctx.memory_barrier();
510 }
511
512 #[test]
513 fn write_and_read_back() {
514 let mut ctx = ComputeContext::software();
515 let buf = ctx.create_empty_buffer::<u32>(4);
516 ctx.write_buffer(&buf, &[100u32, 200, 300, 400]);
517 let result: Vec<u32> = ctx.read_back(&buf);
518 assert_eq!(result, vec![100, 200, 300, 400]);
519 }
520
521 #[test]
522 fn profiler_average() {
523 let mut profiler = ComputeProfiler::new(10);
524 profiler.record("a", [1, 1, 1], Duration::from_millis(10));
525 profiler.record("b", [1, 1, 1], Duration::from_millis(20));
526 assert_eq!(profiler.total_dispatches(), 2);
527 let avg = profiler.average_duration();
528 assert_eq!(avg, Duration::from_millis(15));
529 }
530
531 #[test]
532 fn profiler_rolling() {
533 let mut profiler = ComputeProfiler::new(3);
534 for i in 0..5 {
535 profiler.record(&format!("d{}", i), [1, 1, 1], Duration::from_millis(i as u64));
536 }
537 assert_eq!(profiler.total_dispatches(), 3);
538 assert_eq!(profiler.last().unwrap().label, "d4");
539 }
540
541 #[test]
542 fn profiler_clear() {
543 let mut profiler = ComputeProfiler::new(10);
544 profiler.record("x", [1, 1, 1], Duration::from_millis(5));
545 profiler.clear();
546 assert_eq!(profiler.total_dispatches(), 0);
547 assert_eq!(profiler.average_duration(), Duration::ZERO);
548 }
549
550 #[test]
551 fn cpu_kernel_dispatch() {
552 let kernel = CpuKernel::new([2, 2, 1]);
553 let mut invocations = Vec::new();
554 kernel.dispatch([2, 1, 1], |x, y, z| {
555 invocations.push((x, y, z));
556 });
557 assert_eq!(invocations.len(), 8);
559 assert!(invocations.contains(&(0, 0, 0)));
560 assert!(invocations.contains(&(3, 1, 0)));
561 }
562
563 #[test]
564 fn cpu_kernel_total_invocations() {
565 let kernel = CpuKernel::new([8, 8, 1]);
566 assert_eq!(kernel.total_invocations([4, 4, 1]), 8 * 8 * 4 * 4);
567 }
568
569 #[test]
570 fn cpu_parallel_dispatch_runs() {
571 use std::sync::atomic::{AtomicU32, Ordering};
572 let counter = AtomicU32::new(0);
573 cpu_parallel_dispatch([2, 1, 1], [4, 1, 1], 2, |_tid, _x, _y, _z| {
574 counter.fetch_add(1, Ordering::Relaxed);
575 });
576 assert_eq!(counter.load(Ordering::Relaxed), 8); }
578
579 #[test]
580 fn cpu_parallel_dispatch_single_thread() {
581 use std::sync::atomic::{AtomicU32, Ordering};
582 let counter = AtomicU32::new(0);
583 cpu_parallel_dispatch([1, 1, 1], [3, 2, 1], 1, |_tid, _x, _y, _z| {
584 counter.fetch_add(1, Ordering::Relaxed);
585 });
586 assert_eq!(counter.load(Ordering::Relaxed), 6);
587 }
588
589 #[test]
590 fn bind_group_layout_builder() {
591 let layout = BindGroupLayout::new()
592 .push(0, AccessMode::ReadOnly)
593 .push(1, AccessMode::WriteOnly)
594 .push(2, AccessMode::ReadWrite);
595 assert_eq!(layout.entries.len(), 3);
596 assert_eq!(layout.entries[1].access, AccessMode::WriteOnly);
597 }
598
599 #[test]
600 fn supports_compute_software() {
601 let ctx = ComputeContext::software();
602 assert!(ctx.supports_compute());
603 }
604
605 #[test]
606 fn destroy_buffer() {
607 let mut ctx = ComputeContext::software();
608 let buf = ctx.create_storage_buffer(&[1u32, 2, 3]);
609 ctx.destroy_buffer(&buf);
610 let readback: Vec<u32> = ctx.read_back(&buf);
611 assert!(readback.is_empty());
612 }
613}