Skip to main content

oximedia_gpu/
kernel.rs

1//! GPU kernel management — kernel types, specs, and caching.
2
3#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6/// Type of GPU kernel, determining workgroup shape and shared-memory usage.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum KernelType {
9    /// Video scaling kernel.
10    VideoScale,
11    /// Color space conversion kernel.
12    ColorConvert,
13    /// Histogram accumulation kernel.
14    Histogram,
15    /// Motion estimation kernel.
16    MotionEstimate,
17    /// Denoising kernel.
18    Denoise,
19    /// Sharpening kernel.
20    Sharpen,
21}
22
23impl KernelType {
24    /// Returns the preferred (x, y) workgroup size for this kernel type.
25    #[must_use]
26    pub fn workgroup_size(&self) -> (u32, u32) {
27        match self {
28            Self::VideoScale => (16, 16),
29            Self::ColorConvert => (32, 8),
30            Self::Histogram => (256, 1),
31            Self::MotionEstimate => (8, 8),
32            Self::Denoise => (16, 16),
33            Self::Sharpen => (16, 16),
34        }
35    }
36
37    /// Returns whether this kernel type requires shared memory.
38    #[must_use]
39    pub fn requires_shared_memory(&self) -> bool {
40        match self {
41            Self::Histogram | Self::MotionEstimate | Self::Denoise => true,
42            Self::VideoScale | Self::ColorConvert | Self::Sharpen => false,
43        }
44    }
45}
46
47/// Specification of a single GPU kernel invocation.
48#[derive(Debug, Clone)]
49pub struct KernelSpec {
50    /// The kernel type.
51    pub kernel_type: KernelType,
52    /// Number of input channels.
53    pub input_channels: u8,
54    /// Number of output channels.
55    pub output_channels: u8,
56    /// Image / buffer width in elements.
57    pub width: u32,
58    /// Image / buffer height in elements.
59    pub height: u32,
60}
61
62impl KernelSpec {
63    /// Creates a new `KernelSpec`.
64    #[must_use]
65    pub fn new(
66        kernel_type: KernelType,
67        input_channels: u8,
68        output_channels: u8,
69        width: u32,
70        height: u32,
71    ) -> Self {
72        Self {
73            kernel_type,
74            input_channels,
75            output_channels,
76            width,
77            height,
78        }
79    }
80
81    /// Total number of elements processed (width × height).
82    #[must_use]
83    pub fn total_elements(&self) -> u64 {
84        u64::from(self.width) * u64::from(self.height)
85    }
86
87    /// Rough estimate of floating-point operations for this kernel.
88    ///
89    /// Uses heuristic multipliers per kernel type and channel count.
90    #[must_use]
91    pub fn estimated_flops(&self) -> u64 {
92        let elements = self.total_elements();
93        let channels = u64::from(self.input_channels.max(self.output_channels));
94        let per_element: u64 = match self.kernel_type {
95            KernelType::VideoScale => 8,
96            KernelType::ColorConvert => 12,
97            KernelType::Histogram => 2,
98            KernelType::MotionEstimate => 32,
99            KernelType::Denoise => 64,
100            KernelType::Sharpen => 16,
101        };
102        elements * channels * per_element
103    }
104}
105
106/// A cache that stores multiple [`KernelSpec`] entries and provides lookup.
107#[derive(Debug, Default)]
108pub struct KernelCache {
109    specs: Vec<KernelSpec>,
110}
111
112impl KernelCache {
113    /// Creates an empty `KernelCache`.
114    #[must_use]
115    pub fn new() -> Self {
116        Self::default()
117    }
118
119    /// Adds a kernel spec to the cache.
120    pub fn add(&mut self, spec: KernelSpec) {
121        self.specs.push(spec);
122    }
123
124    /// Finds the first spec whose `kernel_type` matches `kt`.
125    #[must_use]
126    pub fn find(&self, kt: &KernelType) -> Option<&KernelSpec> {
127        self.specs.iter().find(|s| &s.kernel_type == kt)
128    }
129
130    /// Estimates total GPU memory required for all cached specs (bytes).
131    ///
132    /// Each element occupies 4 bytes (f32), multiplied by channel count.
133    #[must_use]
134    pub fn total_memory_estimate_bytes(&self) -> u64 {
135        self.specs.iter().fold(0u64, |acc, s| {
136            let channels = u64::from(s.input_channels) + u64::from(s.output_channels);
137            acc + s.total_elements() * channels * 4
138        })
139    }
140
141    /// Returns the number of specs in the cache.
142    #[must_use]
143    pub fn kernel_count(&self) -> usize {
144        self.specs.len()
145    }
146}
147
148// ---------------------------------------------------------------------------
149// Unit tests
150// ---------------------------------------------------------------------------
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_video_scale_workgroup() {
158        assert_eq!(KernelType::VideoScale.workgroup_size(), (16, 16));
159    }
160
161    #[test]
162    fn test_histogram_workgroup() {
163        assert_eq!(KernelType::Histogram.workgroup_size(), (256, 1));
164    }
165
166    #[test]
167    fn test_motion_estimate_workgroup() {
168        assert_eq!(KernelType::MotionEstimate.workgroup_size(), (8, 8));
169    }
170
171    #[test]
172    fn test_color_convert_workgroup() {
173        assert_eq!(KernelType::ColorConvert.workgroup_size(), (32, 8));
174    }
175
176    #[test]
177    fn test_requires_shared_memory_histogram() {
178        assert!(KernelType::Histogram.requires_shared_memory());
179    }
180
181    #[test]
182    fn test_requires_shared_memory_motion() {
183        assert!(KernelType::MotionEstimate.requires_shared_memory());
184    }
185
186    #[test]
187    fn test_no_shared_memory_video_scale() {
188        assert!(!KernelType::VideoScale.requires_shared_memory());
189    }
190
191    #[test]
192    fn test_no_shared_memory_sharpen() {
193        assert!(!KernelType::Sharpen.requires_shared_memory());
194    }
195
196    #[test]
197    fn test_total_elements() {
198        let spec = KernelSpec::new(KernelType::VideoScale, 3, 3, 1920, 1080);
199        assert_eq!(spec.total_elements(), 1920 * 1080);
200    }
201
202    #[test]
203    fn test_total_elements_zero_width() {
204        let spec = KernelSpec::new(KernelType::Sharpen, 4, 4, 0, 1080);
205        assert_eq!(spec.total_elements(), 0);
206    }
207
208    #[test]
209    fn test_estimated_flops_positive() {
210        let spec = KernelSpec::new(KernelType::Denoise, 4, 4, 256, 256);
211        assert!(spec.estimated_flops() > 0);
212    }
213
214    #[test]
215    fn test_estimated_flops_denoise_greater_than_histogram() {
216        let base = KernelSpec::new(KernelType::Denoise, 4, 4, 256, 256);
217        let hist = KernelSpec::new(KernelType::Histogram, 4, 4, 256, 256);
218        assert!(base.estimated_flops() > hist.estimated_flops());
219    }
220
221    #[test]
222    fn test_cache_add_and_count() {
223        let mut cache = KernelCache::new();
224        cache.add(KernelSpec::new(KernelType::VideoScale, 4, 4, 1920, 1080));
225        cache.add(KernelSpec::new(KernelType::Histogram, 3, 3, 1920, 1080));
226        assert_eq!(cache.kernel_count(), 2);
227    }
228
229    #[test]
230    fn test_cache_find_existing() {
231        let mut cache = KernelCache::new();
232        cache.add(KernelSpec::new(KernelType::Sharpen, 4, 4, 640, 480));
233        let found = cache.find(&KernelType::Sharpen);
234        assert!(found.is_some());
235        assert_eq!(found.unwrap().width, 640);
236    }
237
238    #[test]
239    fn test_cache_find_missing() {
240        let cache = KernelCache::new();
241        assert!(cache.find(&KernelType::Denoise).is_none());
242    }
243
244    #[test]
245    fn test_cache_memory_estimate_nonzero() {
246        let mut cache = KernelCache::new();
247        cache.add(KernelSpec::new(KernelType::ColorConvert, 4, 4, 1920, 1080));
248        assert!(cache.total_memory_estimate_bytes() > 0);
249    }
250
251    #[test]
252    fn test_cache_empty_memory_estimate() {
253        let cache = KernelCache::new();
254        assert_eq!(cache.total_memory_estimate_bytes(), 0);
255    }
256
257    #[test]
258    fn test_cache_find_first_match() {
259        let mut cache = KernelCache::new();
260        cache.add(KernelSpec::new(KernelType::VideoScale, 3, 3, 100, 100));
261        cache.add(KernelSpec::new(KernelType::VideoScale, 4, 4, 200, 200));
262        // Must return the first inserted spec.
263        let found = cache.find(&KernelType::VideoScale).unwrap();
264        assert_eq!(found.width, 100);
265    }
266
267    #[test]
268    fn test_kernel_spec_new() {
269        let spec = KernelSpec::new(KernelType::MotionEstimate, 1, 2, 320, 240);
270        assert_eq!(spec.input_channels, 1);
271        assert_eq!(spec.output_channels, 2);
272        assert_eq!(spec.height, 240);
273    }
274}