1#[derive(Debug, Clone, Copy)]
3pub struct GridConfig {
4 pub target_nodes: usize,
6 pub width: usize,
8 pub height: usize,
10 pub stride_y: usize,
12}
13
14impl GridConfig {
15 pub fn from_target_nodes(target_nodes: usize) -> Self {
18 let dim = isqrt(target_nodes);
20 let stride_y = dim.next_power_of_two();
21
22 Self {
23 target_nodes,
24 width: dim,
25 height: dim,
26 stride_y,
27 }
28 }
29
30 pub fn actual_nodes(&self) -> usize {
32 self.width * self.height
33 }
34
35 pub fn to_rectangular(&self, aspect_ratio: f64) -> Self {
38 let nodes = self.target_nodes;
39
40 let val = (nodes as f64 / aspect_ratio) as usize;
42 let h = isqrt(val).max(1);
43
44 let w = nodes.div_ceil(h);
46
47 let h = h.max(1);
49 let w = w.max(1);
50
51 let stride_y = w.max(h).next_power_of_two();
52
53 Self {
54 target_nodes: nodes,
55 width: w,
56 height: h,
57 stride_y,
58 }
59 }
60}
61
62pub const ERROR_PROBS: [f64; 6] = [0.001, 0.003, 0.006, 0.01, 0.03, 0.06];
64
65#[must_use]
68pub fn isqrt(n: usize) -> usize {
69 if n < 2 {
70 return n;
71 }
72 let mut x = n;
73 let mut y = x.div_ceil(2);
74 while y < x {
75 x = y;
76 y = (x + n / x) / 2;
77 }
78 x
79}
80
81pub struct TestGrids;
83
84impl TestGrids {
85 pub const TINY: GridConfig = GridConfig {
87 target_nodes: 289,
88 width: 17,
89 height: 17,
90 stride_y: 32,
91 };
92
93 pub const SMALL: GridConfig = GridConfig {
95 target_nodes: 500,
96 width: 22,
97 height: 22,
98 stride_y: 32,
99 };
100
101 pub const MEDIUM: GridConfig = GridConfig {
103 target_nodes: 1024,
104 width: 32,
105 height: 32,
106 stride_y: 32,
107 };
108
109 pub const LARGE: GridConfig = GridConfig {
111 target_nodes: 4096,
112 width: 64,
113 height: 64,
114 stride_y: 64,
115 };
116
117 pub const LARGE_PLUS: GridConfig = GridConfig {
119 target_nodes: 5000,
120 width: 71,
121 height: 71,
122 stride_y: 128,
123 };
124
125 pub const XLARGE: GridConfig = GridConfig {
127 target_nodes: 100_000,
128 width: 316,
129 height: 316,
130 stride_y: 512,
131 };
132
133 pub const XXLARGE: GridConfig = GridConfig {
135 target_nodes: 131_072,
136 width: 362,
137 height: 362,
138 stride_y: 512,
139 };
140
141 pub const fn all() -> [GridConfig; 7] {
143 [
144 Self::TINY,
145 Self::SMALL,
146 Self::MEDIUM,
147 Self::LARGE,
148 Self::LARGE_PLUS,
149 Self::XLARGE,
150 Self::XXLARGE,
151 ]
152 }
153
154 pub const fn defaults() -> [GridConfig; 4] {
156 [
157 Self::TINY,
158 Self::SMALL,
159 Self::MEDIUM,
160 Self::LARGE,
161 ]
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_grid_config_from_target() {
171 let config = GridConfig::from_target_nodes(1024);
172 assert_eq!(config.width, 32);
173 assert_eq!(config.height, 32);
174 assert_eq!(config.stride_y, 32);
175 }
176
177 #[test]
178 fn test_predefined_configs() {
179 let tiny = TestGrids::TINY;
180 assert_eq!(tiny.actual_nodes(), 17 * 17);
181 assert_eq!(tiny.actual_nodes(), 289);
182
183 let medium = TestGrids::MEDIUM;
184 assert_eq!(medium.actual_nodes(), 32 * 32);
185 assert_eq!(medium.actual_nodes(), 1024);
186 }
187
188 #[test]
189 fn test_isqrt_edge_cases() {
190 assert_eq!(isqrt(0), 0);
191 assert_eq!(isqrt(1), 1);
192 assert_eq!(isqrt(4), 2);
193 assert_eq!(isqrt(5), 2);
194 assert_eq!(isqrt(9), 3);
195 assert_eq!(isqrt(10), 3);
196 assert_eq!(isqrt(100), 10);
197 }
198
199 #[test]
200 fn test_to_rectangular() {
201 let config = GridConfig::from_target_nodes(1000);
202
203 let rect = config.to_rectangular(2.0);
205 assert!(rect.width > rect.height, "width should be greater for ratio > 1");
206 assert!(rect.stride_y.is_power_of_two());
207
208 let rect_tall = config.to_rectangular(0.5);
210 assert!(rect_tall.height > rect_tall.width, "height should be greater for ratio < 1");
211 assert!(rect_tall.stride_y.is_power_of_two());
212
213 let rect_square = config.to_rectangular(1.0);
215 let diff = (rect_square.width as i64 - rect_square.height as i64).abs();
217 assert!(diff <= 2, "should be approximately square");
218 }
219
220 #[test]
221 fn test_all_grids() {
222 let all = TestGrids::all();
223 assert_eq!(all.len(), 7);
224 assert_eq!(all[0].actual_nodes(), TestGrids::TINY.actual_nodes());
225 assert_eq!(all[1].actual_nodes(), TestGrids::SMALL.actual_nodes());
226 assert_eq!(all[2].actual_nodes(), TestGrids::MEDIUM.actual_nodes());
227 assert_eq!(all[3].actual_nodes(), TestGrids::LARGE.actual_nodes());
228 assert_eq!(all[4].actual_nodes(), TestGrids::LARGE_PLUS.actual_nodes());
229 assert_eq!(all[5].actual_nodes(), TestGrids::XLARGE.actual_nodes());
230 assert_eq!(all[6].actual_nodes(), TestGrids::XXLARGE.actual_nodes());
231 }
232
233 #[test]
234 fn test_defaults_grids() {
235 let defaults = TestGrids::defaults();
236 assert_eq!(defaults.len(), 4);
237 assert_eq!(defaults[0].actual_nodes(), TestGrids::TINY.actual_nodes());
238 assert_eq!(defaults[3].actual_nodes(), TestGrids::LARGE.actual_nodes());
239 }
240
241 #[test]
242 fn test_all_predefined_grids() {
243 assert_eq!(TestGrids::SMALL.actual_nodes(), 22 * 22);
245 assert_eq!(TestGrids::LARGE.actual_nodes(), 64 * 64);
246 assert_eq!(TestGrids::LARGE_PLUS.actual_nodes(), 71 * 71);
247 assert_eq!(TestGrids::XLARGE.actual_nodes(), 316 * 316);
248 assert_eq!(TestGrids::XXLARGE.actual_nodes(), 362 * 362);
249
250 for config in TestGrids::all() {
252 assert!(config.stride_y.is_power_of_two());
253 assert!(config.stride_y >= config.width.max(config.height));
254 }
255 }
256}
257
258#[cfg(kani)]
267mod kani_proofs {
268 use super::*;
269
270 #[kani::proof]
282 #[kani::unwind(33)] fn verify_isqrt_invariant() {
284 let n: usize = kani::any();
285 kani::assume(n <= 1_000_000);
286
287 let s = isqrt(n);
288
289 let s_squared = s.checked_mul(s);
291 kani::assert(
292 s_squared.map(|sq| sq <= n).unwrap_or(false),
293 "isqrt(n)^2 must be <= n",
294 );
295
296 let s_plus_1 = s + 1;
298 let s_plus_1_squared = s_plus_1.checked_mul(s_plus_1);
299 kani::assert(
300 s_plus_1_squared.map(|sq| sq > n).unwrap_or(true),
301 "(isqrt(n)+1)^2 must be > n",
302 );
303 }
304
305 #[kani::proof]
317 fn verify_grid_dimensions_positive() {
318 let target: usize = kani::any();
319 kani::assume(target >= 1 && target <= 1_000_000);
320
321 let config = GridConfig::from_target_nodes(target);
322
323 kani::assert(config.width >= 1, "width must be positive");
324 kani::assert(config.height >= 1, "height must be positive");
325 kani::assert(config.stride_y >= 1, "stride_y must be positive");
326 }
327
328 #[kani::proof]
339 fn verify_stride_power_of_two() {
340 let target: usize = kani::any();
341 kani::assume(target >= 1 && target <= 1_000_000);
342
343 let config = GridConfig::from_target_nodes(target);
344
345 kani::assert(
346 config.stride_y.is_power_of_two(),
347 "stride_y must be power of two",
348 );
349 }
350
351 #[kani::proof]
360 fn verify_stride_covers_dimensions() {
361 let target: usize = kani::any();
362 kani::assume(target >= 1 && target <= 1_000_000);
363
364 let config = GridConfig::from_target_nodes(target);
365
366 kani::assert(
367 config.stride_y >= config.width,
368 "stride_y must be >= width",
369 );
370 kani::assert(
371 config.stride_y >= config.height,
372 "stride_y must be >= height",
373 );
374 }
375}