Skip to main content

trueno/brick/simd_config/
mod.rs

1//! SIMD Configuration and Lazy Initialization
2//!
3//! LCP-07: Lazy AMX/SIMD tile configuration for expensive state setup.
4//! LCP-13: Unroll-and-tail vectorization patterns.
5
6use super::ComputeBackend;
7
8// ----------------------------------------------------------------------------
9// LCP-07: Lazy AMX Tile Config
10// ----------------------------------------------------------------------------
11
12/// SIMD backend state for lazy initialization.
13///
14/// AMX (Advanced Matrix Extensions) and AVX-512 require tile configuration
15/// that's expensive to set up. This tracks whether initialization has occurred.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum SimdBackendState {
18    /// Not initialized - will configure on first use
19    #[default]
20    Uninitialized,
21    /// Configuration in progress
22    Configuring,
23    /// Ready to use
24    Ready,
25    /// Failed to initialize (fallback to scalar)
26    Failed,
27}
28
29/// Lazy SIMD tile configuration manager.
30///
31/// Defers expensive SIMD state setup until actually needed.
32#[derive(Debug)]
33pub struct LazySimdConfig {
34    /// Current state
35    state: SimdBackendState,
36    /// Best available backend
37    best_backend: ComputeBackend,
38    /// Whether AMX is supported
39    amx_supported: bool,
40    /// Tile configuration (for AMX)
41    tile_config: Option<AmxTileConfig>,
42}
43
44/// AMX tile configuration (8x8 tile palette).
45#[derive(Debug, Clone, Copy, Default)]
46pub struct AmxTileConfig {
47    /// Palette ID (0-1)
48    pub palette: u8,
49    /// Start row
50    pub start_row: u8,
51    /// Number of rows per tile
52    pub rows: u8,
53    /// Bytes per row
54    pub bytes_per_row: u16,
55}
56
57impl LazySimdConfig {
58    /// Create new lazy config, detecting best backend.
59    #[must_use]
60    pub fn new() -> Self {
61        Self {
62            state: SimdBackendState::Uninitialized,
63            best_backend: Self::detect_best_backend(),
64            amx_supported: Self::detect_amx(),
65            tile_config: None,
66        }
67    }
68
69    /// Detect best available SIMD backend.
70    fn detect_best_backend() -> ComputeBackend {
71        #[cfg(target_arch = "x86_64")]
72        {
73            if is_x86_feature_detected!("avx512f") {
74                return ComputeBackend::Avx512;
75            }
76            if is_x86_feature_detected!("avx2") {
77                return ComputeBackend::Avx2;
78            }
79            if is_x86_feature_detected!("sse2") {
80                return ComputeBackend::Sse2;
81            }
82        }
83        #[cfg(target_arch = "aarch64")]
84        {
85            // NEON is always available on aarch64
86            return ComputeBackend::Neon;
87        }
88        ComputeBackend::Scalar
89    }
90
91    /// Detect AMX support (Intel Sapphire Rapids+).
92    fn detect_amx() -> bool {
93        #[cfg(target_arch = "x86_64")]
94        {
95            // AMX requires specific CPUID checks
96            // For now, return false as AMX is rare
97            false
98        }
99        #[cfg(not(target_arch = "x86_64"))]
100        {
101            false
102        }
103    }
104
105    /// Ensure SIMD is configured, initializing lazily if needed.
106    pub fn ensure_ready(&mut self) -> Result<ComputeBackend, SimdBackendState> {
107        match self.state {
108            SimdBackendState::Ready => Ok(self.best_backend),
109            SimdBackendState::Failed => Err(SimdBackendState::Failed),
110            SimdBackendState::Configuring => Err(SimdBackendState::Configuring),
111            SimdBackendState::Uninitialized => {
112                self.state = SimdBackendState::Configuring;
113
114                // Configure AMX tiles if supported
115                if self.amx_supported {
116                    self.tile_config = Some(AmxTileConfig {
117                        palette: 1,
118                        start_row: 0,
119                        rows: 16,
120                        bytes_per_row: 64,
121                    });
122                    // In real implementation, would call LDTILECFG here
123                }
124
125                self.state = SimdBackendState::Ready;
126                Ok(self.best_backend)
127            }
128        }
129    }
130
131    /// Get current state.
132    #[must_use]
133    pub fn state(&self) -> SimdBackendState {
134        self.state
135    }
136
137    /// Get best backend without initializing.
138    #[must_use]
139    pub fn best_backend(&self) -> ComputeBackend {
140        self.best_backend
141    }
142
143    /// Check if AMX is supported.
144    #[must_use]
145    pub fn has_amx(&self) -> bool {
146        self.amx_supported
147    }
148
149    /// Reset to uninitialized state.
150    pub fn reset(&mut self) {
151        self.state = SimdBackendState::Uninitialized;
152        self.tile_config = None;
153    }
154}
155
156impl Default for LazySimdConfig {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162// ----------------------------------------------------------------------------
163// LCP-13: Unroll-and-Tail Vectorization
164// ----------------------------------------------------------------------------
165
166/// Unroll factor for SIMD loops.
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub enum UnrollFactor {
169    /// No unrolling (1x)
170    None,
171    /// 2x unroll
172    X2,
173    /// 4x unroll
174    X4,
175    /// 8x unroll (AVX-512)
176    X8,
177}
178
179impl UnrollFactor {
180    /// Get numeric factor.
181    #[must_use]
182    pub fn value(&self) -> usize {
183        match self {
184            UnrollFactor::None => 1,
185            UnrollFactor::X2 => 2,
186            UnrollFactor::X4 => 4,
187            UnrollFactor::X8 => 8,
188        }
189    }
190
191    /// Get optimal factor for backend.
192    #[must_use]
193    pub fn for_backend(backend: ComputeBackend) -> Self {
194        match backend {
195            ComputeBackend::Avx512 => UnrollFactor::X8,
196            ComputeBackend::Avx2 => UnrollFactor::X4,
197            ComputeBackend::Sse2 | ComputeBackend::Neon => UnrollFactor::X2,
198            _ => UnrollFactor::None,
199        }
200    }
201}
202
203/// Helper for unroll-and-tail loop pattern.
204///
205/// Processes data in unrolled chunks, then handles the tail.
206#[derive(Debug)]
207pub struct UnrollTailIterator {
208    /// Total elements
209    total: usize,
210    /// Current position
211    position: usize,
212    /// Elements per unrolled iteration
213    chunk_size: usize,
214}
215
216impl UnrollTailIterator {
217    /// Create iterator for given size and unroll factor.
218    pub fn new(total: usize, factor: UnrollFactor) -> Self {
219        Self { total, position: 0, chunk_size: factor.value() }
220    }
221
222    /// Get number of full unrolled iterations.
223    #[must_use]
224    pub fn full_iterations(&self) -> usize {
225        self.total / self.chunk_size
226    }
227
228    /// Get tail size (remainder).
229    #[must_use]
230    pub fn tail_size(&self) -> usize {
231        self.total % self.chunk_size
232    }
233
234    /// Check if there's a tail to process.
235    #[must_use]
236    pub fn has_tail(&self) -> bool {
237        self.tail_size() > 0
238    }
239
240    /// Get next chunk range for unrolled iteration.
241    pub fn next_chunk(&mut self) -> Option<(usize, usize)> {
242        if self.position + self.chunk_size <= self.total {
243            let start = self.position;
244            self.position += self.chunk_size;
245            Some((start, start + self.chunk_size))
246        } else {
247            None
248        }
249    }
250
251    /// Get tail range (call after all chunks consumed).
252    pub fn tail_range(&self) -> Option<(usize, usize)> {
253        let tail_start = self.full_iterations() * self.chunk_size;
254        if tail_start < self.total {
255            Some((tail_start, self.total))
256        } else {
257            None
258        }
259    }
260}
261
262/// Process a slice with unroll-and-tail pattern.
263///
264/// # Example
265/// ```ignore
266/// let result = unroll_tail_process(
267///     &data,
268///     UnrollFactor::X4,
269///     |chunk| chunk.iter().sum::<f32>(), // Unrolled body
270///     |elem| *elem,                       // Tail body
271/// );
272/// ```
273pub fn unroll_tail_process<T, U, F, G>(
274    data: &[T],
275    factor: UnrollFactor,
276    mut process_chunk: F,
277    mut process_elem: G,
278) -> Vec<U>
279where
280    F: FnMut(&[T]) -> U,
281    G: FnMut(&T) -> U,
282{
283    let mut iter = UnrollTailIterator::new(data.len(), factor);
284    let mut results =
285        Vec::with_capacity(iter.full_iterations() + if iter.has_tail() { 1 } else { 0 });
286
287    // Process full chunks
288    while let Some((start, end)) = iter.next_chunk() {
289        results.push(process_chunk(&data[start..end]));
290    }
291
292    // Process tail
293    if let Some((start, end)) = iter.tail_range() {
294        for elem in &data[start..end] {
295            results.push(process_elem(elem));
296        }
297    }
298
299    results
300}
301
302#[cfg(test)]
303mod tests;