Skip to main content

cliffy_gpu/
lib.rs

1//! # Cliffy GPU
2//!
3//! WebGPU compute shaders and SIMD-optimized CPU operations for geometric algebra.
4//!
5//! This crate provides hardware-accelerated geometric algebra operations using:
6//! - **WebGPU**: Massive parallel computation on GPU
7//! - **SIMD**: CPU-optimized operations using portable SIMD intrinsics
8//!
9//! ## Overview
10//!
11//! Every browser becomes a compute node with cliffy-gpu:
12//!
13//! ```ignore
14//! use cliffy_gpu::{GpuContext, GpuMultivector, AutoDispatcher};
15//! use cliffy_core::GA3;
16//!
17//! // Initialize auto dispatcher (chooses GPU or SIMD-CPU based on batch size)
18//! let dispatcher = AutoDispatcher::new().await;
19//!
20//! // Batch geometric products - automatically dispatched
21//! let a_batch: Vec<GA3> = vec![...];
22//! let b_batch: Vec<GA3> = vec![...];
23//! let results = dispatcher.geometric_product(&a_batch, &b_batch).await?;
24//! ```
25//!
26//! ## Features
27//!
28//! - **Batch Operations**: Process thousands of multivectors in parallel
29//! - **Auto Dispatch**: Automatic CPU/GPU selection based on batch size
30//! - **SIMD Fallback**: Optimized CPU operations when GPU unavailable
31//! - **WASM Support**: Works in browsers with WebGPU
32//! - **Compute Shaders**: WGSL shaders for geometric product, sandwich, exp, slerp
33
34pub mod simd;
35
36#[cfg(feature = "wasm")]
37pub mod wasm;
38
39use bytemuck::{Pod, Zeroable};
40use std::borrow::Cow;
41use std::sync::Arc;
42use thiserror::Error;
43use wgpu::util::DeviceExt;
44
45use cliffy_core::GA3;
46pub use simd::{addition_simd, geometric_product_simd, sandwich_simd, SimdBatch};
47
48#[cfg(feature = "wasm")]
49pub use wasm::*;
50
51/// Errors that can occur during GPU operations.
52#[derive(Error, Debug)]
53pub enum GpuError {
54    #[error("Failed to request GPU adapter")]
55    AdapterNotFound,
56
57    #[error("Failed to request GPU device: {0}")]
58    DeviceRequestFailed(#[from] wgpu::RequestDeviceError),
59
60    #[error("Buffer size mismatch: expected {expected}, got {actual}")]
61    BufferSizeMismatch { expected: usize, actual: usize },
62
63    #[error("GPU computation failed: {0}")]
64    ComputeFailed(String),
65
66    #[error("WebGPU not available")]
67    WebGpuNotAvailable,
68}
69
70/// GPU-compatible multivector representation.
71///
72/// Uses 8 f32 coefficients for Cl(3,0) geometric algebra:
73/// - coeffs[0]: scalar (1)
74/// - coeffs[1]: e1
75/// - coeffs[2]: e2
76/// - coeffs[3]: e12
77/// - coeffs[4]: e3
78/// - coeffs[5]: e13
79/// - coeffs[6]: e23
80/// - coeffs[7]: e123 (pseudoscalar)
81#[repr(C)]
82#[derive(Clone, Copy, Debug, Pod, Zeroable, Default)]
83pub struct GpuMultivector {
84    pub coeffs: [f32; 8],
85}
86
87impl GpuMultivector {
88    /// Create a new GPU multivector with all zeros.
89    pub fn zero() -> Self {
90        Self { coeffs: [0.0; 8] }
91    }
92
93    /// Create a scalar multivector.
94    pub fn scalar(s: f32) -> Self {
95        let mut mv = Self::zero();
96        mv.coeffs[0] = s;
97        mv
98    }
99
100    /// Create a vector multivector (e1, e2, e3 components).
101    pub fn vector(x: f32, y: f32, z: f32) -> Self {
102        let mut mv = Self::zero();
103        mv.coeffs[1] = x;
104        mv.coeffs[2] = y;
105        mv.coeffs[4] = z;
106        mv
107    }
108
109    /// Get the scalar component.
110    pub fn get_scalar(&self) -> f32 {
111        self.coeffs[0]
112    }
113
114    /// Get the vector components (e1, e2, e3).
115    pub fn get_vector(&self) -> (f32, f32, f32) {
116        (self.coeffs[1], self.coeffs[2], self.coeffs[4])
117    }
118}
119
120impl From<&GA3> for GpuMultivector {
121    fn from(mv: &GA3) -> Self {
122        let mut coeffs = [0.0f32; 8];
123        // GA3 = Multivector<3,0,0> has 8 components
124        // Map from amari-core's storage to our layout
125        let slice = mv.as_slice();
126        for (i, &c) in slice.iter().enumerate() {
127            if i < 8 {
128                coeffs[i] = c as f32;
129            }
130        }
131        Self { coeffs }
132    }
133}
134
135impl From<GpuMultivector> for GA3 {
136    fn from(gpu_mv: GpuMultivector) -> Self {
137        let coeffs: Vec<f64> = gpu_mv.coeffs.iter().map(|&c| c as f64).collect();
138        GA3::from_slice(&coeffs)
139    }
140}
141
142/// Threshold for automatic GPU dispatch.
143/// Below this count, CPU is often faster due to GPU overhead.
144pub const GPU_DISPATCH_THRESHOLD: usize = 256;
145
146/// GPU compute context for geometric algebra operations.
147///
148/// Manages WebGPU device, queue, and compute pipelines for
149/// parallel geometric algebra computation.
150pub struct GpuContext {
151    device: Arc<wgpu::Device>,
152    queue: Arc<wgpu::Queue>,
153    geometric_product_pipeline: wgpu::ComputePipeline,
154    addition_pipeline: wgpu::ComputePipeline,
155    sandwich_pipeline: wgpu::ComputePipeline,
156    exp_pipeline: wgpu::ComputePipeline,
157    rotor_slerp_pipeline: wgpu::ComputePipeline,
158    bind_group_layout: wgpu::BindGroupLayout,
159}
160
161impl GpuContext {
162    /// Create a new GPU context.
163    ///
164    /// This initializes WebGPU and creates all compute pipelines.
165    pub async fn new() -> Result<Self, GpuError> {
166        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
167            backends: wgpu::Backends::all(),
168            ..Default::default()
169        });
170
171        let adapter = instance
172            .request_adapter(&wgpu::RequestAdapterOptions {
173                power_preference: wgpu::PowerPreference::HighPerformance,
174                compatible_surface: None,
175                force_fallback_adapter: false,
176            })
177            .await
178            .ok_or(GpuError::AdapterNotFound)?;
179
180        let (device, queue) = adapter
181            .request_device(
182                &wgpu::DeviceDescriptor {
183                    label: Some("Cliffy GPU Device"),
184                    required_features: wgpu::Features::empty(),
185                    required_limits: wgpu::Limits::default(),
186                    memory_hints: wgpu::MemoryHints::Performance,
187                },
188                None,
189            )
190            .await?;
191
192        let device = Arc::new(device);
193        let queue = Arc::new(queue);
194
195        let shader_source = include_str!("../shaders/geometric.wgsl");
196        let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
197            label: Some("Geometric Algebra Shader"),
198            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(shader_source)),
199        });
200
201        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
202            label: Some("Geometric Compute Bind Group Layout"),
203            entries: &[
204                wgpu::BindGroupLayoutEntry {
205                    binding: 0,
206                    visibility: wgpu::ShaderStages::COMPUTE,
207                    ty: wgpu::BindingType::Buffer {
208                        ty: wgpu::BufferBindingType::Storage { read_only: true },
209                        has_dynamic_offset: false,
210                        min_binding_size: None,
211                    },
212                    count: None,
213                },
214                wgpu::BindGroupLayoutEntry {
215                    binding: 1,
216                    visibility: wgpu::ShaderStages::COMPUTE,
217                    ty: wgpu::BindingType::Buffer {
218                        ty: wgpu::BufferBindingType::Storage { read_only: true },
219                        has_dynamic_offset: false,
220                        min_binding_size: None,
221                    },
222                    count: None,
223                },
224                wgpu::BindGroupLayoutEntry {
225                    binding: 2,
226                    visibility: wgpu::ShaderStages::COMPUTE,
227                    ty: wgpu::BindingType::Buffer {
228                        ty: wgpu::BufferBindingType::Storage { read_only: false },
229                        has_dynamic_offset: false,
230                        min_binding_size: None,
231                    },
232                    count: None,
233                },
234            ],
235        });
236
237        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
238            label: Some("Geometric Compute Pipeline Layout"),
239            bind_group_layouts: &[&bind_group_layout],
240            push_constant_ranges: &[],
241        });
242
243        let geometric_product_pipeline =
244            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
245                label: Some("Geometric Product Pipeline"),
246                layout: Some(&pipeline_layout),
247                module: &shader_module,
248                entry_point: Some("geometric_product_kernel"),
249                compilation_options: Default::default(),
250                cache: None,
251            });
252
253        let addition_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
254            label: Some("Addition Pipeline"),
255            layout: Some(&pipeline_layout),
256            module: &shader_module,
257            entry_point: Some("addition_kernel"),
258            compilation_options: Default::default(),
259            cache: None,
260        });
261
262        let sandwich_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
263            label: Some("Sandwich Pipeline"),
264            layout: Some(&pipeline_layout),
265            module: &shader_module,
266            entry_point: Some("sandwich_kernel"),
267            compilation_options: Default::default(),
268            cache: None,
269        });
270
271        let exp_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
272            label: Some("Exponential Pipeline"),
273            layout: Some(&pipeline_layout),
274            module: &shader_module,
275            entry_point: Some("exp_kernel"),
276            compilation_options: Default::default(),
277            cache: None,
278        });
279
280        let rotor_slerp_pipeline =
281            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
282                label: Some("Rotor Slerp Pipeline"),
283                layout: Some(&pipeline_layout),
284                module: &shader_module,
285                entry_point: Some("rotor_slerp_kernel"),
286                compilation_options: Default::default(),
287                cache: None,
288            });
289
290        Ok(Self {
291            device,
292            queue,
293            geometric_product_pipeline,
294            addition_pipeline,
295            sandwich_pipeline,
296            exp_pipeline,
297            rotor_slerp_pipeline,
298            bind_group_layout,
299        })
300    }
301
302    /// Batch geometric product: a[i] * b[i] for all i.
303    ///
304    /// Computes the geometric product of corresponding elements
305    /// from two input arrays in parallel on the GPU.
306    pub async fn batch_geometric_product(
307        &self,
308        a: &[GA3],
309        b: &[GA3],
310    ) -> Result<Vec<GA3>, GpuError> {
311        if a.len() != b.len() {
312            return Err(GpuError::BufferSizeMismatch {
313                expected: a.len(),
314                actual: b.len(),
315            });
316        }
317
318        // Convert to GPU format
319        let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
320        let b_gpu: Vec<GpuMultivector> = b.iter().map(|mv| mv.into()).collect();
321
322        let result = self.run_binary_kernel(&self.geometric_product_pipeline, &a_gpu, &b_gpu)?;
323
324        Ok(result.into_iter().map(Into::into).collect())
325    }
326
327    /// Batch addition: a[i] + b[i] for all i.
328    pub async fn batch_addition(&self, a: &[GA3], b: &[GA3]) -> Result<Vec<GA3>, GpuError> {
329        if a.len() != b.len() {
330            return Err(GpuError::BufferSizeMismatch {
331                expected: a.len(),
332                actual: b.len(),
333            });
334        }
335
336        let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
337        let b_gpu: Vec<GpuMultivector> = b.iter().map(|mv| mv.into()).collect();
338
339        let result = self.run_binary_kernel(&self.addition_pipeline, &a_gpu, &b_gpu)?;
340
341        Ok(result.into_iter().map(Into::into).collect())
342    }
343
344    /// Batch sandwich product: rotor[i] * x[i] * ~rotor[i] for all i.
345    ///
346    /// The sandwich product applies a rotation to each element.
347    pub async fn batch_sandwich(
348        &self,
349        rotors: &[GA3],
350        vectors: &[GA3],
351    ) -> Result<Vec<GA3>, GpuError> {
352        if rotors.len() != vectors.len() {
353            return Err(GpuError::BufferSizeMismatch {
354                expected: rotors.len(),
355                actual: vectors.len(),
356            });
357        }
358
359        let rotors_gpu: Vec<GpuMultivector> = rotors.iter().map(|mv| mv.into()).collect();
360        let vectors_gpu: Vec<GpuMultivector> = vectors.iter().map(|mv| mv.into()).collect();
361
362        let result = self.run_binary_kernel(&self.sandwich_pipeline, &rotors_gpu, &vectors_gpu)?;
363
364        Ok(result.into_iter().map(Into::into).collect())
365    }
366
367    /// Batch exponential: exp(a[i]) for all i.
368    ///
369    /// The exponential map converts bivectors to rotors.
370    pub async fn batch_exp(&self, a: &[GA3]) -> Result<Vec<GA3>, GpuError> {
371        let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
372
373        // For unary operations, use same input for both buffers
374        let result = self.run_binary_kernel(&self.exp_pipeline, &a_gpu, &a_gpu)?;
375
376        Ok(result.into_iter().map(Into::into).collect())
377    }
378
379    /// Batch rotor SLERP: interpolate from a[i] to b[i] by t.
380    ///
381    /// Spherical linear interpolation for smooth rotation blending.
382    pub async fn batch_rotor_slerp(
383        &self,
384        a: &[GA3],
385        b: &[GA3],
386        t: f32,
387    ) -> Result<Vec<GA3>, GpuError> {
388        if a.len() != b.len() {
389            return Err(GpuError::BufferSizeMismatch {
390                expected: a.len(),
391                actual: b.len(),
392            });
393        }
394
395        let a_gpu: Vec<GpuMultivector> = a.iter().map(|mv| mv.into()).collect();
396        // Encode t in the first coefficient of b
397        let b_gpu: Vec<GpuMultivector> = b
398            .iter()
399            .map(|mv| {
400                let mut gpu_mv: GpuMultivector = mv.into();
401                gpu_mv.coeffs[0] = t;
402                gpu_mv
403            })
404            .collect();
405
406        let result = self.run_binary_kernel(&self.rotor_slerp_pipeline, &a_gpu, &b_gpu)?;
407
408        Ok(result.into_iter().map(Into::into).collect())
409    }
410
411    /// Run a binary compute kernel (two input buffers, one output).
412    fn run_binary_kernel(
413        &self,
414        pipeline: &wgpu::ComputePipeline,
415        a: &[GpuMultivector],
416        b: &[GpuMultivector],
417    ) -> Result<Vec<GpuMultivector>, GpuError> {
418        let count = a.len();
419        if count == 0 {
420            return Ok(Vec::new());
421        }
422
423        // Create input buffers
424        let a_buffer = self
425            .device
426            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
427                label: Some("Input A Buffer"),
428                contents: bytemuck::cast_slice(a),
429                usage: wgpu::BufferUsages::STORAGE,
430            });
431
432        let b_buffer = self
433            .device
434            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
435                label: Some("Input B Buffer"),
436                contents: bytemuck::cast_slice(b),
437                usage: wgpu::BufferUsages::STORAGE,
438            });
439
440        // Create output buffer
441        let output_size = std::mem::size_of_val(a) as u64;
442        let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
443            label: Some("Output Buffer"),
444            size: output_size,
445            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
446            mapped_at_creation: false,
447        });
448
449        // Create staging buffer for reading results
450        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
451            label: Some("Staging Buffer"),
452            size: output_size,
453            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
454            mapped_at_creation: false,
455        });
456
457        // Create bind group
458        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
459            label: Some("Compute Bind Group"),
460            layout: &self.bind_group_layout,
461            entries: &[
462                wgpu::BindGroupEntry {
463                    binding: 0,
464                    resource: a_buffer.as_entire_binding(),
465                },
466                wgpu::BindGroupEntry {
467                    binding: 1,
468                    resource: b_buffer.as_entire_binding(),
469                },
470                wgpu::BindGroupEntry {
471                    binding: 2,
472                    resource: output_buffer.as_entire_binding(),
473                },
474            ],
475        });
476
477        // Encode and submit
478        let mut encoder = self
479            .device
480            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
481                label: Some("Compute Encoder"),
482            });
483
484        {
485            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
486                label: Some("Compute Pass"),
487                timestamp_writes: None,
488            });
489            compute_pass.set_pipeline(pipeline);
490            compute_pass.set_bind_group(0, &bind_group, &[]);
491
492            // Dispatch workgroups (64 threads per group)
493            let workgroup_count = count.div_ceil(64) as u32;
494            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
495        }
496
497        // Copy output to staging buffer
498        encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
499
500        self.queue.submit(std::iter::once(encoder.finish()));
501
502        // Read results
503        let buffer_slice = staging_buffer.slice(..);
504        let (sender, receiver) = std::sync::mpsc::channel();
505        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
506            let _ = sender.send(result);
507        });
508
509        self.device.poll(wgpu::Maintain::Wait);
510
511        receiver
512            .recv()
513            .map_err(|e| GpuError::ComputeFailed(e.to_string()))?
514            .map_err(|e| GpuError::ComputeFailed(format!("{:?}", e)))?;
515
516        let data = buffer_slice.get_mapped_range();
517        let result: Vec<GpuMultivector> = bytemuck::cast_slice(&data).to_vec();
518        drop(data);
519        staging_buffer.unmap();
520
521        Ok(result)
522    }
523
524    /// Check if GPU dispatch is recommended for the given batch size.
525    pub fn should_use_gpu(&self, batch_size: usize) -> bool {
526        batch_size >= GPU_DISPATCH_THRESHOLD
527    }
528
529    /// Get the device info for debugging.
530    pub fn device_info(&self) -> String {
531        "Cliffy GPU Context (wgpu)".to_string()
532    }
533}
534
535/// Automatic dispatcher that chooses CPU (SIMD) or GPU based on batch size.
536///
537/// For small batches (< threshold), SIMD-optimized CPU operations are used.
538/// For large batches (>= threshold), GPU compute shaders are used.
539pub struct AutoDispatcher {
540    gpu_ctx: Option<GpuContext>,
541    threshold: usize,
542}
543
544impl AutoDispatcher {
545    /// Create a new auto dispatcher, attempting to initialize GPU.
546    pub async fn new() -> Self {
547        let gpu_ctx = GpuContext::new().await.ok();
548        Self {
549            gpu_ctx,
550            threshold: GPU_DISPATCH_THRESHOLD,
551        }
552    }
553
554    /// Create with a custom threshold.
555    pub async fn with_threshold(threshold: usize) -> Self {
556        let gpu_ctx = GpuContext::new().await.ok();
557        Self { gpu_ctx, threshold }
558    }
559
560    /// Create a CPU-only dispatcher (no GPU).
561    pub fn cpu_only() -> Self {
562        Self {
563            gpu_ctx: None,
564            threshold: GPU_DISPATCH_THRESHOLD,
565        }
566    }
567
568    /// Check if GPU is available.
569    pub fn has_gpu(&self) -> bool {
570        self.gpu_ctx.is_some()
571    }
572
573    /// Get the current dispatch threshold.
574    pub fn threshold(&self) -> usize {
575        self.threshold
576    }
577
578    /// Batch geometric product with automatic dispatch.
579    ///
580    /// Uses GPU for large batches, SIMD-optimized CPU for small batches.
581    pub async fn geometric_product(&self, a: &[GA3], b: &[GA3]) -> Result<Vec<GA3>, GpuError> {
582        if let Some(ref ctx) = self.gpu_ctx {
583            if a.len() >= self.threshold {
584                return ctx.batch_geometric_product(a, b).await;
585            }
586        }
587
588        // SIMD-optimized CPU fallback
589        if a.len() != b.len() {
590            return Err(GpuError::BufferSizeMismatch {
591                expected: a.len(),
592                actual: b.len(),
593            });
594        }
595
596        let a_gpu = SimdBatch::from_ga3(a);
597        let b_gpu = SimdBatch::from_ga3(b);
598        let result = SimdBatch::geometric_product(&a_gpu, &b_gpu);
599        Ok(SimdBatch::to_ga3(&result))
600    }
601
602    /// Batch addition with automatic dispatch.
603    ///
604    /// Uses GPU for large batches, SIMD-optimized CPU for small batches.
605    pub async fn addition(&self, a: &[GA3], b: &[GA3]) -> Result<Vec<GA3>, GpuError> {
606        if let Some(ref ctx) = self.gpu_ctx {
607            if a.len() >= self.threshold {
608                return ctx.batch_addition(a, b).await;
609            }
610        }
611
612        // SIMD-optimized CPU fallback
613        if a.len() != b.len() {
614            return Err(GpuError::BufferSizeMismatch {
615                expected: a.len(),
616                actual: b.len(),
617            });
618        }
619
620        let a_gpu = SimdBatch::from_ga3(a);
621        let b_gpu = SimdBatch::from_ga3(b);
622        let result = SimdBatch::addition(&a_gpu, &b_gpu);
623        Ok(SimdBatch::to_ga3(&result))
624    }
625
626    /// Batch sandwich product with automatic dispatch.
627    ///
628    /// Uses GPU for large batches, SIMD-optimized CPU for small batches.
629    pub async fn sandwich(&self, rotors: &[GA3], vectors: &[GA3]) -> Result<Vec<GA3>, GpuError> {
630        if let Some(ref ctx) = self.gpu_ctx {
631            if rotors.len() >= self.threshold {
632                return ctx.batch_sandwich(rotors, vectors).await;
633            }
634        }
635
636        // SIMD-optimized CPU fallback
637        if rotors.len() != vectors.len() {
638            return Err(GpuError::BufferSizeMismatch {
639                expected: rotors.len(),
640                actual: vectors.len(),
641            });
642        }
643
644        let rotors_gpu = SimdBatch::from_ga3(rotors);
645        let vectors_gpu = SimdBatch::from_ga3(vectors);
646        let result = SimdBatch::sandwich(&rotors_gpu, &vectors_gpu);
647        Ok(SimdBatch::to_ga3(&result))
648    }
649
650    /// Batch exponential with automatic dispatch.
651    ///
652    /// Uses GPU for large batches, SIMD-optimized CPU for small batches.
653    pub async fn exp(&self, a: &[GA3]) -> Result<Vec<GA3>, GpuError> {
654        if let Some(ref ctx) = self.gpu_ctx {
655            if a.len() >= self.threshold {
656                return ctx.batch_exp(a).await;
657            }
658        }
659
660        // SIMD-optimized CPU fallback
661        let a_gpu = SimdBatch::from_ga3(a);
662        let result = SimdBatch::exp(&a_gpu);
663        Ok(SimdBatch::to_ga3(&result))
664    }
665
666    /// Batch rotor SLERP with automatic dispatch.
667    ///
668    /// Uses GPU for large batches, SIMD-optimized CPU for small batches.
669    pub async fn rotor_slerp(&self, a: &[GA3], b: &[GA3], t: f32) -> Result<Vec<GA3>, GpuError> {
670        if let Some(ref ctx) = self.gpu_ctx {
671            if a.len() >= self.threshold {
672                return ctx.batch_rotor_slerp(a, b, t).await;
673            }
674        }
675
676        // SIMD-optimized CPU fallback
677        if a.len() != b.len() {
678            return Err(GpuError::BufferSizeMismatch {
679                expected: a.len(),
680                actual: b.len(),
681            });
682        }
683
684        let a_gpu = SimdBatch::from_ga3(a);
685        let b_gpu = SimdBatch::from_ga3(b);
686        let result = SimdBatch::rotor_slerp(&a_gpu, &b_gpu, t);
687        Ok(SimdBatch::to_ga3(&result))
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    #[test]
696    fn test_gpu_multivector_zero() {
697        let mv = GpuMultivector::zero();
698        assert!(mv.coeffs.iter().all(|&c| c == 0.0));
699    }
700
701    #[test]
702    fn test_gpu_multivector_scalar() {
703        let mv = GpuMultivector::scalar(5.0);
704        assert_eq!(mv.get_scalar(), 5.0);
705    }
706
707    #[test]
708    fn test_gpu_multivector_vector() {
709        let mv = GpuMultivector::vector(1.0, 2.0, 3.0);
710        assert_eq!(mv.get_vector(), (1.0, 2.0, 3.0));
711    }
712
713    #[test]
714    fn test_ga3_roundtrip() {
715        use amari_core::Vector;
716        let vec = Vector::<3, 0, 0>::from_components(1.0, 2.0, 3.0);
717        let original = GA3::from_vector(&vec);
718        let gpu: GpuMultivector = (&original).into();
719        let back: GA3 = gpu.into();
720
721        // Check vector components are preserved (indices 1, 2, 4 for e1, e2, e3)
722        let x = back.get(1);
723        let y = back.get(2);
724        let z = back.get(4);
725        assert!((x - 1.0).abs() < 1e-5);
726        assert!((y - 2.0).abs() < 1e-5);
727        assert!((z - 3.0).abs() < 1e-5);
728    }
729
730    #[test]
731    fn test_dispatch_threshold() {
732        // Test AutoDispatcher CPU-only mode
733        let dispatcher = AutoDispatcher::cpu_only();
734        assert!(!dispatcher.has_gpu());
735        assert_eq!(dispatcher.threshold(), GPU_DISPATCH_THRESHOLD);
736    }
737}