Skip to main content

animato_gpu/
batch.rs

1//! Batched tween evaluation.
2
3use animato_core::{Easing, Update};
4use animato_tween::Tween;
5use core::fmt;
6use std::sync::mpsc;
7
8const SHADER_SOURCE: &str = include_str!("shaders/tween.wgsl");
9
10#[repr(C)]
11#[derive(Clone, Copy, Debug, bytemuck::Pod, bytemuck::Zeroable)]
12struct GpuTweenInput {
13    start: f32,
14    end: f32,
15    duration: f32,
16    elapsed: f32,
17    easing_id: u32,
18    _pad0: u32,
19    _pad1: u32,
20    _pad2: u32,
21}
22
23/// The backend currently used by a [`GpuAnimationBatch`].
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub enum GpuBackend {
26    /// Deterministic CPU fallback.
27    Cpu,
28    /// A wgpu device and queue were supplied.
29    Gpu,
30}
31
32/// Error returned by GPU initialization.
33#[derive(Clone, Debug, PartialEq, Eq)]
34pub enum GpuBatchError {
35    /// No suitable wgpu adapter was found.
36    AdapterUnavailable,
37    /// Requesting the wgpu device failed.
38    RequestDevice(String),
39}
40
41struct GpuResources {
42    device: wgpu::Device,
43    queue: wgpu::Queue,
44    pipeline: wgpu::ComputePipeline,
45    bind_group_layout: wgpu::BindGroupLayout,
46}
47
48impl fmt::Debug for GpuResources {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("GpuResources")
51            .field("device", &"wgpu::Device")
52            .field("queue", &"wgpu::Queue")
53            .field("pipeline", &"tween.wgsl::main")
54            .field("bind_group_layout", &"tween storage buffers")
55            .finish()
56    }
57}
58
59/// Batch of `Tween<f32>` values evaluated together.
60///
61/// The public API is intentionally small: push tweens, tick the batch, then
62/// read the current values. The CPU fallback preserves exact `Tween<f32>`
63/// behavior, including delays, looping, and advanced/custom easing.
64#[derive(Debug)]
65pub struct GpuAnimationBatch {
66    tweens: Vec<Tween<f32>>,
67    values: Vec<f32>,
68    inputs: Vec<GpuTweenInput>,
69    resources: Option<GpuResources>,
70    force_cpu: bool,
71}
72
73impl Default for GpuAnimationBatch {
74    fn default() -> Self {
75        Self::new_cpu()
76    }
77}
78
79impl GpuAnimationBatch {
80    /// Create a CPU-only batch.
81    pub fn new_cpu() -> Self {
82        Self {
83            tweens: Vec::new(),
84            values: Vec::new(),
85            inputs: Vec::new(),
86            resources: None,
87            force_cpu: false,
88        }
89    }
90
91    /// Create a batch from an existing wgpu device and queue.
92    ///
93    /// If an unsupported easing is pushed later, [`backend`](Self::backend)
94    /// reports [`GpuBackend::Cpu`] and ticks continue through the exact CPU
95    /// fallback path.
96    pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Result<Self, GpuBatchError> {
97        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
98            label: Some("animato-gpu tween.wgsl"),
99            source: wgpu::ShaderSource::Wgsl(SHADER_SOURCE.into()),
100        });
101        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
102            label: Some("animato-gpu tween bind group layout"),
103            entries: &[
104                wgpu::BindGroupLayoutEntry {
105                    binding: 0,
106                    visibility: wgpu::ShaderStages::COMPUTE,
107                    ty: wgpu::BindingType::Buffer {
108                        ty: wgpu::BufferBindingType::Storage { read_only: true },
109                        has_dynamic_offset: false,
110                        min_binding_size: None,
111                    },
112                    count: None,
113                },
114                wgpu::BindGroupLayoutEntry {
115                    binding: 1,
116                    visibility: wgpu::ShaderStages::COMPUTE,
117                    ty: wgpu::BindingType::Buffer {
118                        ty: wgpu::BufferBindingType::Storage { read_only: false },
119                        has_dynamic_offset: false,
120                        min_binding_size: None,
121                    },
122                    count: None,
123                },
124            ],
125        });
126        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
127            label: Some("animato-gpu tween pipeline layout"),
128            bind_group_layouts: &[Some(&bind_group_layout)],
129            immediate_size: 0,
130        });
131        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
132            label: Some("animato-gpu tween pipeline"),
133            layout: Some(&pipeline_layout),
134            module: &shader,
135            entry_point: Some("main"),
136            compilation_options: Default::default(),
137            cache: None,
138        });
139        Ok(Self {
140            tweens: Vec::new(),
141            values: Vec::new(),
142            inputs: Vec::new(),
143            resources: Some(GpuResources {
144                device,
145                queue,
146                pipeline,
147                bind_group_layout,
148            }),
149            force_cpu: false,
150        })
151    }
152
153    /// Try to create a GPU-backed batch using the default wgpu adapter.
154    pub fn try_new_auto() -> Result<Self, GpuBatchError> {
155        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
156        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
157            power_preference: wgpu::PowerPreference::HighPerformance,
158            compatible_surface: None,
159            force_fallback_adapter: false,
160        }))
161        .map_err(|_| GpuBatchError::AdapterUnavailable)?;
162
163        let (device, queue) = pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
164            label: Some("animato-gpu device"),
165            ..Default::default()
166        }))
167        .map_err(|err| GpuBatchError::RequestDevice(err.to_string()))?;
168
169        Self::new(device, queue)
170    }
171
172    /// Create a GPU-backed batch when possible, otherwise return CPU fallback.
173    ///
174    /// This function never panics because GPU availability is environmental.
175    pub fn new_auto() -> Self {
176        Self::try_new_auto().unwrap_or_else(|_| Self::new_cpu())
177    }
178
179    /// Push a tween and return its batch index.
180    pub fn push(&mut self, tween: Tween<f32>) -> usize {
181        if classic_easing_id(&tween.easing).is_none() {
182            self.force_cpu = true;
183        }
184        let index = self.tweens.len();
185        self.values.push(tween.value());
186        self.tweens.push(tween);
187        index
188    }
189
190    /// Advance every tween by `dt` seconds and refresh the output buffer.
191    pub fn tick(&mut self, dt: f32) {
192        for tween in &mut self.tweens {
193            tween.update(dt);
194        }
195
196        if self.backend() == GpuBackend::Gpu {
197            self.prepare_gpu_inputs();
198            match self.dispatch_gpu() {
199                Ok(()) => return,
200                Err(_) => {
201                    self.force_cpu = true;
202                }
203            }
204        }
205
206        self.refresh_cpu_values();
207    }
208
209    /// Current output values, in insertion order.
210    pub fn read_back(&self) -> &[f32] {
211        &self.values
212    }
213
214    /// Currently active backend.
215    pub fn backend(&self) -> GpuBackend {
216        if self.resources.is_some() && !self.force_cpu {
217            GpuBackend::Gpu
218        } else {
219            GpuBackend::Cpu
220        }
221    }
222
223    /// Number of tweens in the batch.
224    pub fn len(&self) -> usize {
225        self.tweens.len()
226    }
227
228    /// `true` when the batch contains no tweens.
229    pub fn is_empty(&self) -> bool {
230        self.tweens.is_empty()
231    }
232
233    /// Remove all tweens and output values.
234    pub fn clear(&mut self) {
235        self.tweens.clear();
236        self.values.clear();
237        self.inputs.clear();
238        self.force_cpu = false;
239    }
240
241    /// WGSL shader source used by the GPU backend.
242    pub fn shader_source() -> &'static str {
243        SHADER_SOURCE
244    }
245
246    fn prepare_gpu_inputs(&mut self) {
247        self.inputs.clear();
248        self.inputs.reserve(self.tweens.len());
249        for tween in &self.tweens {
250            let easing_id = classic_easing_id(&tween.easing).unwrap_or(0);
251            let (start, end) = if tween.is_ping_pong_reversed() {
252                (tween.end, tween.start)
253            } else {
254                (tween.start, tween.end)
255            };
256            self.inputs.push(GpuTweenInput {
257                start,
258                end,
259                duration: tween.duration,
260                elapsed: tween.elapsed(),
261                easing_id,
262                _pad0: 0,
263                _pad1: 0,
264                _pad2: 0,
265            });
266        }
267    }
268
269    fn dispatch_gpu(&mut self) -> Result<(), GpuBatchError> {
270        let resources = self
271            .resources
272            .as_ref()
273            .ok_or(GpuBatchError::AdapterUnavailable)?;
274        if self.inputs.is_empty() {
275            return Ok(());
276        }
277
278        let input_bytes = bytemuck::cast_slice(&self.inputs);
279        let output_size = (self.values.len() * core::mem::size_of::<f32>()) as wgpu::BufferAddress;
280
281        let input_buffer = resources.device.create_buffer(&wgpu::BufferDescriptor {
282            label: Some("animato-gpu tween input"),
283            size: input_bytes.len() as wgpu::BufferAddress,
284            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
285            mapped_at_creation: false,
286        });
287        resources.queue.write_buffer(&input_buffer, 0, input_bytes);
288
289        let output_buffer = resources.device.create_buffer(&wgpu::BufferDescriptor {
290            label: Some("animato-gpu tween output"),
291            size: output_size,
292            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
293            mapped_at_creation: false,
294        });
295        let readback_buffer = resources.device.create_buffer(&wgpu::BufferDescriptor {
296            label: Some("animato-gpu tween readback"),
297            size: output_size,
298            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
299            mapped_at_creation: false,
300        });
301
302        let bind_group = resources
303            .device
304            .create_bind_group(&wgpu::BindGroupDescriptor {
305                label: Some("animato-gpu tween bind group"),
306                layout: &resources.bind_group_layout,
307                entries: &[
308                    wgpu::BindGroupEntry {
309                        binding: 0,
310                        resource: input_buffer.as_entire_binding(),
311                    },
312                    wgpu::BindGroupEntry {
313                        binding: 1,
314                        resource: output_buffer.as_entire_binding(),
315                    },
316                ],
317            });
318
319        let mut encoder =
320            resources
321                .device
322                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
323                    label: Some("animato-gpu tween encoder"),
324                });
325        {
326            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
327                label: Some("animato-gpu tween pass"),
328                timestamp_writes: None,
329            });
330            pass.set_pipeline(&resources.pipeline);
331            pass.set_bind_group(0, &bind_group, &[]);
332            pass.dispatch_workgroups(self.inputs.len().div_ceil(64) as u32, 1, 1);
333        }
334        encoder.copy_buffer_to_buffer(&output_buffer, 0, &readback_buffer, 0, output_size);
335        resources.queue.submit(Some(encoder.finish()));
336
337        let slice = readback_buffer.slice(..);
338        let (sender, receiver) = mpsc::channel();
339        slice.map_async(wgpu::MapMode::Read, move |result| {
340            let _ = sender.send(result);
341        });
342        resources
343            .device
344            .poll(wgpu::PollType::wait_indefinitely())
345            .map_err(|err| GpuBatchError::RequestDevice(err.to_string()))?;
346        receiver
347            .recv()
348            .map_err(|err| GpuBatchError::RequestDevice(err.to_string()))?
349            .map_err(|err| GpuBatchError::RequestDevice(err.to_string()))?;
350
351        {
352            let mapped = slice.get_mapped_range();
353            let values: &[f32] = bytemuck::cast_slice(&mapped);
354            self.values.copy_from_slice(values);
355        }
356        readback_buffer.unmap();
357
358        Ok(())
359    }
360
361    fn refresh_cpu_values(&mut self) {
362        for (tween, value) in self.tweens.iter().zip(self.values.iter_mut()) {
363            *value = tween.value();
364        }
365    }
366}
367
368#[inline]
369fn classic_easing_id(easing: &Easing) -> Option<u32> {
370    Some(match easing {
371        Easing::Linear => 0,
372        Easing::EaseInQuad => 1,
373        Easing::EaseOutQuad => 2,
374        Easing::EaseInOutQuad => 3,
375        Easing::EaseInCubic => 4,
376        Easing::EaseOutCubic => 5,
377        Easing::EaseInOutCubic => 6,
378        Easing::EaseInQuart => 7,
379        Easing::EaseOutQuart => 8,
380        Easing::EaseInOutQuart => 9,
381        Easing::EaseInQuint => 10,
382        Easing::EaseOutQuint => 11,
383        Easing::EaseInOutQuint => 12,
384        Easing::EaseInSine => 13,
385        Easing::EaseOutSine => 14,
386        Easing::EaseInOutSine => 15,
387        Easing::EaseInExpo => 16,
388        Easing::EaseOutExpo => 17,
389        Easing::EaseInOutExpo => 18,
390        Easing::EaseInCirc => 19,
391        Easing::EaseOutCirc => 20,
392        Easing::EaseInOutCirc => 21,
393        Easing::EaseInBack => 22,
394        Easing::EaseOutBack => 23,
395        Easing::EaseInOutBack => 24,
396        Easing::EaseInElastic => 25,
397        Easing::EaseOutElastic => 26,
398        Easing::EaseInOutElastic => 27,
399        Easing::EaseInBounce => 28,
400        Easing::EaseOutBounce => 29,
401        Easing::EaseInOutBounce => 30,
402        _ => return None,
403    })
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use animato_core::Easing;
410
411    #[test]
412    fn cpu_batch_matches_regular_tween_values() {
413        let mut expected = Tween::new(0.0_f32, 100.0)
414            .duration(1.0)
415            .easing(Easing::EaseOutCubic)
416            .build();
417        let mut batch = GpuAnimationBatch::new_cpu();
418        batch.push(
419            Tween::new(0.0_f32, 100.0)
420                .duration(1.0)
421                .easing(Easing::EaseOutCubic)
422                .build(),
423        );
424
425        expected.update(0.25);
426        batch.tick(0.25);
427
428        assert!((batch.read_back()[0] - expected.value()).abs() < 0.0001);
429    }
430
431    #[test]
432    fn unsupported_easing_keeps_cpu_backend() {
433        let mut batch = GpuAnimationBatch::new_cpu();
434        batch.push(
435            Tween::new(0.0_f32, 1.0)
436                .easing(Easing::CubicBezier(0.25, 0.1, 0.25, 1.0))
437                .build(),
438        );
439        assert_eq!(batch.backend(), GpuBackend::Cpu);
440    }
441
442    #[test]
443    fn shader_source_is_embedded() {
444        assert!(GpuAnimationBatch::shader_source().contains("@compute"));
445        assert!(GpuAnimationBatch::shader_source().contains("ease_out_bounce"));
446    }
447
448    #[test]
449    fn default_len_clear_and_empty_tick_are_cpu_safe() {
450        let mut batch = GpuAnimationBatch::default();
451
452        assert_eq!(batch.backend(), GpuBackend::Cpu);
453        assert!(batch.is_empty());
454        assert_eq!(batch.len(), 0);
455        batch.tick(0.25);
456        assert!(batch.read_back().is_empty());
457
458        let index = batch.push(Tween::new(1.0_f32, 3.0).duration(1.0).build());
459        assert_eq!(index, 0);
460        assert_eq!(batch.len(), 1);
461        assert_eq!(batch.read_back(), &[1.0]);
462
463        batch.clear();
464        assert!(batch.is_empty());
465        assert!(batch.read_back().is_empty());
466        assert_eq!(batch.backend(), GpuBackend::Cpu);
467    }
468
469    #[test]
470    fn supported_easing_ids_cover_all_shader_variants() {
471        let supported = [
472            Easing::Linear,
473            Easing::EaseInQuad,
474            Easing::EaseOutQuad,
475            Easing::EaseInOutQuad,
476            Easing::EaseInCubic,
477            Easing::EaseOutCubic,
478            Easing::EaseInOutCubic,
479            Easing::EaseInQuart,
480            Easing::EaseOutQuart,
481            Easing::EaseInOutQuart,
482            Easing::EaseInQuint,
483            Easing::EaseOutQuint,
484            Easing::EaseInOutQuint,
485            Easing::EaseInSine,
486            Easing::EaseOutSine,
487            Easing::EaseInOutSine,
488            Easing::EaseInExpo,
489            Easing::EaseOutExpo,
490            Easing::EaseInOutExpo,
491            Easing::EaseInCirc,
492            Easing::EaseOutCirc,
493            Easing::EaseInOutCirc,
494            Easing::EaseInBack,
495            Easing::EaseOutBack,
496            Easing::EaseInOutBack,
497            Easing::EaseInElastic,
498            Easing::EaseOutElastic,
499            Easing::EaseInOutElastic,
500            Easing::EaseInBounce,
501            Easing::EaseOutBounce,
502            Easing::EaseInOutBounce,
503        ];
504
505        for (index, easing) in supported.iter().enumerate() {
506            assert_eq!(classic_easing_id(easing), Some(index as u32));
507        }
508        assert_eq!(classic_easing_id(&Easing::Steps(4)), None);
509    }
510
511    #[test]
512    fn cpu_fallback_handles_multiple_tweens_and_loops() {
513        let mut batch = GpuAnimationBatch::new_cpu();
514        batch.push(
515            Tween::new(0.0_f32, 10.0)
516                .duration(1.0)
517                .looping(animato_tween::Loop::Forever)
518                .build(),
519        );
520        batch.push(
521            Tween::new(10.0_f32, 0.0)
522                .duration(2.0)
523                .easing(Easing::EaseInOutQuad)
524                .build(),
525        );
526
527        batch.tick(1.25);
528
529        assert!((batch.read_back()[0] - 2.5).abs() < 0.001);
530        assert!(batch.read_back()[1] < 5.0);
531    }
532
533    #[test]
534    fn auto_constructor_falls_back_or_reports_gpu_without_panicking() {
535        let mut batch = GpuAnimationBatch::new_auto();
536
537        assert!(matches!(batch.backend(), GpuBackend::Cpu | GpuBackend::Gpu));
538        batch.push(Tween::new(0.0_f32, 1.0).duration(0.1).build());
539        batch.tick(0.1);
540        assert_eq!(batch.read_back().len(), 1);
541    }
542
543    #[test]
544    fn gpu_error_debug_and_equality_are_stable() {
545        let adapter = GpuBatchError::AdapterUnavailable;
546        let device = GpuBatchError::RequestDevice("lost".to_owned());
547
548        assert_eq!(adapter, GpuBatchError::AdapterUnavailable);
549        assert_ne!(adapter, device);
550        assert!(format!("{device:?}").contains("lost"));
551    }
552}