1#![allow(dead_code)]
4#![allow(clippy::cast_precision_loss)]
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum KernelType {
9 VideoScale,
11 ColorConvert,
13 Histogram,
15 MotionEstimate,
17 Denoise,
19 Sharpen,
21}
22
23impl KernelType {
24 #[must_use]
26 pub fn workgroup_size(&self) -> (u32, u32) {
27 match self {
28 Self::VideoScale => (16, 16),
29 Self::ColorConvert => (32, 8),
30 Self::Histogram => (256, 1),
31 Self::MotionEstimate => (8, 8),
32 Self::Denoise => (16, 16),
33 Self::Sharpen => (16, 16),
34 }
35 }
36
37 #[must_use]
39 pub fn requires_shared_memory(&self) -> bool {
40 match self {
41 Self::Histogram | Self::MotionEstimate | Self::Denoise => true,
42 Self::VideoScale | Self::ColorConvert | Self::Sharpen => false,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct KernelSpec {
50 pub kernel_type: KernelType,
52 pub input_channels: u8,
54 pub output_channels: u8,
56 pub width: u32,
58 pub height: u32,
60}
61
62impl KernelSpec {
63 #[must_use]
65 pub fn new(
66 kernel_type: KernelType,
67 input_channels: u8,
68 output_channels: u8,
69 width: u32,
70 height: u32,
71 ) -> Self {
72 Self {
73 kernel_type,
74 input_channels,
75 output_channels,
76 width,
77 height,
78 }
79 }
80
81 #[must_use]
83 pub fn total_elements(&self) -> u64 {
84 u64::from(self.width) * u64::from(self.height)
85 }
86
87 #[must_use]
91 pub fn estimated_flops(&self) -> u64 {
92 let elements = self.total_elements();
93 let channels = u64::from(self.input_channels.max(self.output_channels));
94 let per_element: u64 = match self.kernel_type {
95 KernelType::VideoScale => 8,
96 KernelType::ColorConvert => 12,
97 KernelType::Histogram => 2,
98 KernelType::MotionEstimate => 32,
99 KernelType::Denoise => 64,
100 KernelType::Sharpen => 16,
101 };
102 elements * channels * per_element
103 }
104}
105
106#[derive(Debug, Default)]
108pub struct KernelCache {
109 specs: Vec<KernelSpec>,
110}
111
112impl KernelCache {
113 #[must_use]
115 pub fn new() -> Self {
116 Self::default()
117 }
118
119 pub fn add(&mut self, spec: KernelSpec) {
121 self.specs.push(spec);
122 }
123
124 #[must_use]
126 pub fn find(&self, kt: &KernelType) -> Option<&KernelSpec> {
127 self.specs.iter().find(|s| &s.kernel_type == kt)
128 }
129
130 #[must_use]
134 pub fn total_memory_estimate_bytes(&self) -> u64 {
135 self.specs.iter().fold(0u64, |acc, s| {
136 let channels = u64::from(s.input_channels) + u64::from(s.output_channels);
137 acc + s.total_elements() * channels * 4
138 })
139 }
140
141 #[must_use]
143 pub fn kernel_count(&self) -> usize {
144 self.specs.len()
145 }
146}
147
148#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_video_scale_workgroup() {
158 assert_eq!(KernelType::VideoScale.workgroup_size(), (16, 16));
159 }
160
161 #[test]
162 fn test_histogram_workgroup() {
163 assert_eq!(KernelType::Histogram.workgroup_size(), (256, 1));
164 }
165
166 #[test]
167 fn test_motion_estimate_workgroup() {
168 assert_eq!(KernelType::MotionEstimate.workgroup_size(), (8, 8));
169 }
170
171 #[test]
172 fn test_color_convert_workgroup() {
173 assert_eq!(KernelType::ColorConvert.workgroup_size(), (32, 8));
174 }
175
176 #[test]
177 fn test_requires_shared_memory_histogram() {
178 assert!(KernelType::Histogram.requires_shared_memory());
179 }
180
181 #[test]
182 fn test_requires_shared_memory_motion() {
183 assert!(KernelType::MotionEstimate.requires_shared_memory());
184 }
185
186 #[test]
187 fn test_no_shared_memory_video_scale() {
188 assert!(!KernelType::VideoScale.requires_shared_memory());
189 }
190
191 #[test]
192 fn test_no_shared_memory_sharpen() {
193 assert!(!KernelType::Sharpen.requires_shared_memory());
194 }
195
196 #[test]
197 fn test_total_elements() {
198 let spec = KernelSpec::new(KernelType::VideoScale, 3, 3, 1920, 1080);
199 assert_eq!(spec.total_elements(), 1920 * 1080);
200 }
201
202 #[test]
203 fn test_total_elements_zero_width() {
204 let spec = KernelSpec::new(KernelType::Sharpen, 4, 4, 0, 1080);
205 assert_eq!(spec.total_elements(), 0);
206 }
207
208 #[test]
209 fn test_estimated_flops_positive() {
210 let spec = KernelSpec::new(KernelType::Denoise, 4, 4, 256, 256);
211 assert!(spec.estimated_flops() > 0);
212 }
213
214 #[test]
215 fn test_estimated_flops_denoise_greater_than_histogram() {
216 let base = KernelSpec::new(KernelType::Denoise, 4, 4, 256, 256);
217 let hist = KernelSpec::new(KernelType::Histogram, 4, 4, 256, 256);
218 assert!(base.estimated_flops() > hist.estimated_flops());
219 }
220
221 #[test]
222 fn test_cache_add_and_count() {
223 let mut cache = KernelCache::new();
224 cache.add(KernelSpec::new(KernelType::VideoScale, 4, 4, 1920, 1080));
225 cache.add(KernelSpec::new(KernelType::Histogram, 3, 3, 1920, 1080));
226 assert_eq!(cache.kernel_count(), 2);
227 }
228
229 #[test]
230 fn test_cache_find_existing() {
231 let mut cache = KernelCache::new();
232 cache.add(KernelSpec::new(KernelType::Sharpen, 4, 4, 640, 480));
233 let found = cache.find(&KernelType::Sharpen);
234 assert!(found.is_some());
235 assert_eq!(found.expect("operation should succeed in test").width, 640);
236 }
237
238 #[test]
239 fn test_cache_find_missing() {
240 let cache = KernelCache::new();
241 assert!(cache.find(&KernelType::Denoise).is_none());
242 }
243
244 #[test]
245 fn test_cache_memory_estimate_nonzero() {
246 let mut cache = KernelCache::new();
247 cache.add(KernelSpec::new(KernelType::ColorConvert, 4, 4, 1920, 1080));
248 assert!(cache.total_memory_estimate_bytes() > 0);
249 }
250
251 #[test]
252 fn test_cache_empty_memory_estimate() {
253 let cache = KernelCache::new();
254 assert_eq!(cache.total_memory_estimate_bytes(), 0);
255 }
256
257 #[test]
258 fn test_cache_find_first_match() {
259 let mut cache = KernelCache::new();
260 cache.add(KernelSpec::new(KernelType::VideoScale, 3, 3, 100, 100));
261 cache.add(KernelSpec::new(KernelType::VideoScale, 4, 4, 200, 200));
262 let found = cache
264 .find(&KernelType::VideoScale)
265 .expect("find should return a result");
266 assert_eq!(found.width, 100);
267 }
268
269 #[test]
270 fn test_kernel_spec_new() {
271 let spec = KernelSpec::new(KernelType::MotionEstimate, 1, 2, 320, 240);
272 assert_eq!(spec.input_channels, 1);
273 assert_eq!(spec.output_channels, 2);
274 assert_eq!(spec.height, 240);
275 }
276}