1use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use wgpu::{Device, DeviceType};
11
12#[derive(Clone)]
14pub struct AdaptiveSelector {
15 #[allow(dead_code)]
17 device: Arc<Device>,
18 device_info: DeviceInfo,
19 performance_history: Arc<RwLock<PerformanceHistory>>,
20 #[allow(dead_code)]
22 config: AdaptiveConfig,
23}
24
25impl AdaptiveSelector {
26 pub fn new(device: Arc<Device>, device_type: DeviceType) -> Self {
28 let device_info = DeviceInfo::from_device(&device, device_type);
29
30 Self {
31 device,
32 device_info,
33 performance_history: Arc::new(RwLock::new(PerformanceHistory::new())),
34 config: AdaptiveConfig::default(),
35 }
36 }
37
38 pub fn select_algorithm(&self, operation: &str, workload: &WorkloadInfo) -> Algorithm {
40 let history = self.performance_history.read();
42 if let Some(best) = history.get_best_algorithm(operation, workload) {
43 return best;
44 }
45 drop(history);
46
47 self.select_by_heuristics(operation, workload)
49 }
50
51 fn select_by_heuristics(&self, operation: &str, workload: &WorkloadInfo) -> Algorithm {
53 match operation {
54 "matrix_multiply" => self.select_matmul_algorithm(workload),
55 "convolution" => self.select_convolution_algorithm(workload),
56 "reduction" => self.select_reduction_algorithm(workload),
57 "sort" => self.select_sort_algorithm(workload),
58 "fft" => self.select_fft_algorithm(workload),
59 _ => Algorithm::default(),
60 }
61 }
62
63 fn select_matmul_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
65 let size = workload.data_size;
66
67 if size < 1024 * 1024 {
68 Algorithm {
70 name: "matmul_naive".to_string(),
71 workgroup_size: (8, 8, 1),
72 strategy: ExecutionStrategy::Direct,
73 tuning_params: TuningParams::default(),
74 }
75 } else if size < 16 * 1024 * 1024 {
76 Algorithm {
78 name: "matmul_tiled".to_string(),
79 workgroup_size: (16, 16, 1),
80 strategy: ExecutionStrategy::Tiled { tile_size: 32 },
81 tuning_params: TuningParams {
82 use_shared_memory: true,
83 unroll_factor: 4,
84 vectorize: true,
85 },
86 }
87 } else {
88 Algorithm {
90 name: "matmul_hierarchical".to_string(),
91 workgroup_size: (16, 16, 1),
92 strategy: ExecutionStrategy::Hierarchical { levels: 2 },
93 tuning_params: TuningParams {
94 use_shared_memory: true,
95 unroll_factor: 8,
96 vectorize: true,
97 },
98 }
99 }
100 }
101
102 fn select_convolution_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
104 let kernel_size = workload.dimensions.first().copied().unwrap_or(3);
105
106 if kernel_size <= 3 {
107 Algorithm {
108 name: "conv_direct".to_string(),
109 workgroup_size: (8, 8, 1),
110 strategy: ExecutionStrategy::Direct,
111 tuning_params: TuningParams {
112 use_shared_memory: true,
113 unroll_factor: 1,
114 vectorize: false,
115 },
116 }
117 } else {
118 Algorithm {
119 name: "conv_im2col".to_string(),
120 workgroup_size: (16, 16, 1),
121 strategy: ExecutionStrategy::Transform,
122 tuning_params: TuningParams {
123 use_shared_memory: true,
124 unroll_factor: 4,
125 vectorize: true,
126 },
127 }
128 }
129 }
130
131 fn select_reduction_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
133 let compute_units = self.device_info.compute_units.unwrap_or(8);
134
135 Algorithm {
136 name: "reduce_hierarchical".to_string(),
137 workgroup_size: (256, 1, 1),
138 strategy: ExecutionStrategy::Hierarchical {
139 levels: (workload.data_size as f32).log2().ceil() as usize / 8,
140 },
141 tuning_params: TuningParams {
142 use_shared_memory: true,
143 unroll_factor: (compute_units / 8).max(1),
144 vectorize: true,
145 },
146 }
147 }
148
149 fn select_sort_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
151 if workload.data_size < 1024 {
152 Algorithm {
153 name: "sort_insertion".to_string(),
154 workgroup_size: (64, 1, 1),
155 strategy: ExecutionStrategy::Direct,
156 tuning_params: TuningParams::default(),
157 }
158 } else if workload.data_size < 1024 * 1024 {
159 Algorithm {
160 name: "sort_bitonic".to_string(),
161 workgroup_size: (128, 1, 1),
162 strategy: ExecutionStrategy::Parallel,
163 tuning_params: TuningParams {
164 use_shared_memory: true,
165 unroll_factor: 2,
166 vectorize: false,
167 },
168 }
169 } else {
170 Algorithm {
171 name: "sort_radix".to_string(),
172 workgroup_size: (256, 1, 1),
173 strategy: ExecutionStrategy::Hierarchical { levels: 4 },
174 tuning_params: TuningParams {
175 use_shared_memory: true,
176 unroll_factor: 4,
177 vectorize: true,
178 },
179 }
180 }
181 }
182
183 fn select_fft_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
185 let size = workload.data_size;
186 let is_power_of_2 = (size & (size - 1)) == 0;
187
188 if is_power_of_2 {
189 Algorithm {
190 name: "fft_cooley_tukey".to_string(),
191 workgroup_size: (256, 1, 1),
192 strategy: ExecutionStrategy::Hierarchical {
193 levels: (size as f32).log2() as usize,
194 },
195 tuning_params: TuningParams {
196 use_shared_memory: true,
197 unroll_factor: 4,
198 vectorize: true,
199 },
200 }
201 } else {
202 Algorithm {
203 name: "fft_bluestein".to_string(),
204 workgroup_size: (128, 1, 1),
205 strategy: ExecutionStrategy::Transform,
206 tuning_params: TuningParams {
207 use_shared_memory: false,
208 unroll_factor: 2,
209 vectorize: false,
210 },
211 }
212 }
213 }
214
215 pub fn record_performance(
217 &self,
218 operation: &str,
219 workload: &WorkloadInfo,
220 algorithm: &Algorithm,
221 duration: Duration,
222 ) {
223 let mut history = self.performance_history.write();
224 history.record(operation, workload, algorithm, duration);
225 }
226
227 pub fn get_statistics(&self, operation: &str) -> Option<AlgorithmStats> {
229 let history = self.performance_history.read();
230 history.get_stats(operation)
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct DeviceInfo {
237 pub device_type: DeviceType,
239 pub compute_units: Option<u32>,
241 pub memory_bandwidth: Option<f32>,
243 pub peak_flops: Option<f64>,
245}
246
247impl DeviceInfo {
248 fn from_device(_device: &Device, device_type: DeviceType) -> Self {
249 let compute_units = match device_type {
250 DeviceType::DiscreteGpu => Some(64),
251 DeviceType::IntegratedGpu => Some(16),
252 _ => None,
253 };
254
255 Self {
256 device_type,
257 compute_units,
258 memory_bandwidth: None,
259 peak_flops: None,
260 }
261 }
262}
263
264#[derive(Debug, Clone, Hash, PartialEq, Eq)]
266pub struct WorkloadInfo {
267 pub data_size: u64,
269 pub dimensions: Vec<u32>,
271 pub element_size: u32,
273}
274
275#[derive(Debug, Clone)]
277pub struct Algorithm {
278 pub name: String,
280 pub workgroup_size: (u32, u32, u32),
282 pub strategy: ExecutionStrategy,
284 pub tuning_params: TuningParams,
286}
287
288impl Default for Algorithm {
289 fn default() -> Self {
290 Self {
291 name: "default".to_string(),
292 workgroup_size: (8, 8, 1),
293 strategy: ExecutionStrategy::Direct,
294 tuning_params: TuningParams::default(),
295 }
296 }
297}
298
299#[derive(Debug, Clone, Copy)]
301pub enum ExecutionStrategy {
302 Direct,
304 Tiled {
306 tile_size: u32,
308 },
309 Hierarchical {
311 levels: usize,
313 },
314 Transform,
316 Parallel,
318}
319
320#[derive(Debug, Clone)]
322pub struct TuningParams {
323 pub use_shared_memory: bool,
325 pub unroll_factor: u32,
327 pub vectorize: bool,
329}
330
331impl Default for TuningParams {
332 fn default() -> Self {
333 Self {
334 use_shared_memory: false,
335 unroll_factor: 1,
336 vectorize: false,
337 }
338 }
339}
340
341struct PerformanceHistory {
343 records: HashMap<String, Vec<PerformanceRecord>>,
344 max_records_per_operation: usize,
345}
346
347impl PerformanceHistory {
348 fn new() -> Self {
349 Self {
350 records: HashMap::new(),
351 max_records_per_operation: 100,
352 }
353 }
354
355 fn record(
356 &mut self,
357 operation: &str,
358 workload: &WorkloadInfo,
359 algorithm: &Algorithm,
360 duration: Duration,
361 ) {
362 let record = PerformanceRecord {
363 workload: workload.clone(),
364 algorithm_name: algorithm.name.clone(),
365 duration,
366 timestamp: Instant::now(),
367 };
368
369 let records = self.records.entry(operation.to_string()).or_default();
370 records.push(record);
371
372 if records.len() > self.max_records_per_operation {
374 records.remove(0);
375 }
376 }
377
378 fn get_best_algorithm(&self, operation: &str, workload: &WorkloadInfo) -> Option<Algorithm> {
379 let records = self.records.get(operation)?;
380
381 let mut similar: Vec<_> = records
383 .iter()
384 .filter(|r| Self::is_similar_workload(&r.workload, workload))
385 .collect();
386
387 if similar.is_empty() {
388 return None;
389 }
390
391 similar.sort_by_key(|r| r.duration);
393
394 Some(Algorithm {
396 name: similar[0].algorithm_name.clone(),
397 ..Algorithm::default()
398 })
399 }
400
401 fn is_similar_workload(w1: &WorkloadInfo, w2: &WorkloadInfo) -> bool {
402 let size_ratio = (w1.data_size as f64) / (w2.data_size as f64);
404 size_ratio > 0.8 && size_ratio < 1.2 && w1.dimensions.len() == w2.dimensions.len()
405 }
406
407 fn get_stats(&self, operation: &str) -> Option<AlgorithmStats> {
408 let records = self.records.get(operation)?;
409
410 if records.is_empty() {
411 return None;
412 }
413
414 let total_duration: Duration = records.iter().map(|r| r.duration).sum();
415 let count = records.len() as u32;
416
417 Some(AlgorithmStats {
418 total_executions: count,
419 average_duration: total_duration / count,
420 total_duration,
421 })
422 }
423}
424
425#[derive(Debug, Clone)]
427struct PerformanceRecord {
428 workload: WorkloadInfo,
429 algorithm_name: String,
430 duration: Duration,
431 #[allow(dead_code)]
433 timestamp: Instant,
434}
435
436#[derive(Debug, Clone)]
438pub struct AlgorithmStats {
439 pub total_executions: u32,
441 pub average_duration: Duration,
443 pub total_duration: Duration,
445}
446
447#[derive(Debug, Clone)]
449pub struct AdaptiveConfig {
450 pub auto_tune: bool,
452 pub min_samples: usize,
454}
455
456impl Default for AdaptiveConfig {
457 fn default() -> Self {
458 Self {
459 auto_tune: true,
460 min_samples: 3,
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_workload_similarity() {
471 let w1 = WorkloadInfo {
472 data_size: 1000,
473 dimensions: vec![100, 10],
474 element_size: 4,
475 };
476
477 let w2 = WorkloadInfo {
478 data_size: 1100,
479 dimensions: vec![110, 10],
480 element_size: 4,
481 };
482
483 let w3 = WorkloadInfo {
484 data_size: 2000,
485 dimensions: vec![100, 20],
486 element_size: 4,
487 };
488
489 assert!(PerformanceHistory::is_similar_workload(&w1, &w2));
490 assert!(!PerformanceHistory::is_similar_workload(&w1, &w3));
491 }
492
493 #[test]
494 fn test_algorithm_default() {
495 let algo = Algorithm::default();
496 assert_eq!(algo.name, "default");
497 assert_eq!(algo.workgroup_size, (8, 8, 1));
498 }
499}