1pub mod simd;
35
36#[cfg(feature = "wasm")]
37pub mod wasm;
38
39use bytemuck::{Pod, Zeroable};
40use std::borrow::Cow;
41use std::sync::Arc;
42use thiserror::Error;
43use wgpu::util::DeviceExt;
44
45use cliffy_core::GA3;
46pub use simd::{addition_simd, geometric_product_simd, sandwich_simd, SimdBatch};
47
48#[cfg(feature = "wasm")]
49pub use wasm::*;
50
51#[derive(Error, Debug)]
53pub enum GpuError {
54 #[error("Failed to request GPU adapter")]
55 AdapterNotFound,
56
57 #[error("Failed to request GPU device: {0}")]
58 DeviceRequestFailed(#[from] wgpu::RequestDeviceError),
59
60 #[error("Buffer size mismatch: expected {expected}, got {actual}")]
61 BufferSizeMismatch { expected: usize, actual: usize },
62
63 #[error("GPU computation failed: {0}")]
64 ComputeFailed(String),
65
66 #[error("WebGPU not available")]
67 WebGpuNotAvailable,
68}
69
70#[repr(C)]
82#[derive(Clone, Copy, Debug, Pod, Zeroable, Default)]
83pub struct GpuMultivector {
84 pub coeffs: [f32; 8],
85}
86
87impl GpuMultivector {
88 pub fn zero() -> Self {
90 Self { coeffs: [0.0; 8] }
91 }
92
93 pub fn scalar(s: f32) -> Self {
95 let mut mv = Self::zero();
96 mv.coeffs[0] = s;
97 mv
98 }
99
100 pub fn vector(x: f32, y: f32, z: f32) -> Self {
102 let mut mv = Self::zero();
103 mv.coeffs[1] = x;
104 mv.coeffs[2] = y;
105 mv.coeffs[4] = z;
106 mv
107 }
108
109 pub fn get_scalar(&self) -> f32 {
111 self.coeffs[0]
112 }
113
114 pub fn get_vector(&self) -> (f32, f32, f32) {
116 (self.coeffs[1], self.coeffs[2], self.coeffs[4])
117 }
118}
119
120impl From<&GA3> for GpuMultivector {
121 fn from(mv: &GA3) -> Self {
122 let mut coeffs = [0.0f32; 8];
123 let slice = mv.as_slice();
126 for (i, &c) in slice.iter().enumerate() {
127 if i < 8 {
128 coeffs[i] = c as f32;
129 }
130 }
131 Self { coeffs }
132 }
133}
134
135impl From<GpuMultivector> for GA3 {
136 fn from(gpu_mv: GpuMultivector) -> Self {
137 let coeffs: Vec<f64> = gpu_mv.coeffs.iter().map(|&c| c as f64).collect();
138 GA3::from_slice(&coeffs)
139 }
140}
141
142pub const GPU_DISPATCH_THRESHOLD: usize = 256;
145
146pub struct GpuContext {
151 device: Arc<wgpu::Device>,
152 queue: Arc<wgpu::Queue>,
153 geometric_product_pipeline: wgpu::ComputePipeline,
154 addition_pipeline: wgpu::ComputePipeline,
155 sandwich_pipeline: wgpu::ComputePipeline,
156 exp_pipeline: wgpu::ComputePipeline,
157 rotor_slerp_pipeline: wgpu::ComputePipeline,
158 bind_group_layout: wgpu::BindGroupLayout,
159}
160
161impl GpuContext {
162 pub async fn new() -> Result<Self, GpuError> {
166 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
167 backends: wgpu::Backends::all(),
168 ..Default::default()
169 });
170
171 let adapter = instance
172 .request_adapter(&wgpu::RequestAdapterOptions {
173 power_preference: wgpu::PowerPreference::HighPerformance,
174 compatible_surface: None,
175 force_fallback_adapter: false,
176 })
177 .await
178 .ok_or(GpuError::AdapterNotFound)?;
179
180 let (device, queue) = adapter
181 .request_device(
182 &wgpu::DeviceDescriptor {
183 label: Some("Cliffy GPU Device"),
184 required_features: wgpu::Features::empty(),
185 required_limits: wgpu::Limits::default(),
186 memory_hints: wgpu::MemoryHints::Performance,
187 },
188 None,
189 )
190 .await?;
191
192 let device = Arc::new(device);
193 let queue = Arc::new(queue);
194
195 let shader_source = include_str!("../shaders/geometric.wgsl");
196 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
197 label: Some("Geometric Algebra Shader"),
198 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(shader_source)),
199 });
200
201 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
202 label: Some("Geometric Compute Bind Group Layout"),
203 entries: &[
204 wgpu::BindGroupLayoutEntry {
205 binding: 0,
206 visibility: wgpu::ShaderStages::COMPUTE,
207 ty: wgpu::BindingType::Buffer {
208 ty: wgpu::BufferBindingType::Storage { read_only: true },
209 has_dynamic_offset: false,
210 min_binding_size: None,
211 },
212 count: None,
213 },
214 wgpu::BindGroupLayoutEntry {
215 binding: 1,
216 visibility: wgpu::ShaderStages::COMPUTE,
217 ty: wgpu::BindingType::Buffer {
218 ty: wgpu::BufferBindingType::Storage { read_only: true },
219 has_dynamic_offset: false,
220 min_binding_size: None,
221 },
222 count: None,
223 },
224 wgpu::BindGroupLayoutEntry {
225 binding: 2,
226 visibility: wgpu::ShaderStages::COMPUTE,
227 ty: wgpu::BindingType::Buffer {
228 ty: wgpu::BufferBindingType::Storage { read_only: false },
229 has_dynamic_offset: false,
230 min_binding_size: None,
231 },
232 count: None,
233 },
234 ],
235 });
236
237 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
238 label: Some("Geometric Compute Pipeline Layout"),
239 bind_group_layouts: &[&bind_group_layout],
240 push_constant_ranges: &[],
241 });
242
243 let geometric_product_pipeline =
244 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
245 label: Some("Geometric Product Pipeline"),
246 layout: Some(&pipeline_layout),
247 module: &shader_module,
248 entry_point: Some("geometric_product_kernel"),
249 compilation_options: Default::default(),
250 cache: None,
251 });
252
253 let addition_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
254 label: Some("Addition Pipeline"),
255 layout: Some(&pipeline_layout),
256 module: &shader_module,
257 entry_point: Some("addition_kernel"),
258 compilation_options: Default::default(),
259 cache: None,
260 });
261
262 let sandwich_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
263 label: Some("Sandwich Pipeline"),
264 layout: Some(&pipeline_layout),
265 module: &shader_module,
266 entry_point: Some("sandwich_kernel"),
267 compilation_options: Default::default(),
268 cache: None,
269 });
270
271 let exp_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
272 label: Some("Exponential Pipeline"),
273 layout: Some(&pipeline_layout),
274 module: &shader_module,
275 entry_point: Some("exp_kernel"),
276 compilation_options: Default::default(),
277 cache: None,
278 });
279
280 let rotor_slerp_pipeline =
281 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
282 label: Some("Rotor Slerp Pipeline"),
283 layout: Some(&pipeline_layout),
284 module: &shader_module,
285 entry_point: Some("rotor_slerp_kernel"),
286 compilation_options: Default::default(),
287 cache: None,
288 });
289
290 Ok(Self {
291 device,
292 queue,
293 geometric_product_pipeline,
294 addition_pipeline,
295 sandwich_pipeline,
296 exp_pipeline,
297 rotor_slerp_pipeline,
298 bind_group_layout,
299 })
300 }
301
302 pub async fn batch_geometric_product(
307 &self,
308 a: &[GA3],
309 b: &[GA3],
310 ) -> Result<Vec<GA3>, GpuError> {
311 if a.len() != b.len() {
312 return Err(GpuError::BufferSizeMismatch {
313 expected: a.len(),
314 actual: b.len(),
315 });
316 }
317
318 let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
320 let b_gpu: Vec<GpuMultivector> = b.iter().map(|mv| mv.into()).collect();
321
322 let result = self.run_binary_kernel(&self.geometric_product_pipeline, &a_gpu, &b_gpu)?;
323
324 Ok(result.into_iter().map(Into::into).collect())
325 }
326
327 pub async fn batch_addition(&self, a: &[GA3], b: &[GA3]) -> Result<Vec<GA3>, GpuError> {
329 if a.len() != b.len() {
330 return Err(GpuError::BufferSizeMismatch {
331 expected: a.len(),
332 actual: b.len(),
333 });
334 }
335
336 let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
337 let b_gpu: Vec<GpuMultivector> = b.iter().map(|mv| mv.into()).collect();
338
339 let result = self.run_binary_kernel(&self.addition_pipeline, &a_gpu, &b_gpu)?;
340
341 Ok(result.into_iter().map(Into::into).collect())
342 }
343
344 pub async fn batch_sandwich(
348 &self,
349 rotors: &[GA3],
350 vectors: &[GA3],
351 ) -> Result<Vec<GA3>, GpuError> {
352 if rotors.len() != vectors.len() {
353 return Err(GpuError::BufferSizeMismatch {
354 expected: rotors.len(),
355 actual: vectors.len(),
356 });
357 }
358
359 let rotors_gpu: Vec<GpuMultivector> = rotors.iter().map(|mv| mv.into()).collect();
360 let vectors_gpu: Vec<GpuMultivector> = vectors.iter().map(|mv| mv.into()).collect();
361
362 let result = self.run_binary_kernel(&self.sandwich_pipeline, &rotors_gpu, &vectors_gpu)?;
363
364 Ok(result.into_iter().map(Into::into).collect())
365 }
366
367 pub async fn batch_exp(&self, a: &[GA3]) -> Result<Vec<GA3>, GpuError> {
371 let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
372
373 let result = self.run_binary_kernel(&self.exp_pipeline, &a_gpu, &a_gpu)?;
375
376 Ok(result.into_iter().map(Into::into).collect())
377 }
378
379 pub async fn batch_rotor_slerp(
383 &self,
384 a: &[GA3],
385 b: &[GA3],
386 t: f32,
387 ) -> Result<Vec<GA3>, GpuError> {
388 if a.len() != b.len() {
389 return Err(GpuError::BufferSizeMismatch {
390 expected: a.len(),
391 actual: b.len(),
392 });
393 }
394
395 let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
396 let b_gpu: Vec<GpuMultivector> = b
398 .iter()
399 .map(|mv| {
400 let mut gpu_mv: GpuMultivector = mv.into();
401 gpu_mv.coeffs[0] = t;
402 gpu_mv
403 })
404 .collect();
405
406 let result = self.run_binary_kernel(&self.rotor_slerp_pipeline, &a_gpu, &b_gpu)?;
407
408 Ok(result.into_iter().map(Into::into).collect())
409 }
410
411 fn run_binary_kernel(
413 &self,
414 pipeline: &wgpu::ComputePipeline,
415 a: &[GpuMultivector],
416 b: &[GpuMultivector],
417 ) -> Result<Vec<GpuMultivector>, GpuError> {
418 let count = a.len();
419 if count == 0 {
420 return Ok(Vec::new());
421 }
422
423 let a_buffer = self
425 .device
426 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
427 label: Some("Input A Buffer"),
428 contents: bytemuck::cast_slice(a),
429 usage: wgpu::BufferUsages::STORAGE,
430 });
431
432 let b_buffer = self
433 .device
434 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
435 label: Some("Input B Buffer"),
436 contents: bytemuck::cast_slice(b),
437 usage: wgpu::BufferUsages::STORAGE,
438 });
439
440 let output_size = std::mem::size_of_val(a) as u64;
442 let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
443 label: Some("Output Buffer"),
444 size: output_size,
445 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
446 mapped_at_creation: false,
447 });
448
449 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
451 label: Some("Staging Buffer"),
452 size: output_size,
453 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
454 mapped_at_creation: false,
455 });
456
457 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
459 label: Some("Compute Bind Group"),
460 layout: &self.bind_group_layout,
461 entries: &[
462 wgpu::BindGroupEntry {
463 binding: 0,
464 resource: a_buffer.as_entire_binding(),
465 },
466 wgpu::BindGroupEntry {
467 binding: 1,
468 resource: b_buffer.as_entire_binding(),
469 },
470 wgpu::BindGroupEntry {
471 binding: 2,
472 resource: output_buffer.as_entire_binding(),
473 },
474 ],
475 });
476
477 let mut encoder = self
479 .device
480 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
481 label: Some("Compute Encoder"),
482 });
483
484 {
485 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
486 label: Some("Compute Pass"),
487 timestamp_writes: None,
488 });
489 compute_pass.set_pipeline(pipeline);
490 compute_pass.set_bind_group(0, &bind_group, &[]);
491
492 let workgroup_count = count.div_ceil(64) as u32;
494 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
495 }
496
497 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
499
500 self.queue.submit(std::iter::once(encoder.finish()));
501
502 let buffer_slice = staging_buffer.slice(..);
504 let (sender, receiver) = std::sync::mpsc::channel();
505 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
506 let _ = sender.send(result);
507 });
508
509 self.device.poll(wgpu::Maintain::Wait);
510
511 receiver
512 .recv()
513 .map_err(|e| GpuError::ComputeFailed(e.to_string()))?
514 .map_err(|e| GpuError::ComputeFailed(format!("{:?}", e)))?;
515
516 let data = buffer_slice.get_mapped_range();
517 let result: Vec<GpuMultivector> = bytemuck::cast_slice(&data).to_vec();
518 drop(data);
519 staging_buffer.unmap();
520
521 Ok(result)
522 }
523
524 pub fn should_use_gpu(&self, batch_size: usize) -> bool {
526 batch_size >= GPU_DISPATCH_THRESHOLD
527 }
528
529 pub fn device_info(&self) -> String {
531 "Cliffy GPU Context (wgpu)".to_string()
532 }
533}
534
535pub struct AutoDispatcher {
540 gpu_ctx: Option<GpuContext>,
541 threshold: usize,
542}
543
544impl AutoDispatcher {
545 pub async fn new() -> Self {
547 let gpu_ctx = GpuContext::new().await.ok();
548 Self {
549 gpu_ctx,
550 threshold: GPU_DISPATCH_THRESHOLD,
551 }
552 }
553
554 pub async fn with_threshold(threshold: usize) -> Self {
556 let gpu_ctx = GpuContext::new().await.ok();
557 Self { gpu_ctx, threshold }
558 }
559
560 pub fn cpu_only() -> Self {
562 Self {
563 gpu_ctx: None,
564 threshold: GPU_DISPATCH_THRESHOLD,
565 }
566 }
567
568 pub fn has_gpu(&self) -> bool {
570 self.gpu_ctx.is_some()
571 }
572
573 pub fn threshold(&self) -> usize {
575 self.threshold
576 }
577
578 pub async fn geometric_product(&self, a: &[GA3], b: &[GA3]) -> Result<Vec<GA3>, GpuError> {
582 if let Some(ref ctx) = self.gpu_ctx {
583 if a.len() >= self.threshold {
584 return ctx.batch_geometric_product(a, b).await;
585 }
586 }
587
588 if a.len() != b.len() {
590 return Err(GpuError::BufferSizeMismatch {
591 expected: a.len(),
592 actual: b.len(),
593 });
594 }
595
596 let a_gpu = SimdBatch::from_ga3(a);
597 let b_gpu = SimdBatch::from_ga3(b);
598 let result = SimdBatch::geometric_product(&a_gpu, &b_gpu);
599 Ok(SimdBatch::to_ga3(&result))
600 }
601
602 pub async fn addition(&self, a: &[GA3], b: &[GA3]) -> Result<Vec<GA3>, GpuError> {
606 if let Some(ref ctx) = self.gpu_ctx {
607 if a.len() >= self.threshold {
608 return ctx.batch_addition(a, b).await;
609 }
610 }
611
612 if a.len() != b.len() {
614 return Err(GpuError::BufferSizeMismatch {
615 expected: a.len(),
616 actual: b.len(),
617 });
618 }
619
620 let a_gpu = SimdBatch::from_ga3(a);
621 let b_gpu = SimdBatch::from_ga3(b);
622 let result = SimdBatch::addition(&a_gpu, &b_gpu);
623 Ok(SimdBatch::to_ga3(&result))
624 }
625
626 pub async fn sandwich(&self, rotors: &[GA3], vectors: &[GA3]) -> Result<Vec<GA3>, GpuError> {
630 if let Some(ref ctx) = self.gpu_ctx {
631 if rotors.len() >= self.threshold {
632 return ctx.batch_sandwich(rotors, vectors).await;
633 }
634 }
635
636 if rotors.len() != vectors.len() {
638 return Err(GpuError::BufferSizeMismatch {
639 expected: rotors.len(),
640 actual: vectors.len(),
641 });
642 }
643
644 let rotors_gpu = SimdBatch::from_ga3(rotors);
645 let vectors_gpu = SimdBatch::from_ga3(vectors);
646 let result = SimdBatch::sandwich(&rotors_gpu, &vectors_gpu);
647 Ok(SimdBatch::to_ga3(&result))
648 }
649
650 pub async fn exp(&self, a: &[GA3]) -> Result<Vec<GA3>, GpuError> {
654 if let Some(ref ctx) = self.gpu_ctx {
655 if a.len() >= self.threshold {
656 return ctx.batch_exp(a).await;
657 }
658 }
659
660 let a_gpu = SimdBatch::from_ga3(a);
662 let result = SimdBatch::exp(&a_gpu);
663 Ok(SimdBatch::to_ga3(&result))
664 }
665
666 pub async fn rotor_slerp(&self, a: &[GA3], b: &[GA3], t: f32) -> Result<Vec<GA3>, GpuError> {
670 if let Some(ref ctx) = self.gpu_ctx {
671 if a.len() >= self.threshold {
672 return ctx.batch_rotor_slerp(a, b, t).await;
673 }
674 }
675
676 if a.len() != b.len() {
678 return Err(GpuError::BufferSizeMismatch {
679 expected: a.len(),
680 actual: b.len(),
681 });
682 }
683
684 let a_gpu = SimdBatch::from_ga3(a);
685 let b_gpu = SimdBatch::from_ga3(b);
686 let result = SimdBatch::rotor_slerp(&a_gpu, &b_gpu, t);
687 Ok(SimdBatch::to_ga3(&result))
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[test]
696 fn test_gpu_multivector_zero() {
697 let mv = GpuMultivector::zero();
698 assert!(mv.coeffs.iter().all(|&c| c == 0.0));
699 }
700
701 #[test]
702 fn test_gpu_multivector_scalar() {
703 let mv = GpuMultivector::scalar(5.0);
704 assert_eq!(mv.get_scalar(), 5.0);
705 }
706
707 #[test]
708 fn test_gpu_multivector_vector() {
709 let mv = GpuMultivector::vector(1.0, 2.0, 3.0);
710 assert_eq!(mv.get_vector(), (1.0, 2.0, 3.0));
711 }
712
713 #[test]
714 fn test_ga3_roundtrip() {
715 use amari_core::Vector;
716 let vec = Vector::<3, 0, 0>::from_components(1.0, 2.0, 3.0);
717 let original = GA3::from_vector(&vec);
718 let gpu: GpuMultivector = (&original).into();
719 let back: GA3 = gpu.into();
720
721 let x = back.get(1);
723 let y = back.get(2);
724 let z = back.get(4);
725 assert!((x - 1.0).abs() < 1e-5);
726 assert!((y - 2.0).abs() < 1e-5);
727 assert!((z - 3.0).abs() < 1e-5);
728 }
729
730 #[test]
731 fn test_dispatch_threshold() {
732 let dispatcher = AutoDispatcher::cpu_only();
734 assert!(!dispatcher.has_gpu());
735 assert_eq!(dispatcher.threshold(), GPU_DISPATCH_THRESHOLD);
736 }
737}