Skip to main content

optirs_gpu/memory/management/
prefetching.rs

1// Memory prefetching for GPU memory management
2//
3// This module provides advanced prefetching strategies to improve GPU memory
4// performance by anticipating future memory access patterns and proactively
5// loading data before it's needed.
6
7#[allow(dead_code)]
8use std::collections::{BTreeMap, HashMap, VecDeque};
9use std::ptr::NonNull;
10use std::sync::{Arc, Mutex, RwLock};
11use std::time::{Duration, Instant};
12
13/// Main prefetching engine
14pub struct PrefetchingEngine {
15    /// Configuration
16    config: PrefetchConfig,
17    /// Statistics
18    stats: PrefetchStats,
19    /// Active prefetching strategies
20    strategies: Vec<Box<dyn PrefetchStrategy>>,
21    /// Access pattern history
22    access_history: AccessHistoryTracker,
23    /// Prefetch requests queue
24    prefetch_queue: VecDeque<PrefetchRequest>,
25    /// Cache of prefetched data
26    prefetch_cache: PrefetchCache,
27    /// Performance monitoring
28    performance_monitor: PerformanceMonitor,
29}
30
31/// Prefetching configuration
32#[derive(Debug, Clone)]
33pub struct PrefetchConfig {
34    /// Enable automatic prefetching
35    pub auto_prefetch: bool,
36    /// Maximum prefetch distance (bytes)
37    pub max_prefetch_distance: usize,
38    /// Prefetch window size
39    pub prefetch_window: usize,
40    /// Minimum access frequency for prefetching
41    pub min_access_frequency: f64,
42    /// Enable adaptive prefetching
43    pub enable_adaptive: bool,
44    /// Enable pattern-based prefetching
45    pub enable_pattern_based: bool,
46    /// Enable stride-based prefetching
47    pub enable_stride_based: bool,
48    /// Enable ML-based prefetching
49    pub enable_ml_based: bool,
50    /// Prefetch aggressiveness (0.0 to 1.0)
51    pub aggressiveness: f64,
52    /// Cache size for prefetched data
53    pub cache_size: usize,
54    /// Enable performance monitoring
55    pub enable_monitoring: bool,
56    /// History window size
57    pub history_window: usize,
58}
59
60impl Default for PrefetchConfig {
61    fn default() -> Self {
62        Self {
63            auto_prefetch: true,
64            max_prefetch_distance: 1024 * 1024, // 1MB
65            prefetch_window: 64,
66            min_access_frequency: 0.1,
67            enable_adaptive: true,
68            enable_pattern_based: true,
69            enable_stride_based: true,
70            enable_ml_based: false,
71            aggressiveness: 0.5,
72            cache_size: 16 * 1024 * 1024, // 16MB
73            enable_monitoring: true,
74            history_window: 1000,
75        }
76    }
77}
78
79/// Prefetching statistics
80#[derive(Debug, Clone, Default)]
81pub struct PrefetchStats {
82    /// Total prefetch requests
83    pub total_requests: u64,
84    /// Successful prefetches (used)
85    pub successful_prefetches: u64,
86    /// Failed prefetches (unused)
87    pub failed_prefetches: u64,
88    /// Prefetch accuracy ratio
89    pub accuracy_ratio: f64,
90    /// Total bytes prefetched
91    pub total_bytes_prefetched: u64,
92    /// Useful bytes prefetched
93    pub useful_bytes_prefetched: u64,
94    /// Cache hit rate
95    pub cache_hit_rate: f64,
96    /// Average prefetch latency
97    pub average_latency: Duration,
98    /// Bandwidth saved by prefetching
99    pub bandwidth_saved: u64,
100    /// Strategy performance
101    pub strategy_stats: HashMap<String, StrategyStats>,
102}
103
104/// Individual strategy statistics
105#[derive(Debug, Clone, Default)]
106pub struct StrategyStats {
107    pub requests: u64,
108    pub hits: u64,
109    pub misses: u64,
110    pub accuracy: f64,
111    pub latency: Duration,
112}
113
114/// Memory access tracking
115pub struct AccessHistoryTracker {
116    /// Recent access history
117    access_history: VecDeque<MemoryAccess>,
118    /// Access patterns
119    patterns: HashMap<AccessPattern, PatternFrequency>,
120    /// Stride patterns
121    stride_patterns: HashMap<usize, StrideInfo>,
122    /// Sequential access tracking
123    sequential_tracking: HashMap<usize, SequentialInfo>,
124    /// Access frequency map
125    frequency_map: HashMap<usize, AccessFrequency>,
126}
127
128/// Memory access record
129#[derive(Debug, Clone)]
130pub struct MemoryAccess {
131    /// Memory address accessed
132    pub address: usize,
133    /// Access size
134    pub size: usize,
135    /// Access timestamp
136    pub timestamp: Instant,
137    /// Access type (read/write)
138    pub access_type: AccessType,
139    /// Thread/context ID
140    pub context_id: u32,
141    /// GPU kernel ID
142    pub kernel_id: Option<u32>,
143}
144
145/// Access type enumeration
146#[derive(Debug, Clone, PartialEq)]
147pub enum AccessType {
148    Read,
149    Write,
150    ReadWrite,
151}
152
153/// Access pattern representation
154#[derive(Debug, Clone, Hash, PartialEq, Eq)]
155pub struct AccessPattern {
156    /// Pattern type
157    pub pattern_type: PatternType,
158    /// Address deltas
159    pub deltas: Vec<isize>,
160    /// Pattern size
161    pub size: usize,
162}
163
164/// Pattern type enumeration
165#[derive(Debug, Clone, Hash, PartialEq, Eq)]
166pub enum PatternType {
167    Sequential,
168    Strided,
169    Random,
170    Irregular,
171    Custom(String),
172}
173
174/// Pattern frequency tracking
175#[derive(Debug, Clone)]
176pub struct PatternFrequency {
177    pub count: u32,
178    pub last_seen: Instant,
179    pub confidence: f64,
180    pub prediction_accuracy: f64,
181}
182
183/// Stride pattern information
184#[derive(Debug, Clone)]
185pub struct StrideInfo {
186    pub stride: isize,
187    pub frequency: u32,
188    pub last_address: usize,
189    pub confidence: f64,
190    pub start_time: Instant,
191}
192
193/// Sequential access information
194#[derive(Debug, Clone)]
195pub struct SequentialInfo {
196    pub start_address: usize,
197    pub current_address: usize,
198    pub length: usize,
199    pub direction: i8, // 1 for forward, -1 for backward
200    pub last_access: Instant,
201}
202
203/// Access frequency tracking
204#[derive(Debug, Clone)]
205pub struct AccessFrequency {
206    pub count: u32,
207    pub first_access: Instant,
208    pub last_access: Instant,
209    pub average_interval: Duration,
210}
211
212impl AccessHistoryTracker {
213    pub fn new(capacity: usize) -> Self {
214        Self {
215            access_history: VecDeque::with_capacity(capacity),
216            patterns: HashMap::new(),
217            stride_patterns: HashMap::new(),
218            sequential_tracking: HashMap::new(),
219            frequency_map: HashMap::new(),
220        }
221    }
222
223    /// Record a memory access
224    pub fn record_access(&mut self, access: MemoryAccess) {
225        // Add to history
226        self.access_history.push_back(access.clone());
227        if self.access_history.len() > self.access_history.capacity() {
228            self.access_history.pop_front();
229        }
230
231        // Update frequency map
232        let freq = self
233            .frequency_map
234            .entry(access.address)
235            .or_insert_with(|| AccessFrequency {
236                count: 0,
237                first_access: access.timestamp,
238                last_access: access.timestamp,
239                average_interval: Duration::from_secs(0),
240            });
241
242        let interval = if freq.count > 0 {
243            access.timestamp.duration_since(freq.last_access)
244        } else {
245            Duration::from_secs(0)
246        };
247
248        freq.count += 1;
249        freq.last_access = access.timestamp;
250        freq.average_interval = if freq.count > 1 {
251            Duration::from_nanos(
252                (freq.average_interval.as_nanos() as u64 * (freq.count - 1) as u64
253                    + interval.as_nanos() as u64)
254                    / freq.count as u64,
255            )
256        } else {
257            interval
258        };
259
260        // Detect patterns
261        self.detect_patterns(&access);
262        self.detect_strides(&access);
263        self.track_sequential_access(&access);
264    }
265
266    fn detect_patterns(&mut self, current_access: &MemoryAccess) {
267        let window_size = 8;
268        if self.access_history.len() < window_size {
269            return;
270        }
271
272        let recent: Vec<&MemoryAccess> =
273            self.access_history.iter().rev().take(window_size).collect();
274        let mut deltas = Vec::new();
275
276        for i in 1..recent.len() {
277            let delta = recent[i - 1].address as isize - recent[i].address as isize;
278            deltas.push(delta);
279        }
280
281        // Classify pattern type
282        let pattern_type = if deltas.iter().all(|&d| d == deltas[0]) {
283            if deltas[0] == 0 {
284                PatternType::Random
285            } else if deltas[0].abs() < 128 {
286                PatternType::Sequential
287            } else {
288                PatternType::Strided
289            }
290        } else {
291            PatternType::Irregular
292        };
293
294        let pattern = AccessPattern {
295            pattern_type,
296            deltas,
297            size: window_size,
298        };
299
300        // Update pattern frequency
301        let freq = self
302            .patterns
303            .entry(pattern)
304            .or_insert_with(|| PatternFrequency {
305                count: 0,
306                last_seen: current_access.timestamp,
307                confidence: 0.0,
308                prediction_accuracy: 0.0,
309            });
310
311        freq.count += 1;
312        freq.last_seen = current_access.timestamp;
313        freq.confidence = (freq.count as f64 / 100.0).min(1.0);
314    }
315
316    fn detect_strides(&mut self, current_access: &MemoryAccess) {
317        if self.access_history.len() < 2 {
318            return;
319        }
320
321        let prev_access = &self.access_history[self.access_history.len() - 2];
322        let stride = current_access.address as isize - prev_access.address as isize;
323
324        let stride_info = self
325            .stride_patterns
326            .entry(current_access.context_id as usize)
327            .or_insert_with(|| StrideInfo {
328                stride: 0,
329                frequency: 0,
330                last_address: prev_access.address,
331                confidence: 0.0,
332                start_time: current_access.timestamp,
333            });
334
335        if stride == stride_info.stride {
336            stride_info.frequency += 1;
337            stride_info.confidence = (stride_info.frequency as f64 / 10.0).min(1.0);
338        } else {
339            stride_info.stride = stride;
340            stride_info.frequency = 1;
341            stride_info.confidence = 0.1;
342            stride_info.start_time = current_access.timestamp;
343        }
344
345        stride_info.last_address = current_access.address;
346    }
347
348    fn track_sequential_access(&mut self, current_access: &MemoryAccess) {
349        let seq_info = self
350            .sequential_tracking
351            .entry(current_access.context_id as usize)
352            .or_insert_with(|| SequentialInfo {
353                start_address: current_access.address,
354                current_address: current_access.address,
355                length: 1,
356                direction: 0,
357                last_access: current_access.timestamp,
358            });
359
360        let address_diff = current_access.address as isize - seq_info.current_address as isize;
361
362        if address_diff.abs() <= current_access.size as isize * 2 {
363            // Likely sequential
364            if seq_info.direction == 0 {
365                seq_info.direction = if address_diff > 0 { 1 } else { -1 };
366            }
367
368            if (seq_info.direction > 0 && address_diff > 0)
369                || (seq_info.direction < 0 && address_diff < 0)
370            {
371                seq_info.length += 1;
372                seq_info.current_address = current_access.address;
373                seq_info.last_access = current_access.timestamp;
374            } else {
375                // Reset sequence
376                seq_info.start_address = current_access.address;
377                seq_info.current_address = current_access.address;
378                seq_info.length = 1;
379                seq_info.direction = 0;
380                seq_info.last_access = current_access.timestamp;
381            }
382        } else {
383            // Non-sequential, reset
384            seq_info.start_address = current_access.address;
385            seq_info.current_address = current_access.address;
386            seq_info.length = 1;
387            seq_info.direction = 0;
388            seq_info.last_access = current_access.timestamp;
389        }
390    }
391
392    /// Get predicted next accesses
393    pub fn predict_next_accesses(&self, count: usize) -> Vec<PredictedAccess> {
394        let mut predictions = Vec::new();
395
396        // Sequential predictions
397        for seq_info in self.sequential_tracking.values() {
398            if seq_info.length >= 3 && seq_info.last_access.elapsed() < Duration::from_millis(100) {
399                let next_addr = if seq_info.direction > 0 {
400                    seq_info.current_address + 64 // Typical cache line size
401                } else {
402                    seq_info.current_address.saturating_sub(64)
403                };
404
405                predictions.push(PredictedAccess {
406                    address: next_addr,
407                    size: 64,
408                    confidence: 0.8,
409                    strategy: "Sequential".to_string(),
410                    estimated_time: Duration::from_micros(100),
411                });
412            }
413        }
414
415        // Stride predictions
416        for (context_id, stride_info) in &self.stride_patterns {
417            if stride_info.confidence > 0.5 && stride_info.frequency >= 3 {
418                let next_addr = (stride_info.last_address as isize + stride_info.stride) as usize;
419                predictions.push(PredictedAccess {
420                    address: next_addr,
421                    size: 64,
422                    confidence: stride_info.confidence,
423                    strategy: "Stride".to_string(),
424                    estimated_time: Duration::from_micros(150),
425                });
426            }
427        }
428
429        predictions.truncate(count);
430        predictions
431    }
432}
433
434/// Predicted memory access
435#[derive(Debug, Clone)]
436pub struct PredictedAccess {
437    pub address: usize,
438    pub size: usize,
439    pub confidence: f64,
440    pub strategy: String,
441    pub estimated_time: Duration,
442}
443
444/// Prefetch request
445#[derive(Debug, Clone)]
446pub struct PrefetchRequest {
447    /// Target address to prefetch
448    pub address: usize,
449    /// Size to prefetch
450    pub size: usize,
451    /// Priority level
452    pub priority: PrefetchPriority,
453    /// Strategy that generated this request
454    pub strategy: String,
455    /// Confidence in this prefetch
456    pub confidence: f64,
457    /// Request timestamp
458    pub timestamp: Instant,
459    /// Deadline for prefetch completion
460    pub deadline: Option<Instant>,
461}
462
463/// Prefetch priority levels
464#[derive(Debug, Clone, PartialEq, Ord, PartialOrd, Eq)]
465pub enum PrefetchPriority {
466    Low,
467    Normal,
468    High,
469    Critical,
470}
471
472/// Prefetch cache for storing prefetched data
473pub struct PrefetchCache {
474    /// Cache entries
475    entries: BTreeMap<usize, CacheEntry>,
476    /// Cache size limit
477    size_limit: usize,
478    /// Current cache size
479    current_size: usize,
480    /// LRU tracking
481    lru_order: VecDeque<usize>,
482    /// Cache statistics
483    stats: CacheStats,
484}
485
486/// Cache entry
487#[derive(Debug, Clone)]
488pub struct CacheEntry {
489    pub address: usize,
490    pub size: usize,
491    pub data: Vec<u8>,
492    pub prefetch_time: Instant,
493    pub last_access: Option<Instant>,
494    pub access_count: u32,
495    pub strategy: String,
496}
497
498/// Cache statistics
499#[derive(Debug, Clone, Default)]
500pub struct CacheStats {
501    pub hits: u64,
502    pub misses: u64,
503    pub evictions: u64,
504    pub total_size: usize,
505    pub utilization: f64,
506}
507
508impl PrefetchCache {
509    pub fn new(size_limit: usize) -> Self {
510        Self {
511            entries: BTreeMap::new(),
512            size_limit,
513            current_size: 0,
514            lru_order: VecDeque::new(),
515            stats: CacheStats::default(),
516        }
517    }
518
519    /// Insert prefetched data into cache
520    pub fn insert(&mut self, address: usize, data: Vec<u8>, strategy: String) -> bool {
521        let size = data.len();
522
523        // Check if we need to evict entries
524        while self.current_size + size > self.size_limit && !self.entries.is_empty() {
525            self.evict_lru();
526        }
527
528        if self.current_size + size <= self.size_limit {
529            let entry = CacheEntry {
530                address,
531                size,
532                data,
533                prefetch_time: Instant::now(),
534                last_access: None,
535                access_count: 0,
536                strategy,
537            };
538
539            self.entries.insert(address, entry);
540            self.lru_order.push_back(address);
541            self.current_size += size;
542            true
543        } else {
544            false
545        }
546    }
547
548    /// Check if data is in cache and mark as accessed
549    pub fn get(&mut self, address: usize, size: usize) -> Option<&[u8]> {
550        if let Some(entry) = self.entries.get_mut(&address) {
551            if entry.size >= size {
552                entry.last_access = Some(Instant::now());
553                entry.access_count += 1;
554
555                // Update LRU order
556                if let Some(pos) = self.lru_order.iter().position(|&addr| addr == address) {
557                    self.lru_order.remove(pos);
558                    self.lru_order.push_back(address);
559                }
560
561                self.stats.hits += 1;
562                return Some(&entry.data[..size]);
563            }
564        }
565
566        self.stats.misses += 1;
567        None
568    }
569
570    fn evict_lru(&mut self) {
571        if let Some(address) = self.lru_order.pop_front() {
572            if let Some(entry) = self.entries.remove(&address) {
573                self.current_size -= entry.size;
574                self.stats.evictions += 1;
575            }
576        }
577    }
578
579    /// Get cache statistics
580    pub fn get_stats(&self) -> &CacheStats {
581        &self.stats
582    }
583}
584
585/// Prefetch strategy trait
586pub trait PrefetchStrategy: Send + Sync {
587    fn name(&self) -> &str;
588    fn can_prefetch(&self, access: &MemoryAccess, history: &AccessHistoryTracker) -> bool;
589    fn generate_requests(
590        &self,
591        access: &MemoryAccess,
592        history: &AccessHistoryTracker,
593    ) -> Vec<PrefetchRequest>;
594    fn get_statistics(&self) -> StrategyStats;
595    fn configure(&mut self, config: &PrefetchConfig);
596}
597
598/// Sequential prefetching strategy
599pub struct SequentialPrefetcher {
600    stats: StrategyStats,
601    config: SequentialConfig,
602}
603
604/// Sequential prefetcher configuration
605#[derive(Debug, Clone)]
606pub struct SequentialConfig {
607    pub prefetch_distance: usize,
608    pub min_sequence_length: usize,
609    pub max_prefetch_count: usize,
610}
611
612impl Default for SequentialConfig {
613    fn default() -> Self {
614        Self {
615            prefetch_distance: 1024,
616            min_sequence_length: 3,
617            max_prefetch_count: 8,
618        }
619    }
620}
621
622impl SequentialPrefetcher {
623    pub fn new(config: SequentialConfig) -> Self {
624        Self {
625            stats: StrategyStats::default(),
626            config,
627        }
628    }
629}
630
631impl PrefetchStrategy for SequentialPrefetcher {
632    fn name(&self) -> &str {
633        "Sequential"
634    }
635
636    fn can_prefetch(&self, access: &MemoryAccess, history: &AccessHistoryTracker) -> bool {
637        if let Some(seq_info) = history
638            .sequential_tracking
639            .get(&(access.context_id as usize))
640        {
641            seq_info.length >= self.config.min_sequence_length
642                && seq_info.last_access.elapsed() < Duration::from_millis(50)
643        } else {
644            false
645        }
646    }
647
648    fn generate_requests(
649        &self,
650        access: &MemoryAccess,
651        history: &AccessHistoryTracker,
652    ) -> Vec<PrefetchRequest> {
653        let mut requests = Vec::new();
654
655        if let Some(seq_info) = history
656            .sequential_tracking
657            .get(&(access.context_id as usize))
658        {
659            let mut next_addr = access.address;
660            let step = 64; // Cache line size
661
662            for i in 0..self.config.max_prefetch_count {
663                next_addr = if seq_info.direction > 0 {
664                    next_addr + step
665                } else {
666                    next_addr.saturating_sub(step)
667                };
668
669                if (next_addr as isize - access.address as isize).abs()
670                    > self.config.prefetch_distance as isize
671                {
672                    break;
673                }
674
675                let confidence = (1.0 - i as f64 * 0.1).max(0.1);
676
677                requests.push(PrefetchRequest {
678                    address: next_addr,
679                    size: 64,
680                    priority: PrefetchPriority::Normal,
681                    strategy: self.name().to_string(),
682                    confidence,
683                    timestamp: Instant::now(),
684                    deadline: Some(Instant::now() + Duration::from_millis(10)),
685                });
686            }
687        }
688
689        requests
690    }
691
692    fn get_statistics(&self) -> StrategyStats {
693        self.stats.clone()
694    }
695
696    fn configure(&mut self, config: &PrefetchConfig) {
697        self.config.prefetch_distance = config.max_prefetch_distance;
698        self.config.max_prefetch_count = config.prefetch_window;
699    }
700}
701
702/// Stride-based prefetching strategy
703pub struct StridePrefetcher {
704    stats: StrategyStats,
705    config: StrideConfig,
706}
707
708/// Stride prefetcher configuration
709#[derive(Debug, Clone)]
710pub struct StrideConfig {
711    pub min_confidence: f64,
712    pub max_stride: isize,
713    pub prefetch_degree: usize,
714}
715
716impl Default for StrideConfig {
717    fn default() -> Self {
718        Self {
719            min_confidence: 0.6,
720            max_stride: 4096,
721            prefetch_degree: 4,
722        }
723    }
724}
725
726impl StridePrefetcher {
727    pub fn new(config: StrideConfig) -> Self {
728        Self {
729            stats: StrategyStats::default(),
730            config,
731        }
732    }
733}
734
735impl PrefetchStrategy for StridePrefetcher {
736    fn name(&self) -> &str {
737        "Stride"
738    }
739
740    fn can_prefetch(&self, access: &MemoryAccess, history: &AccessHistoryTracker) -> bool {
741        if let Some(stride_info) = history.stride_patterns.get(&(access.context_id as usize)) {
742            stride_info.confidence >= self.config.min_confidence
743                && stride_info.stride.abs() <= self.config.max_stride
744                && stride_info.stride != 0
745        } else {
746            false
747        }
748    }
749
750    fn generate_requests(
751        &self,
752        access: &MemoryAccess,
753        history: &AccessHistoryTracker,
754    ) -> Vec<PrefetchRequest> {
755        let mut requests = Vec::new();
756
757        if let Some(stride_info) = history.stride_patterns.get(&(access.context_id as usize)) {
758            let mut next_addr = access.address;
759
760            for i in 0..self.config.prefetch_degree {
761                next_addr = (next_addr as isize + stride_info.stride) as usize;
762                let confidence = stride_info.confidence * (1.0 - i as f64 * 0.15);
763
764                requests.push(PrefetchRequest {
765                    address: next_addr,
766                    size: access.size,
767                    priority: PrefetchPriority::Normal,
768                    strategy: self.name().to_string(),
769                    confidence,
770                    timestamp: Instant::now(),
771                    deadline: Some(Instant::now() + Duration::from_millis(15)),
772                });
773            }
774        }
775
776        requests
777    }
778
779    fn get_statistics(&self) -> StrategyStats {
780        self.stats.clone()
781    }
782
783    fn configure(&mut self, config: &PrefetchConfig) {
784        self.config.prefetch_degree = config.prefetch_window;
785    }
786}
787
788/// Performance monitoring for prefetching
789pub struct PerformanceMonitor {
790    /// Performance history
791    history: VecDeque<PerfSample>,
792    /// Current metrics
793    current_metrics: PerfMetrics,
794    /// Monitoring configuration
795    config: MonitorConfig,
796}
797
798/// Performance sample
799#[derive(Debug, Clone)]
800pub struct PerfSample {
801    pub timestamp: Instant,
802    pub cache_hit_rate: f64,
803    pub prefetch_accuracy: f64,
804    pub bandwidth_utilization: f64,
805    pub latency: Duration,
806}
807
808/// Performance metrics
809#[derive(Debug, Clone, Default)]
810pub struct PerfMetrics {
811    pub average_hit_rate: f64,
812    pub average_accuracy: f64,
813    pub average_bandwidth: f64,
814    pub average_latency: Duration,
815    pub trend_hit_rate: f64,
816    pub trend_accuracy: f64,
817}
818
819/// Monitor configuration
820#[derive(Debug, Clone)]
821pub struct MonitorConfig {
822    pub sample_interval: Duration,
823    pub history_size: usize,
824    pub enable_trends: bool,
825}
826
827impl Default for MonitorConfig {
828    fn default() -> Self {
829        Self {
830            sample_interval: Duration::from_secs(1),
831            history_size: 100,
832            enable_trends: true,
833        }
834    }
835}
836
837impl PerformanceMonitor {
838    pub fn new(config: MonitorConfig) -> Self {
839        Self {
840            history: VecDeque::with_capacity(config.history_size),
841            current_metrics: PerfMetrics::default(),
842            config,
843        }
844    }
845
846    /// Record a performance sample
847    pub fn record_sample(&mut self, sample: PerfSample) {
848        self.history.push_back(sample);
849        if self.history.len() > self.config.history_size {
850            self.history.pop_front();
851        }
852
853        self.update_metrics();
854    }
855
856    fn update_metrics(&mut self) {
857        if self.history.is_empty() {
858            return;
859        }
860
861        let count = self.history.len() as f64;
862        self.current_metrics.average_hit_rate =
863            self.history.iter().map(|s| s.cache_hit_rate).sum::<f64>() / count;
864        self.current_metrics.average_accuracy = self
865            .history
866            .iter()
867            .map(|s| s.prefetch_accuracy)
868            .sum::<f64>()
869            / count;
870        self.current_metrics.average_bandwidth = self
871            .history
872            .iter()
873            .map(|s| s.bandwidth_utilization)
874            .sum::<f64>()
875            / count;
876
877        let total_latency_nanos: u64 = self
878            .history
879            .iter()
880            .map(|s| s.latency.as_nanos() as u64)
881            .sum();
882        self.current_metrics.average_latency =
883            Duration::from_nanos(total_latency_nanos / count as u64);
884
885        // Calculate trends
886        if self.config.enable_trends && self.history.len() >= 10 {
887            let recent_hit_rate: f64 = self
888                .history
889                .iter()
890                .rev()
891                .take(5)
892                .map(|s| s.cache_hit_rate)
893                .sum::<f64>()
894                / 5.0;
895            let older_hit_rate: f64 = self
896                .history
897                .iter()
898                .rev()
899                .skip(5)
900                .take(5)
901                .map(|s| s.cache_hit_rate)
902                .sum::<f64>()
903                / 5.0;
904            self.current_metrics.trend_hit_rate = recent_hit_rate - older_hit_rate;
905
906            let recent_accuracy: f64 = self
907                .history
908                .iter()
909                .rev()
910                .take(5)
911                .map(|s| s.prefetch_accuracy)
912                .sum::<f64>()
913                / 5.0;
914            let older_accuracy: f64 = self
915                .history
916                .iter()
917                .rev()
918                .skip(5)
919                .take(5)
920                .map(|s| s.prefetch_accuracy)
921                .sum::<f64>()
922                / 5.0;
923            self.current_metrics.trend_accuracy = recent_accuracy - older_accuracy;
924        }
925    }
926
927    /// Get current performance metrics
928    pub fn get_metrics(&self) -> &PerfMetrics {
929        &self.current_metrics
930    }
931}
932
933impl PrefetchingEngine {
934    pub fn new(config: PrefetchConfig) -> Self {
935        let mut strategies: Vec<Box<dyn PrefetchStrategy>> = Vec::new();
936
937        if config.enable_pattern_based {
938            strategies.push(Box::new(SequentialPrefetcher::new(
939                SequentialConfig::default(),
940            )));
941        }
942
943        if config.enable_stride_based {
944            strategies.push(Box::new(StridePrefetcher::new(StrideConfig::default())));
945        }
946
947        let access_history = AccessHistoryTracker::new(config.history_window);
948        let prefetch_cache = PrefetchCache::new(config.cache_size);
949        let performance_monitor = PerformanceMonitor::new(MonitorConfig::default());
950
951        Self {
952            config,
953            stats: PrefetchStats::default(),
954            strategies,
955            access_history,
956            prefetch_queue: VecDeque::new(),
957            prefetch_cache,
958            performance_monitor,
959        }
960    }
961
962    /// Record a memory access and potentially trigger prefetching
963    pub fn record_access(&mut self, access: MemoryAccess) -> Vec<PrefetchRequest> {
964        // Record access in history
965        self.access_history.record_access(access.clone());
966
967        // Check cache for hit/miss
968        let cache_hit = self
969            .prefetch_cache
970            .get(access.address, access.size)
971            .is_some();
972        if cache_hit {
973            self.stats.successful_prefetches += 1;
974        }
975
976        let mut new_requests = Vec::new();
977
978        if self.config.auto_prefetch {
979            // Generate prefetch requests from strategies
980            for strategy in &self.strategies {
981                if strategy.can_prefetch(&access, &self.access_history) {
982                    let requests = strategy.generate_requests(&access, &self.access_history);
983                    for request in requests {
984                        if self.should_issue_prefetch(&request) {
985                            new_requests.push(request);
986                        }
987                    }
988                }
989            }
990
991            // Add requests to queue
992            for request in &new_requests {
993                self.prefetch_queue.push_back(request.clone());
994                self.stats.total_requests += 1;
995            }
996        }
997
998        new_requests
999    }
1000
1001    fn should_issue_prefetch(&self, request: &PrefetchRequest) -> bool {
1002        // Check if already in cache
1003        if self.prefetch_cache.entries.contains_key(&request.address) {
1004            return false;
1005        }
1006
1007        // Check confidence threshold
1008        if request.confidence < self.config.min_access_frequency {
1009            return false;
1010        }
1011
1012        // Check prefetch distance
1013        if request.size > self.config.max_prefetch_distance {
1014            return false;
1015        }
1016
1017        true
1018    }
1019
1020    /// Process prefetch queue and issue prefetches
1021    pub fn process_prefetch_queue(&mut self) -> Vec<PrefetchRequest> {
1022        let mut issued_requests = Vec::new();
1023        let max_concurrent = (self.config.aggressiveness * 10.0) as usize + 1;
1024
1025        // Sort by priority and confidence
1026        let mut pending: Vec<PrefetchRequest> = self.prefetch_queue.drain(..).collect();
1027        pending.sort_by(|a, b| {
1028            b.priority.cmp(&a.priority).then_with(|| {
1029                b.confidence
1030                    .partial_cmp(&a.confidence)
1031                    .unwrap_or(std::cmp::Ordering::Equal)
1032            })
1033        });
1034
1035        for request in pending.into_iter().take(max_concurrent) {
1036            // Simulate prefetch (in real implementation, this would trigger actual memory load)
1037            let dummy_data = vec![0u8; request.size];
1038            if self
1039                .prefetch_cache
1040                .insert(request.address, dummy_data, request.strategy.clone())
1041            {
1042                issued_requests.push(request);
1043            }
1044        }
1045
1046        issued_requests
1047    }
1048
1049    /// Update statistics and performance metrics
1050    pub fn update_performance(&mut self) {
1051        let cache_stats = self.prefetch_cache.get_stats();
1052
1053        // Update accuracy ratio
1054        let total_prefetches = self.stats.successful_prefetches + self.stats.failed_prefetches;
1055        if total_prefetches > 0 {
1056            self.stats.accuracy_ratio =
1057                self.stats.successful_prefetches as f64 / total_prefetches as f64;
1058        }
1059
1060        // Update cache hit rate
1061        let total_accesses = cache_stats.hits + cache_stats.misses;
1062        if total_accesses > 0 {
1063            self.stats.cache_hit_rate = cache_stats.hits as f64 / total_accesses as f64;
1064        }
1065
1066        // Record performance sample
1067        let sample = PerfSample {
1068            timestamp: Instant::now(),
1069            cache_hit_rate: self.stats.cache_hit_rate,
1070            prefetch_accuracy: self.stats.accuracy_ratio,
1071            bandwidth_utilization: 0.8, // Would be calculated from actual usage
1072            latency: self.stats.average_latency,
1073        };
1074
1075        self.performance_monitor.record_sample(sample);
1076    }
1077
1078    /// Get prefetching statistics
1079    pub fn get_stats(&self) -> &PrefetchStats {
1080        &self.stats
1081    }
1082
1083    /// Get cache statistics
1084    pub fn get_cache_stats(&self) -> &CacheStats {
1085        self.prefetch_cache.get_stats()
1086    }
1087
1088    /// Get performance metrics
1089    pub fn get_performance_metrics(&self) -> &PerfMetrics {
1090        self.performance_monitor.get_metrics()
1091    }
1092
1093    /// Get access history
1094    pub fn get_access_history(&self) -> &AccessHistoryTracker {
1095        &self.access_history
1096    }
1097
1098    /// Configure prefetching engine
1099    pub fn configure(&mut self, config: PrefetchConfig) {
1100        self.config = config.clone();
1101        for strategy in &mut self.strategies {
1102            strategy.configure(&config);
1103        }
1104    }
1105}
1106
1107/// Thread-safe prefetching engine wrapper
1108pub struct ThreadSafePrefetchingEngine {
1109    engine: Arc<Mutex<PrefetchingEngine>>,
1110}
1111
1112impl ThreadSafePrefetchingEngine {
1113    pub fn new(config: PrefetchConfig) -> Self {
1114        Self {
1115            engine: Arc::new(Mutex::new(PrefetchingEngine::new(config))),
1116        }
1117    }
1118
1119    pub fn record_access(&self, access: MemoryAccess) -> Vec<PrefetchRequest> {
1120        let mut engine = self.engine.lock().expect("lock poisoned");
1121        engine.record_access(access)
1122    }
1123
1124    pub fn process_prefetch_queue(&self) -> Vec<PrefetchRequest> {
1125        let mut engine = self.engine.lock().expect("lock poisoned");
1126        engine.process_prefetch_queue()
1127    }
1128
1129    pub fn get_stats(&self) -> PrefetchStats {
1130        let engine = self.engine.lock().expect("lock poisoned");
1131        engine.get_stats().clone()
1132    }
1133
1134    pub fn update_performance(&self) {
1135        let mut engine = self.engine.lock().expect("lock poisoned");
1136        engine.update_performance();
1137    }
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::*;
1143
1144    #[test]
1145    fn test_prefetch_engine_creation() {
1146        let config = PrefetchConfig::default();
1147        let engine = PrefetchingEngine::new(config);
1148        assert!(!engine.strategies.is_empty());
1149    }
1150
1151    #[test]
1152    fn test_access_history_tracking() {
1153        let mut tracker = AccessHistoryTracker::new(100);
1154
1155        let access = MemoryAccess {
1156            address: 0x1000,
1157            size: 64,
1158            timestamp: Instant::now(),
1159            access_type: AccessType::Read,
1160            context_id: 1,
1161            kernel_id: Some(100),
1162        };
1163
1164        tracker.record_access(access);
1165        assert_eq!(tracker.access_history.len(), 1);
1166    }
1167
1168    #[test]
1169    fn test_prefetch_cache() {
1170        let mut cache = PrefetchCache::new(1024);
1171
1172        let data = vec![1, 2, 3, 4];
1173        assert!(cache.insert(0x1000, data, "Test".to_string()));
1174
1175        let retrieved = cache.get(0x1000, 4);
1176        assert!(retrieved.is_some());
1177        assert_eq!(retrieved.expect("unwrap failed"), &[1, 2, 3, 4]);
1178    }
1179
1180    #[test]
1181    fn test_sequential_prefetcher() {
1182        let config = SequentialConfig::default();
1183        let prefetcher = SequentialPrefetcher::new(config);
1184        assert_eq!(prefetcher.name(), "Sequential");
1185    }
1186
1187    #[test]
1188    fn test_stride_prefetcher() {
1189        let config = StrideConfig::default();
1190        let prefetcher = StridePrefetcher::new(config);
1191        assert_eq!(prefetcher.name(), "Stride");
1192    }
1193
1194    #[test]
1195    fn test_performance_monitor() {
1196        let config = MonitorConfig::default();
1197        let mut monitor = PerformanceMonitor::new(config);
1198
1199        let sample = PerfSample {
1200            timestamp: Instant::now(),
1201            cache_hit_rate: 0.8,
1202            prefetch_accuracy: 0.7,
1203            bandwidth_utilization: 0.9,
1204            latency: Duration::from_millis(5),
1205        };
1206
1207        monitor.record_sample(sample);
1208        let metrics = monitor.get_metrics();
1209        assert!(metrics.average_hit_rate > 0.0);
1210    }
1211
1212    #[test]
1213    fn test_thread_safe_engine() {
1214        let config = PrefetchConfig::default();
1215        let engine = ThreadSafePrefetchingEngine::new(config);
1216
1217        let access = MemoryAccess {
1218            address: 0x2000,
1219            size: 128,
1220            timestamp: Instant::now(),
1221            access_type: AccessType::Read,
1222            context_id: 2,
1223            kernel_id: Some(200),
1224        };
1225
1226        let requests = engine.record_access(access);
1227        // Should not panic and may return requests
1228    }
1229}