Skip to main content

oxigdal_gpu_advanced/adaptive/
mod.rs

1//! Adaptive algorithm selection and auto-tuning for GPUs.
2//!
3//! This module provides intelligent algorithm selection based on hardware
4//! capabilities and workload characteristics with performance feedback loops.
5
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use wgpu::{Device, DeviceType};
11
12/// Adaptive algorithm selector
13#[derive(Clone)]
14pub struct AdaptiveSelector {
15    /// The GPU device (used for future adaptive selection)
16    #[allow(dead_code)]
17    device: Arc<Device>,
18    device_info: DeviceInfo,
19    performance_history: Arc<RwLock<PerformanceHistory>>,
20    /// Configuration for adaptive selection (used for future tuning)
21    #[allow(dead_code)]
22    config: AdaptiveConfig,
23}
24
25impl AdaptiveSelector {
26    /// Create a new adaptive selector
27    pub fn new(device: Arc<Device>, device_type: DeviceType) -> Self {
28        let device_info = DeviceInfo::from_device(&device, device_type);
29
30        Self {
31            device,
32            device_info,
33            performance_history: Arc::new(RwLock::new(PerformanceHistory::new())),
34            config: AdaptiveConfig::default(),
35        }
36    }
37
38    /// Select best algorithm for a given operation
39    pub fn select_algorithm(&self, operation: &str, workload: &WorkloadInfo) -> Algorithm {
40        // Check if we have historical data
41        let history = self.performance_history.read();
42        if let Some(best) = history.get_best_algorithm(operation, workload) {
43            return best;
44        }
45        drop(history);
46
47        // No history, use heuristics based on hardware and workload
48        self.select_by_heuristics(operation, workload)
49    }
50
51    /// Select algorithm using hardware heuristics
52    fn select_by_heuristics(&self, operation: &str, workload: &WorkloadInfo) -> Algorithm {
53        match operation {
54            "matrix_multiply" => self.select_matmul_algorithm(workload),
55            "convolution" => self.select_convolution_algorithm(workload),
56            "reduction" => self.select_reduction_algorithm(workload),
57            "sort" => self.select_sort_algorithm(workload),
58            "fft" => self.select_fft_algorithm(workload),
59            _ => Algorithm::default(),
60        }
61    }
62
63    /// Select matrix multiplication algorithm
64    fn select_matmul_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
65        let size = workload.data_size;
66
67        if size < 1024 * 1024 {
68            // Small matrices: use naive algorithm
69            Algorithm {
70                name: "matmul_naive".to_string(),
71                workgroup_size: (8, 8, 1),
72                strategy: ExecutionStrategy::Direct,
73                tuning_params: TuningParams::default(),
74            }
75        } else if size < 16 * 1024 * 1024 {
76            // Medium matrices: use tiled algorithm
77            Algorithm {
78                name: "matmul_tiled".to_string(),
79                workgroup_size: (16, 16, 1),
80                strategy: ExecutionStrategy::Tiled { tile_size: 32 },
81                tuning_params: TuningParams {
82                    use_shared_memory: true,
83                    unroll_factor: 4,
84                    vectorize: true,
85                },
86            }
87        } else {
88            // Large matrices: use hierarchical algorithm
89            Algorithm {
90                name: "matmul_hierarchical".to_string(),
91                workgroup_size: (16, 16, 1),
92                strategy: ExecutionStrategy::Hierarchical { levels: 2 },
93                tuning_params: TuningParams {
94                    use_shared_memory: true,
95                    unroll_factor: 8,
96                    vectorize: true,
97                },
98            }
99        }
100    }
101
102    /// Select convolution algorithm
103    fn select_convolution_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
104        let kernel_size = workload.dimensions.first().copied().unwrap_or(3);
105
106        if kernel_size <= 3 {
107            Algorithm {
108                name: "conv_direct".to_string(),
109                workgroup_size: (8, 8, 1),
110                strategy: ExecutionStrategy::Direct,
111                tuning_params: TuningParams {
112                    use_shared_memory: true,
113                    unroll_factor: 1,
114                    vectorize: false,
115                },
116            }
117        } else {
118            Algorithm {
119                name: "conv_im2col".to_string(),
120                workgroup_size: (16, 16, 1),
121                strategy: ExecutionStrategy::Transform,
122                tuning_params: TuningParams {
123                    use_shared_memory: true,
124                    unroll_factor: 4,
125                    vectorize: true,
126                },
127            }
128        }
129    }
130
131    /// Select reduction algorithm
132    fn select_reduction_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
133        let compute_units = self.device_info.compute_units.unwrap_or(8);
134
135        Algorithm {
136            name: "reduce_hierarchical".to_string(),
137            workgroup_size: (256, 1, 1),
138            strategy: ExecutionStrategy::Hierarchical {
139                levels: (workload.data_size as f32).log2().ceil() as usize / 8,
140            },
141            tuning_params: TuningParams {
142                use_shared_memory: true,
143                unroll_factor: (compute_units / 8).max(1),
144                vectorize: true,
145            },
146        }
147    }
148
149    /// Select sorting algorithm
150    fn select_sort_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
151        if workload.data_size < 1024 {
152            Algorithm {
153                name: "sort_insertion".to_string(),
154                workgroup_size: (64, 1, 1),
155                strategy: ExecutionStrategy::Direct,
156                tuning_params: TuningParams::default(),
157            }
158        } else if workload.data_size < 1024 * 1024 {
159            Algorithm {
160                name: "sort_bitonic".to_string(),
161                workgroup_size: (128, 1, 1),
162                strategy: ExecutionStrategy::Parallel,
163                tuning_params: TuningParams {
164                    use_shared_memory: true,
165                    unroll_factor: 2,
166                    vectorize: false,
167                },
168            }
169        } else {
170            Algorithm {
171                name: "sort_radix".to_string(),
172                workgroup_size: (256, 1, 1),
173                strategy: ExecutionStrategy::Hierarchical { levels: 4 },
174                tuning_params: TuningParams {
175                    use_shared_memory: true,
176                    unroll_factor: 4,
177                    vectorize: true,
178                },
179            }
180        }
181    }
182
183    /// Select FFT algorithm
184    fn select_fft_algorithm(&self, workload: &WorkloadInfo) -> Algorithm {
185        let size = workload.data_size;
186        let is_power_of_2 = (size & (size - 1)) == 0;
187
188        if is_power_of_2 {
189            Algorithm {
190                name: "fft_cooley_tukey".to_string(),
191                workgroup_size: (256, 1, 1),
192                strategy: ExecutionStrategy::Hierarchical {
193                    levels: (size as f32).log2() as usize,
194                },
195                tuning_params: TuningParams {
196                    use_shared_memory: true,
197                    unroll_factor: 4,
198                    vectorize: true,
199                },
200            }
201        } else {
202            Algorithm {
203                name: "fft_bluestein".to_string(),
204                workgroup_size: (128, 1, 1),
205                strategy: ExecutionStrategy::Transform,
206                tuning_params: TuningParams {
207                    use_shared_memory: false,
208                    unroll_factor: 2,
209                    vectorize: false,
210                },
211            }
212        }
213    }
214
215    /// Record performance for an algorithm
216    pub fn record_performance(
217        &self,
218        operation: &str,
219        workload: &WorkloadInfo,
220        algorithm: &Algorithm,
221        duration: Duration,
222    ) {
223        let mut history = self.performance_history.write();
224        history.record(operation, workload, algorithm, duration);
225    }
226
227    /// Get performance statistics
228    pub fn get_statistics(&self, operation: &str) -> Option<AlgorithmStats> {
229        let history = self.performance_history.read();
230        history.get_stats(operation)
231    }
232}
233
234/// Device hardware information
235#[derive(Debug, Clone)]
236pub struct DeviceInfo {
237    /// Device type
238    pub device_type: DeviceType,
239    /// Estimated compute units
240    pub compute_units: Option<u32>,
241    /// Memory bandwidth (estimated, GB/s)
242    pub memory_bandwidth: Option<f32>,
243    /// Peak FLOPS (estimated)
244    pub peak_flops: Option<f64>,
245}
246
247impl DeviceInfo {
248    fn from_device(_device: &Device, device_type: DeviceType) -> Self {
249        let compute_units = match device_type {
250            DeviceType::DiscreteGpu => Some(64),
251            DeviceType::IntegratedGpu => Some(16),
252            _ => None,
253        };
254
255        Self {
256            device_type,
257            compute_units,
258            memory_bandwidth: None,
259            peak_flops: None,
260        }
261    }
262}
263
264/// Workload information
265#[derive(Debug, Clone, Hash, PartialEq, Eq)]
266pub struct WorkloadInfo {
267    /// Total data size in bytes
268    pub data_size: u64,
269    /// Data dimensions (e.g., [width, height] for 2D)
270    pub dimensions: Vec<u32>,
271    /// Data type size in bytes
272    pub element_size: u32,
273}
274
275/// Algorithm configuration
276#[derive(Debug, Clone)]
277pub struct Algorithm {
278    /// Algorithm name
279    pub name: String,
280    /// Workgroup size (x, y, z)
281    pub workgroup_size: (u32, u32, u32),
282    /// Execution strategy
283    pub strategy: ExecutionStrategy,
284    /// Tuning parameters
285    pub tuning_params: TuningParams,
286}
287
288impl Default for Algorithm {
289    fn default() -> Self {
290        Self {
291            name: "default".to_string(),
292            workgroup_size: (8, 8, 1),
293            strategy: ExecutionStrategy::Direct,
294            tuning_params: TuningParams::default(),
295        }
296    }
297}
298
299/// Execution strategy
300#[derive(Debug, Clone, Copy)]
301pub enum ExecutionStrategy {
302    /// Direct execution
303    Direct,
304    /// Tiled execution with specified tile size
305    Tiled {
306        /// Tile size for tiled execution
307        tile_size: u32,
308    },
309    /// Hierarchical/recursive execution with specified levels
310    Hierarchical {
311        /// Number of hierarchical levels
312        levels: usize,
313    },
314    /// Transform-based execution
315    Transform,
316    /// Parallel execution
317    Parallel,
318}
319
320/// Algorithm tuning parameters
321#[derive(Debug, Clone)]
322pub struct TuningParams {
323    /// Use shared/local memory
324    pub use_shared_memory: bool,
325    /// Loop unroll factor
326    pub unroll_factor: u32,
327    /// Enable vectorization
328    pub vectorize: bool,
329}
330
331impl Default for TuningParams {
332    fn default() -> Self {
333        Self {
334            use_shared_memory: false,
335            unroll_factor: 1,
336            vectorize: false,
337        }
338    }
339}
340
341/// Performance history tracker
342struct PerformanceHistory {
343    records: HashMap<String, Vec<PerformanceRecord>>,
344    max_records_per_operation: usize,
345}
346
347impl PerformanceHistory {
348    fn new() -> Self {
349        Self {
350            records: HashMap::new(),
351            max_records_per_operation: 100,
352        }
353    }
354
355    fn record(
356        &mut self,
357        operation: &str,
358        workload: &WorkloadInfo,
359        algorithm: &Algorithm,
360        duration: Duration,
361    ) {
362        let record = PerformanceRecord {
363            workload: workload.clone(),
364            algorithm_name: algorithm.name.clone(),
365            duration,
366            timestamp: Instant::now(),
367        };
368
369        let records = self.records.entry(operation.to_string()).or_default();
370        records.push(record);
371
372        // Keep only recent records
373        if records.len() > self.max_records_per_operation {
374            records.remove(0);
375        }
376    }
377
378    fn get_best_algorithm(&self, operation: &str, workload: &WorkloadInfo) -> Option<Algorithm> {
379        let records = self.records.get(operation)?;
380
381        // Find records with similar workload
382        let mut similar: Vec<_> = records
383            .iter()
384            .filter(|r| Self::is_similar_workload(&r.workload, workload))
385            .collect();
386
387        if similar.is_empty() {
388            return None;
389        }
390
391        // Sort by duration
392        similar.sort_by_key(|r| r.duration);
393
394        // Return the algorithm of the best performing record
395        Some(Algorithm {
396            name: similar[0].algorithm_name.clone(),
397            ..Algorithm::default()
398        })
399    }
400
401    fn is_similar_workload(w1: &WorkloadInfo, w2: &WorkloadInfo) -> bool {
402        // Similar if data size within 20% and same dimensions count
403        let size_ratio = (w1.data_size as f64) / (w2.data_size as f64);
404        size_ratio > 0.8 && size_ratio < 1.2 && w1.dimensions.len() == w2.dimensions.len()
405    }
406
407    fn get_stats(&self, operation: &str) -> Option<AlgorithmStats> {
408        let records = self.records.get(operation)?;
409
410        if records.is_empty() {
411            return None;
412        }
413
414        let total_duration: Duration = records.iter().map(|r| r.duration).sum();
415        let count = records.len() as u32;
416
417        Some(AlgorithmStats {
418            total_executions: count,
419            average_duration: total_duration / count,
420            total_duration,
421        })
422    }
423}
424
425/// Performance record
426#[derive(Debug, Clone)]
427struct PerformanceRecord {
428    workload: WorkloadInfo,
429    algorithm_name: String,
430    duration: Duration,
431    /// Timestamp when the record was created (for future LRU eviction)
432    #[allow(dead_code)]
433    timestamp: Instant,
434}
435
436/// Algorithm performance statistics
437#[derive(Debug, Clone)]
438pub struct AlgorithmStats {
439    /// Total executions
440    pub total_executions: u32,
441    /// Average duration
442    pub average_duration: Duration,
443    /// Total duration
444    pub total_duration: Duration,
445}
446
447/// Adaptive configuration
448#[derive(Debug, Clone)]
449pub struct AdaptiveConfig {
450    /// Enable auto-tuning
451    pub auto_tune: bool,
452    /// Minimum samples before trusting history
453    pub min_samples: usize,
454}
455
456impl Default for AdaptiveConfig {
457    fn default() -> Self {
458        Self {
459            auto_tune: true,
460            min_samples: 3,
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_workload_similarity() {
471        let w1 = WorkloadInfo {
472            data_size: 1000,
473            dimensions: vec![100, 10],
474            element_size: 4,
475        };
476
477        let w2 = WorkloadInfo {
478            data_size: 1100,
479            dimensions: vec![110, 10],
480            element_size: 4,
481        };
482
483        let w3 = WorkloadInfo {
484            data_size: 2000,
485            dimensions: vec![100, 20],
486            element_size: 4,
487        };
488
489        assert!(PerformanceHistory::is_similar_workload(&w1, &w2));
490        assert!(!PerformanceHistory::is_similar_workload(&w1, &w3));
491    }
492
493    #[test]
494    fn test_algorithm_default() {
495        let algo = Algorithm::default();
496        assert_eq!(algo.name, "default");
497        assert_eq!(algo.workgroup_size, (8, 8, 1));
498    }
499}