trueno/brick/simd_config/
mod.rs1use super::ComputeBackend;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum SimdBackendState {
18 #[default]
20 Uninitialized,
21 Configuring,
23 Ready,
25 Failed,
27}
28
29#[derive(Debug)]
33pub struct LazySimdConfig {
34 state: SimdBackendState,
36 best_backend: ComputeBackend,
38 amx_supported: bool,
40 tile_config: Option<AmxTileConfig>,
42}
43
44#[derive(Debug, Clone, Copy, Default)]
46pub struct AmxTileConfig {
47 pub palette: u8,
49 pub start_row: u8,
51 pub rows: u8,
53 pub bytes_per_row: u16,
55}
56
57impl LazySimdConfig {
58 #[must_use]
60 pub fn new() -> Self {
61 Self {
62 state: SimdBackendState::Uninitialized,
63 best_backend: Self::detect_best_backend(),
64 amx_supported: Self::detect_amx(),
65 tile_config: None,
66 }
67 }
68
69 fn detect_best_backend() -> ComputeBackend {
71 #[cfg(target_arch = "x86_64")]
72 {
73 if is_x86_feature_detected!("avx512f") {
74 return ComputeBackend::Avx512;
75 }
76 if is_x86_feature_detected!("avx2") {
77 return ComputeBackend::Avx2;
78 }
79 if is_x86_feature_detected!("sse2") {
80 return ComputeBackend::Sse2;
81 }
82 }
83 #[cfg(target_arch = "aarch64")]
84 {
85 return ComputeBackend::Neon;
87 }
88 ComputeBackend::Scalar
89 }
90
91 fn detect_amx() -> bool {
93 #[cfg(target_arch = "x86_64")]
94 {
95 false
98 }
99 #[cfg(not(target_arch = "x86_64"))]
100 {
101 false
102 }
103 }
104
105 pub fn ensure_ready(&mut self) -> Result<ComputeBackend, SimdBackendState> {
107 match self.state {
108 SimdBackendState::Ready => Ok(self.best_backend),
109 SimdBackendState::Failed => Err(SimdBackendState::Failed),
110 SimdBackendState::Configuring => Err(SimdBackendState::Configuring),
111 SimdBackendState::Uninitialized => {
112 self.state = SimdBackendState::Configuring;
113
114 if self.amx_supported {
116 self.tile_config = Some(AmxTileConfig {
117 palette: 1,
118 start_row: 0,
119 rows: 16,
120 bytes_per_row: 64,
121 });
122 }
124
125 self.state = SimdBackendState::Ready;
126 Ok(self.best_backend)
127 }
128 }
129 }
130
131 #[must_use]
133 pub fn state(&self) -> SimdBackendState {
134 self.state
135 }
136
137 #[must_use]
139 pub fn best_backend(&self) -> ComputeBackend {
140 self.best_backend
141 }
142
143 #[must_use]
145 pub fn has_amx(&self) -> bool {
146 self.amx_supported
147 }
148
149 pub fn reset(&mut self) {
151 self.state = SimdBackendState::Uninitialized;
152 self.tile_config = None;
153 }
154}
155
156impl Default for LazySimdConfig {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub enum UnrollFactor {
169 None,
171 X2,
173 X4,
175 X8,
177}
178
179impl UnrollFactor {
180 #[must_use]
182 pub fn value(&self) -> usize {
183 match self {
184 UnrollFactor::None => 1,
185 UnrollFactor::X2 => 2,
186 UnrollFactor::X4 => 4,
187 UnrollFactor::X8 => 8,
188 }
189 }
190
191 #[must_use]
193 pub fn for_backend(backend: ComputeBackend) -> Self {
194 match backend {
195 ComputeBackend::Avx512 => UnrollFactor::X8,
196 ComputeBackend::Avx2 => UnrollFactor::X4,
197 ComputeBackend::Sse2 | ComputeBackend::Neon => UnrollFactor::X2,
198 _ => UnrollFactor::None,
199 }
200 }
201}
202
203#[derive(Debug)]
207pub struct UnrollTailIterator {
208 total: usize,
210 position: usize,
212 chunk_size: usize,
214}
215
216impl UnrollTailIterator {
217 pub fn new(total: usize, factor: UnrollFactor) -> Self {
219 Self { total, position: 0, chunk_size: factor.value() }
220 }
221
222 #[must_use]
224 pub fn full_iterations(&self) -> usize {
225 self.total / self.chunk_size
226 }
227
228 #[must_use]
230 pub fn tail_size(&self) -> usize {
231 self.total % self.chunk_size
232 }
233
234 #[must_use]
236 pub fn has_tail(&self) -> bool {
237 self.tail_size() > 0
238 }
239
240 pub fn next_chunk(&mut self) -> Option<(usize, usize)> {
242 if self.position + self.chunk_size <= self.total {
243 let start = self.position;
244 self.position += self.chunk_size;
245 Some((start, start + self.chunk_size))
246 } else {
247 None
248 }
249 }
250
251 pub fn tail_range(&self) -> Option<(usize, usize)> {
253 let tail_start = self.full_iterations() * self.chunk_size;
254 if tail_start < self.total {
255 Some((tail_start, self.total))
256 } else {
257 None
258 }
259 }
260}
261
262pub fn unroll_tail_process<T, U, F, G>(
274 data: &[T],
275 factor: UnrollFactor,
276 mut process_chunk: F,
277 mut process_elem: G,
278) -> Vec<U>
279where
280 F: FnMut(&[T]) -> U,
281 G: FnMut(&T) -> U,
282{
283 let mut iter = UnrollTailIterator::new(data.len(), factor);
284 let mut results =
285 Vec::with_capacity(iter.full_iterations() + if iter.has_tail() { 1 } else { 0 });
286
287 while let Some((start, end)) = iter.next_chunk() {
289 results.push(process_chunk(&data[start..end]));
290 }
291
292 if let Some((start, end)) = iter.tail_range() {
294 for elem in &data[start..end] {
295 results.push(process_elem(elem));
296 }
297 }
298
299 results
300}
301
302#[cfg(test)]
303mod tests;