Skip to main content

cuda_rust_wasm/runtime/
coalescing.rs

1//! Memory Coalescing Analyzer
2//!
3//! Analyzes memory access patterns in kernel code to detect coalesced
4//! vs. scattered accesses. Coalesced accesses (threads in a warp accessing
5//! consecutive addresses) are critical for GPU performance — up to 32x
6//! difference between coalesced and uncoalesced patterns.
7//!
8//! This module provides:
9//! - Static pattern analysis from array access expressions
10//! - Runtime access pattern recording and analysis
11//! - Optimization suggestions
12
13use std::fmt;
14use std::collections::HashMap;
15
16/// Memory access pattern classification.
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum AccessPattern {
19    /// Perfectly coalesced: thread i accesses address base + i * elem_size
20    FullyCoalesced,
21    /// Strided: thread i accesses base + i * stride (stride > elem_size)
22    Strided { stride: usize },
23    /// Random/scattered: no detectable pattern
24    Scattered,
25    /// Broadcast: all threads access same address
26    Broadcast,
27    /// Block-cyclic: threads access in groups
28    BlockCyclic { block_size: usize },
29}
30
31/// A recorded memory access for analysis.
32#[derive(Debug, Clone)]
33pub struct MemoryAccess {
34    /// Thread ID within warp (0-31).
35    pub lane_id: u32,
36    /// Byte address accessed.
37    pub address: usize,
38    /// Read or write.
39    pub is_write: bool,
40    /// Element size in bytes.
41    pub elem_size: usize,
42}
43
44/// Result of coalescing analysis.
45#[derive(Debug, Clone)]
46pub struct CoalescingReport {
47    /// Detected pattern.
48    pub pattern: AccessPattern,
49    /// Number of memory transactions needed (fewer = better).
50    /// Ideal: 1 transaction for 32 threads. Worst: 32 transactions.
51    pub transactions: u32,
52    /// Efficiency: useful bytes / total bytes transferred (0.0 to 1.0).
53    pub efficiency: f64,
54    /// Cache line utilization (assuming 128-byte cache lines).
55    pub cache_lines_touched: u32,
56    /// Optimization suggestion.
57    pub suggestion: String,
58}
59
60impl fmt::Display for CoalescingReport {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        write!(f, "Coalescing: {:?}, {} transactions, {:.1}% efficiency — {}",
63            self.pattern, self.transactions, self.efficiency * 100.0, self.suggestion)
64    }
65}
66
67/// Cache line size in bytes (GPU L1 cache line).
68const CACHE_LINE_SIZE: usize = 128;
69/// GPU memory transaction size in bytes.
70const TRANSACTION_SIZE: usize = 32;
71
72/// Analyze a set of memory accesses from one warp (32 threads).
73pub fn analyze_warp_access(accesses: &[MemoryAccess]) -> CoalescingReport {
74    if accesses.is_empty() {
75        return CoalescingReport {
76            pattern: AccessPattern::FullyCoalesced,
77            transactions: 0,
78            efficiency: 1.0,
79            cache_lines_touched: 0,
80            suggestion: "No accesses to analyze".into(),
81        };
82    }
83
84    let elem_size = accesses[0].elem_size;
85
86    // Sort by lane_id
87    let mut sorted = accesses.to_vec();
88    sorted.sort_by_key(|a| a.lane_id);
89
90    // Check for broadcast
91    if sorted.windows(2).all(|w| w[0].address == w[1].address) {
92        return CoalescingReport {
93            pattern: AccessPattern::Broadcast,
94            transactions: 1,
95            efficiency: elem_size as f64 / TRANSACTION_SIZE as f64,
96            cache_lines_touched: 1,
97            suggestion: "Broadcast access — consider using shared memory or constant cache".into(),
98        };
99    }
100
101    // Detect stride pattern
102    let mut strides = Vec::new();
103    for window in sorted.windows(2) {
104        if window[1].address >= window[0].address {
105            strides.push(window[1].address - window[0].address);
106        }
107    }
108
109    // Count unique cache lines touched
110    let mut cache_lines: Vec<usize> = accesses.iter()
111        .map(|a| a.address / CACHE_LINE_SIZE)
112        .collect();
113    cache_lines.sort();
114    cache_lines.dedup();
115    let cache_lines_touched = cache_lines.len() as u32;
116
117    // Count memory transactions (32-byte segments)
118    let mut segments: Vec<usize> = accesses.iter()
119        .map(|a| a.address / TRANSACTION_SIZE)
120        .collect();
121    segments.sort();
122    segments.dedup();
123    let transactions = segments.len() as u32;
124
125    let useful_bytes = accesses.len() * elem_size;
126    let total_bytes = transactions as usize * TRANSACTION_SIZE;
127    let efficiency = if total_bytes > 0 {
128        useful_bytes as f64 / total_bytes as f64
129    } else {
130        1.0
131    };
132
133    // Classify pattern
134    let is_uniform_stride = !strides.is_empty() && strides.iter().all(|&s| s == strides[0]);
135
136    let (pattern, suggestion) = if is_uniform_stride {
137        let stride = strides[0];
138        if stride == elem_size {
139            (AccessPattern::FullyCoalesced,
140             "Fully coalesced — optimal memory access pattern".into())
141        } else if stride == 0 {
142            (AccessPattern::Broadcast,
143             "Broadcast — consider constant memory cache".into())
144        } else {
145            let stride_ratio = stride / elem_size;
146            (AccessPattern::Strided { stride },
147             format!("Stride-{} access — consider transposing data layout or using shared memory tiling", stride_ratio))
148        }
149    } else {
150        (AccessPattern::Scattered,
151         "Scattered access — consider sorting indices or using texture cache".into())
152    };
153
154    CoalescingReport {
155        pattern,
156        transactions,
157        efficiency,
158        cache_lines_touched,
159        suggestion,
160    }
161}
162
163/// Simulate warp access pattern for a linear index expression.
164///
165/// Models: `address = base + (thread_id * stride + offset) * elem_size`
166pub fn simulate_linear_access(
167    base: usize,
168    stride: usize,
169    offset: usize,
170    elem_size: usize,
171    warp_size: u32,
172) -> Vec<MemoryAccess> {
173    (0..warp_size).map(|lane| {
174        MemoryAccess {
175            lane_id: lane,
176            address: base + (lane as usize * stride + offset) * elem_size,
177            is_write: false,
178            elem_size,
179        }
180    }).collect()
181}
182
183/// Simulate column-major access pattern (common anti-pattern).
184///
185/// Models: `address = base + (thread_id * num_cols + col) * elem_size`
186pub fn simulate_column_access(
187    base: usize,
188    num_cols: usize,
189    col: usize,
190    elem_size: usize,
191    warp_size: u32,
192) -> Vec<MemoryAccess> {
193    (0..warp_size).map(|lane| {
194        MemoryAccess {
195            lane_id: lane,
196            address: base + (lane as usize * num_cols + col) * elem_size,
197            is_write: false,
198            elem_size,
199        }
200    }).collect()
201}
202
203/// Runtime access pattern recorder.
204pub struct AccessRecorder {
205    accesses: Vec<Vec<MemoryAccess>>,
206    current_warp: Vec<MemoryAccess>,
207}
208
209impl AccessRecorder {
210    /// Create a new recorder.
211    pub fn new() -> Self {
212        Self {
213            accesses: Vec::new(),
214            current_warp: Vec::new(),
215        }
216    }
217
218    /// Record a memory access.
219    pub fn record(&mut self, lane_id: u32, address: usize, elem_size: usize, is_write: bool) {
220        self.current_warp.push(MemoryAccess {
221            lane_id,
222            address,
223            is_write,
224            elem_size,
225        });
226
227        if self.current_warp.len() >= 32 {
228            self.flush_warp();
229        }
230    }
231
232    /// Flush current warp to history.
233    pub fn flush_warp(&mut self) {
234        if !self.current_warp.is_empty() {
235            self.accesses.push(std::mem::take(&mut self.current_warp));
236        }
237    }
238
239    /// Analyze all recorded access patterns.
240    pub fn analyze(&mut self) -> Vec<CoalescingReport> {
241        self.flush_warp();
242        self.accesses.iter().map(|warp| analyze_warp_access(warp)).collect()
243    }
244
245    /// Get a summary of all recorded patterns.
246    pub fn summary(&mut self) -> AccessSummary {
247        let reports = self.analyze();
248        let mut pattern_counts: HashMap<String, usize> = HashMap::new();
249        let mut total_efficiency = 0.0;
250        let mut total_transactions = 0u32;
251
252        for report in &reports {
253            let key = format!("{:?}", report.pattern);
254            *pattern_counts.entry(key).or_insert(0) += 1;
255            total_efficiency += report.efficiency;
256            total_transactions += report.transactions;
257        }
258
259        let count = reports.len();
260        AccessSummary {
261            total_warps_analyzed: count,
262            avg_efficiency: if count > 0 { total_efficiency / count as f64 } else { 0.0 },
263            total_transactions,
264            pattern_distribution: pattern_counts,
265        }
266    }
267}
268
269/// Summary of access pattern analysis.
270#[derive(Debug)]
271pub struct AccessSummary {
272    pub total_warps_analyzed: usize,
273    pub avg_efficiency: f64,
274    pub total_transactions: u32,
275    pub pattern_distribution: HashMap<String, usize>,
276}
277
278impl fmt::Display for AccessSummary {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        write!(f, "Access Summary: {} warps, {:.1}% avg efficiency, {} transactions",
281            self.total_warps_analyzed,
282            self.avg_efficiency * 100.0,
283            self.total_transactions)
284    }
285}
286
287// ── Tests ──────────────────────────────────────────────────────────
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_coalesced_access() {
295        // Thread i accesses address base + i * 4 (perfect coalescing for f32)
296        let accesses = simulate_linear_access(0, 1, 0, 4, 32);
297        let report = analyze_warp_access(&accesses);
298        assert_eq!(report.pattern, AccessPattern::FullyCoalesced);
299        assert!(report.efficiency > 0.9);
300    }
301
302    #[test]
303    fn test_strided_access() {
304        // Thread i accesses address base + i * 512 * 4 (column of 512-wide matrix)
305        let accesses = simulate_column_access(0, 512, 0, 4, 32);
306        let report = analyze_warp_access(&accesses);
307        match report.pattern {
308            AccessPattern::Strided { stride } => assert_eq!(stride, 512 * 4),
309            _ => panic!("Expected strided pattern, got {:?}", report.pattern),
310        }
311        assert!(report.efficiency < 0.2, "Strided access should have low efficiency: {}", report.efficiency);
312    }
313
314    #[test]
315    fn test_broadcast_access() {
316        let accesses: Vec<MemoryAccess> = (0..32).map(|lane| {
317            MemoryAccess { lane_id: lane, address: 1000, is_write: false, elem_size: 4 }
318        }).collect();
319        let report = analyze_warp_access(&accesses);
320        assert_eq!(report.pattern, AccessPattern::Broadcast);
321        assert_eq!(report.transactions, 1);
322    }
323
324    #[test]
325    fn test_scattered_access() {
326        let addresses = [100, 5000, 200, 9000, 50, 7000, 300, 2000,
327                        400, 6000, 150, 8000, 250, 3000, 350, 1000,
328                        450, 4000, 550, 10000, 650, 11000, 750, 12000,
329                        850, 13000, 950, 14000, 1050, 15000, 1150, 16000];
330        let accesses: Vec<MemoryAccess> = addresses.iter().enumerate().map(|(i, &addr)| {
331            MemoryAccess { lane_id: i as u32, address: addr, is_write: false, elem_size: 4 }
332        }).collect();
333        let report = analyze_warp_access(&accesses);
334        assert_eq!(report.pattern, AccessPattern::Scattered);
335        assert!(report.transactions > 1);
336    }
337
338    #[test]
339    fn test_recorder() {
340        let mut recorder = AccessRecorder::new();
341        for lane in 0..32 {
342            recorder.record(lane, (lane as usize) * 4, 4, false);
343        }
344        let reports = recorder.analyze();
345        assert_eq!(reports.len(), 1);
346        assert_eq!(reports[0].pattern, AccessPattern::FullyCoalesced);
347    }
348
349    #[test]
350    fn test_summary() {
351        let mut recorder = AccessRecorder::new();
352        // Two warps: one coalesced, one strided
353        for lane in 0..32 {
354            recorder.record(lane, (lane as usize) * 4, 4, false);
355        }
356        for lane in 0..32 {
357            recorder.record(lane, (lane as usize) * 2048, 4, false);
358        }
359        let summary = recorder.summary();
360        assert_eq!(summary.total_warps_analyzed, 2);
361        assert!(summary.avg_efficiency > 0.0);
362    }
363
364    #[test]
365    fn test_report_display() {
366        let report = CoalescingReport {
367            pattern: AccessPattern::FullyCoalesced,
368            transactions: 4,
369            efficiency: 1.0,
370            cache_lines_touched: 1,
371            suggestion: "Optimal".into(),
372        };
373        let s = format!("{}", report);
374        assert!(s.contains("100.0%"));
375    }
376
377    #[test]
378    fn test_empty_access() {
379        let report = analyze_warp_access(&[]);
380        assert_eq!(report.transactions, 0);
381        assert_eq!(report.pattern, AccessPattern::FullyCoalesced);
382    }
383}