1use alloc::vec::Vec;
19use core::iter::repeat_n;
20use hekate_core::errors::Error;
21use hekate_core::poly::PolyVariant;
22use hekate_core::trace::{ColumnType, Trace, TraceColumn, TraceCompatibleField};
23use hekate_math::{Bit, Block8, Block16, Block32, Block64, Flat};
24
25#[derive(Clone, Copy, Debug)]
27pub enum ExpansionEntry {
28 ExpandBits {
29 count: usize,
30 storage: ColumnType,
31 },
32 PassThrough {
33 count: usize,
34 storage: ColumnType,
35 },
36 ControlBits {
37 count: usize,
38 },
39 ReusePassThrough {
40 phy_col_start: usize,
41 count: usize,
42 storage: ColumnType,
43 },
44}
45
46#[derive(Clone, Copy, Debug)]
48enum EntryKind {
49 ExpandBits { count: usize, storage: ColumnType },
52
53 PassThrough { count: usize, storage: ColumnType },
56
57 ControlBits { count: usize },
60}
61
62impl EntryKind {
63 fn count(&self) -> usize {
64 match self {
65 Self::ExpandBits { count, .. }
66 | Self::PassThrough { count, .. }
67 | Self::ControlBits { count } => *count,
68 }
69 }
70
71 fn storage(&self) -> ColumnType {
72 match self {
73 Self::ExpandBits { storage, .. } | Self::PassThrough { storage, .. } => *storage,
74 Self::ControlBits { .. } => ColumnType::Bit,
75 }
76 }
77}
78
79#[derive(Clone, Copy, Debug)]
82struct CompiledEntry {
83 phy_col_start: usize,
86
87 byte_offset: usize,
89 kind: EntryKind,
90
91 reuse: bool,
94}
95
96#[derive(Clone, Debug)]
104pub struct VirtualExpander {
105 entries: Vec<CompiledEntry>,
106 num_virtual: usize,
107 num_physical: usize,
108 physical_row_bytes: usize,
109 virtual_layout: Vec<ColumnType>,
110 error: Option<Error>,
111}
112
113impl VirtualExpander {
114 pub fn new() -> Self {
115 Self {
116 entries: Vec::new(),
117 num_virtual: 0,
118 num_physical: 0,
119 physical_row_bytes: 0,
120 virtual_layout: Vec::new(),
121 error: None,
122 }
123 }
124
125 pub fn build(self) -> Result<Self, Error> {
128 match self.error {
129 Some(e) => Err(e),
130 None => Ok(self),
131 }
132 }
133
134 pub fn expand_bits(mut self, count: usize, storage: ColumnType) -> Self {
137 if self.error.is_some() {
138 return self;
139 }
140
141 let bits_per = match expand_bit_width(storage) {
142 Ok(v) => v,
143 Err(e) => {
144 self.error = Some(e);
145 return self;
146 }
147 };
148
149 let byte_offset = self.physical_row_bytes;
150 let phy_col_start = self.num_physical;
151
152 self.entries.push(CompiledEntry {
153 phy_col_start,
154 byte_offset,
155 kind: EntryKind::ExpandBits { count, storage },
156 reuse: false,
157 });
158
159 let virt_count = count * bits_per;
160 self.virtual_layout
161 .extend(repeat_n(ColumnType::Bit, virt_count));
162
163 self.num_virtual += virt_count;
164 self.num_physical += count;
165 self.physical_row_bytes += count * storage.byte_size();
166
167 self
168 }
169
170 pub fn pass_through(mut self, count: usize, storage: ColumnType) -> Self {
173 let byte_offset = self.physical_row_bytes;
174 let phy_col_start = self.num_physical;
175
176 self.entries.push(CompiledEntry {
177 phy_col_start,
178 byte_offset,
179 kind: EntryKind::PassThrough { count, storage },
180 reuse: false,
181 });
182
183 self.virtual_layout.extend(repeat_n(storage, count));
184
185 self.num_virtual += count;
186 self.num_physical += count;
187 self.physical_row_bytes += count * storage.byte_size();
188
189 self
190 }
191
192 pub fn control_bits(mut self, count: usize) -> Self {
194 let byte_offset = self.physical_row_bytes;
195 let phy_col_start = self.num_physical;
196
197 self.entries.push(CompiledEntry {
198 phy_col_start,
199 byte_offset,
200 kind: EntryKind::ControlBits { count },
201 reuse: false,
202 });
203
204 self.virtual_layout.extend(repeat_n(ColumnType::Bit, count));
205
206 self.num_virtual += count;
207 self.num_physical += count;
208 self.physical_row_bytes += count;
209
210 self
211 }
212
213 pub fn reuse_pass_through(mut self, phy_col_start: usize, count: usize) -> Self {
218 if self.error.is_some() {
219 return self;
220 }
221
222 if phy_col_start + count > self.num_physical {
223 self.error = Some(Error::Protocol {
224 protocol: "virtual_expand",
225 message: "reuse_pass_through: range exceeds declared physical columns",
226 });
227 return self;
228 }
229
230 let (byte_offset, storage) = match self.find_phy_source(phy_col_start, count) {
231 Ok(v) => v,
232 Err(e) => {
233 self.error = Some(e);
234 return self;
235 }
236 };
237
238 self.entries.push(CompiledEntry {
239 phy_col_start,
240 byte_offset,
241 kind: EntryKind::PassThrough { count, storage },
242 reuse: true,
243 });
244
245 self.virtual_layout.extend(repeat_n(storage, count));
246
247 self.num_virtual += count;
248
249 self
250 }
251
252 #[inline]
253 pub fn num_virtual_columns(&self) -> usize {
254 self.num_virtual
255 }
256
257 #[inline]
258 pub fn num_physical_columns(&self) -> usize {
259 self.num_physical
260 }
261
262 #[inline]
263 pub fn physical_row_bytes(&self) -> usize {
264 self.physical_row_bytes
265 }
266
267 #[inline]
268 pub fn virtual_layout(&self) -> &[ColumnType] {
269 &self.virtual_layout
270 }
271
272 pub fn parse_row<F: TraceCompatibleField>(
276 &self,
277 bytes: &[u8],
278 res: &mut Vec<Flat<F>>,
279 ) -> Result<(), Error> {
280 if bytes.len() != self.physical_row_bytes {
281 return Err(Error::Protocol {
282 protocol: "virtual_expand",
283 message: "parse_row: byte slice length mismatch",
284 });
285 }
286
287 res.reserve(self.num_virtual);
288
289 for entry in &self.entries {
290 let off = entry.byte_offset;
291 match entry.kind {
292 EntryKind::ExpandBits { count, storage } => {
293 let bsz = storage.byte_size();
294 let bits = expand_bit_width(storage)?;
295
296 for i in 0..count {
297 let start = off + i * bsz;
298 for bit_idx in 0..bits {
299 let bit = parse_tower_bit(storage, &bytes[start..start + bsz], bit_idx);
300 res.push(Flat::from_raw(F::from(Bit::from(bit))));
301 }
302 }
303 }
304 EntryKind::PassThrough { count, storage } => {
305 let bsz = storage.byte_size();
306 for i in 0..count {
307 let start = off + i * bsz;
308 res.push(storage.parse_from_bytes(&bytes[start..start + bsz]));
309 }
310 }
311 EntryKind::ControlBits { count } => {
312 for i in 0..count {
313 res.push(Flat::from_raw(F::from(Bit::from(bytes[off + i] & 1))));
314 }
315 }
316 }
317 }
318
319 Ok(())
320 }
321
322 pub fn expand_variants<'a, F, T: Trace + ?Sized>(
326 &self,
327 trace: &'a T,
328 phy_start_idx: usize,
329 ) -> Result<Vec<PolyVariant<'a, F>>, Error>
330 where
331 F: TraceCompatibleField + 'static,
332 {
333 let columns = trace.columns();
334
335 let mut variants = Vec::with_capacity(self.num_virtual);
336 for entry in &self.entries {
337 let base = phy_start_idx + entry.phy_col_start;
338 match entry.kind {
339 EntryKind::ExpandBits { count, storage } => {
340 let bits = expand_bit_width(storage)?;
341 for i in 0..count {
342 let col = columns.get(base + i).ok_or(Error::Protocol {
343 protocol: "virtual_expand",
344 message: "missing physical column for ExpandBits",
345 })?;
346
347 for bit_idx in 0..bits {
348 variants.push(expand_packed_bit(col, storage, bit_idx)?);
349 }
350 }
351 }
352 EntryKind::PassThrough { count, storage } => {
353 for i in 0..count {
354 let col = columns.get(base + i).ok_or(Error::Protocol {
355 protocol: "virtual_expand",
356 message: "missing physical column for PassThrough",
357 })?;
358
359 variants.push(expand_pass_through(col, storage)?);
360 }
361 }
362 EntryKind::ControlBits { count } => {
363 for i in 0..count {
364 let col = columns.get(base + i).ok_or(Error::Protocol {
365 protocol: "virtual_expand",
366 message: "missing physical column for ControlBits",
367 })?;
368 let data = col.as_bit_slice().ok_or(Error::Protocol {
369 protocol: "virtual_expand",
370 message: "control column must be Bit",
371 })?;
372
373 variants.push(PolyVariant::BitSlice(data));
374 }
375 }
376 }
377 }
378
379 Ok(variants)
380 }
381
382 pub fn expansion_entries(&self) -> Vec<ExpansionEntry> {
384 self.entries
385 .iter()
386 .map(|e| match (e.kind, e.reuse) {
387 (EntryKind::PassThrough { count, storage }, true) => {
388 ExpansionEntry::ReusePassThrough {
389 phy_col_start: e.phy_col_start,
390 count,
391 storage,
392 }
393 }
394 (EntryKind::ExpandBits { count, storage }, _) => {
395 ExpansionEntry::ExpandBits { count, storage }
396 }
397 (EntryKind::PassThrough { count, storage }, false) => {
398 ExpansionEntry::PassThrough { count, storage }
399 }
400 (EntryKind::ControlBits { count }, _) => ExpansionEntry::ControlBits { count },
401 })
402 .collect()
403 }
404
405 fn find_phy_source(
408 &self,
409 target_start: usize,
410 target_count: usize,
411 ) -> Result<(usize, ColumnType), Error> {
412 let mut running_phy = 0usize;
413 for entry in &self.entries {
414 if entry.phy_col_start != running_phy {
415 continue;
416 }
417
418 let entry_count = entry.kind.count();
419 let entry_end = running_phy + entry_count;
420
421 if target_start >= running_phy && target_start + target_count <= entry_end {
422 let storage = entry.kind.storage();
423 let offset_in_entry = target_start - running_phy;
424
425 return Ok((
426 entry.byte_offset + offset_in_entry * storage.byte_size(),
427 storage,
428 ));
429 }
430
431 running_phy = entry_end;
432 }
433
434 Err(Error::Protocol {
435 protocol: "virtual_expand",
436 message: "reuse_pass_through: source columns not found in any fresh entry",
437 })
438 }
439}
440
441impl Default for VirtualExpander {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447fn expand_bit_width(storage: ColumnType) -> Result<usize, Error> {
448 match storage {
449 ColumnType::B8 => Ok(8),
450 ColumnType::B16 => Ok(16),
451 ColumnType::B32 => Ok(32),
452 ColumnType::B64 => Ok(64),
453 _ => Err(Error::Protocol {
454 protocol: "virtual_expand",
455 message: "ExpandBits requires B8/B16/B32/B64",
456 }),
457 }
458}
459
460fn parse_tower_bit(storage: ColumnType, bytes: &[u8], bit_idx: usize) -> u8 {
462 match storage {
463 ColumnType::B8 => Flat::from_raw(Block8(bytes[0])).tower_bit(bit_idx),
464 ColumnType::B16 => {
465 let mut arr = [0u8; 2];
466 arr.copy_from_slice(bytes);
467
468 Flat::from_raw(Block16(u16::from_le_bytes(arr))).tower_bit(bit_idx)
469 }
470 ColumnType::B32 => {
471 let mut arr = [0u8; 4];
472 arr.copy_from_slice(bytes);
473
474 Flat::from_raw(Block32(u32::from_le_bytes(arr))).tower_bit(bit_idx)
475 }
476 ColumnType::B64 => {
477 let mut arr = [0u8; 8];
478 arr.copy_from_slice(bytes);
479
480 Flat::from_raw(Block64(u64::from_le_bytes(arr))).tower_bit(bit_idx)
481 }
482 _ => unreachable!(),
483 }
484}
485
486fn expand_packed_bit<F: TraceCompatibleField + 'static>(
487 col: &'_ TraceColumn,
488 storage: ColumnType,
489 bit_idx: usize,
490) -> Result<PolyVariant<'_, F>, Error> {
491 match storage {
492 ColumnType::B8 => {
493 let data = col.as_b8_slice().ok_or(Error::Protocol {
494 protocol: "virtual_expand",
495 message: "ExpandBits B8: column type mismatch",
496 })?;
497
498 Ok(PolyVariant::PackedBitB8 { data, bit_idx })
499 }
500 ColumnType::B16 => {
501 let data = col.as_b16_slice().ok_or(Error::Protocol {
502 protocol: "virtual_expand",
503 message: "ExpandBits B16: column type mismatch",
504 })?;
505
506 Ok(PolyVariant::PackedBitB16 { data, bit_idx })
507 }
508 ColumnType::B32 => {
509 let data = col.as_b32_slice().ok_or(Error::Protocol {
510 protocol: "virtual_expand",
511 message: "ExpandBits B32: column type mismatch",
512 })?;
513
514 Ok(PolyVariant::PackedBitB32 { data, bit_idx })
515 }
516 ColumnType::B64 => {
517 let data = col.as_b64_slice().ok_or(Error::Protocol {
518 protocol: "virtual_expand",
519 message: "ExpandBits B64: column type mismatch",
520 })?;
521
522 Ok(PolyVariant::PackedBitB64 { data, bit_idx })
523 }
524 _ => unreachable!(),
525 }
526}
527
528fn expand_pass_through<F: TraceCompatibleField + 'static>(
529 col: &TraceColumn,
530 storage: ColumnType,
531) -> Result<PolyVariant<'_, F>, Error> {
532 match storage {
533 ColumnType::Bit => {
534 let data = col.as_bit_slice().ok_or(Error::Protocol {
535 protocol: "virtual_expand",
536 message: "PassThrough Bit: column type mismatch",
537 })?;
538
539 Ok(PolyVariant::BitSlice(data))
540 }
541 ColumnType::B8 => {
542 let data = col.as_b8_slice().ok_or(Error::Protocol {
543 protocol: "virtual_expand",
544 message: "PassThrough B8: column type mismatch",
545 })?;
546
547 Ok(PolyVariant::B8Slice(data))
548 }
549 ColumnType::B16 => {
550 let data = col.as_b16_slice().ok_or(Error::Protocol {
551 protocol: "virtual_expand",
552 message: "PassThrough B16: column type mismatch",
553 })?;
554
555 Ok(PolyVariant::B16Slice(data))
556 }
557 ColumnType::B32 => {
558 let data = col.as_b32_slice().ok_or(Error::Protocol {
559 protocol: "virtual_expand",
560 message: "PassThrough B32: column type mismatch",
561 })?;
562
563 Ok(PolyVariant::B32Slice(data))
564 }
565 ColumnType::B64 => {
566 let data = col.as_b64_slice().ok_or(Error::Protocol {
567 protocol: "virtual_expand",
568 message: "PassThrough B64: column type mismatch",
569 })?;
570
571 Ok(PolyVariant::B64Slice(data))
572 }
573 ColumnType::B128 => {
574 let data = col.as_b128_slice().ok_or(Error::Protocol {
575 protocol: "virtual_expand",
576 message: "PassThrough B128: column type mismatch",
577 })?;
578
579 Ok(PolyVariant::B128Slice(data))
580 }
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use hekate_core::trace::TraceBuilder;
588 use hekate_math::Block128;
589
590 #[test]
591 fn ram_layout() {
592 let e = VirtualExpander::new()
593 .expand_bits(2, ColumnType::B32)
594 .pass_through(13, ColumnType::B32)
595 .pass_through(1, ColumnType::B128)
596 .control_bits(4)
597 .build()
598 .unwrap();
599
600 assert_eq!(e.num_virtual_columns(), 82);
601 assert_eq!(e.num_physical_columns(), 20);
602 assert_eq!(e.physical_row_bytes(), 80);
603
604 let layout = e.virtual_layout();
605 assert_eq!(layout.len(), 82);
606 assert!(layout[..64].iter().all(|&t| t == ColumnType::Bit));
607 assert!(layout[64..77].iter().all(|&t| t == ColumnType::B32));
608 assert_eq!(layout[77], ColumnType::B128);
609 assert!(layout[78..82].iter().all(|&t| t == ColumnType::Bit));
610 }
611
612 #[test]
613 fn keccak_layout() {
614 let e = VirtualExpander::new()
615 .expand_bits(25, ColumnType::B64)
616 .expand_bits(1, ColumnType::B64)
617 .reuse_pass_through(0, 25)
618 .control_bits(2)
619 .build()
620 .unwrap();
621
622 assert_eq!(e.num_virtual_columns(), 1691);
623 assert_eq!(e.num_physical_columns(), 28);
624 assert_eq!(e.physical_row_bytes(), 210);
625
626 let layout = e.virtual_layout();
627 assert_eq!(layout.len(), 1691);
628 assert!(layout[..1600].iter().all(|&t| t == ColumnType::Bit));
629 assert!(layout[1600..1664].iter().all(|&t| t == ColumnType::Bit));
630 assert!(layout[1664..1689].iter().all(|&t| t == ColumnType::B64));
631 assert!(layout[1689..1691].iter().all(|&t| t == ColumnType::Bit));
632 }
633
634 #[test]
635 fn reuse_partial_range() {
636 let e = VirtualExpander::new()
637 .expand_bits(10, ColumnType::B32)
638 .reuse_pass_through(3, 4)
639 .build()
640 .unwrap();
641
642 assert_eq!(e.num_virtual_columns(), 324);
643 assert_eq!(e.num_physical_columns(), 10);
644 assert_eq!(e.physical_row_bytes(), 40);
645
646 let layout = e.virtual_layout();
647 assert_eq!(layout[320..324].len(), 4);
648 assert!(layout[320..324].iter().all(|&t| t == ColumnType::B32));
649 }
650
651 #[test]
652 fn reuse_exceeds_declared() {
653 let result = VirtualExpander::new()
654 .expand_bits(5, ColumnType::B32)
655 .reuse_pass_through(3, 5)
656 .build();
657 assert!(result.is_err());
658 }
659
660 #[test]
661 fn expand_rejects_bit() {
662 let result = VirtualExpander::new()
663 .expand_bits(1, ColumnType::Bit)
664 .build();
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn expand_rejects_b128() {
670 let result = VirtualExpander::new()
671 .expand_bits(1, ColumnType::B128)
672 .build();
673 assert!(result.is_err());
674 }
675
676 #[test]
677 fn empty_expander() {
678 let e = VirtualExpander::new();
679 assert_eq!(e.num_virtual_columns(), 0);
680 assert_eq!(e.num_physical_columns(), 0);
681 assert_eq!(e.physical_row_bytes(), 0);
682 assert!(e.virtual_layout().is_empty());
683 }
684
685 #[test]
686 fn parse_row_b32_roundtrip() {
687 let expander = VirtualExpander::new()
688 .expand_bits(1, ColumnType::B32)
689 .pass_through(1, ColumnType::B32)
690 .control_bits(1)
691 .build()
692 .unwrap();
693
694 let val: u32 = 0xDEAD_BEEF;
695 let pass_val: u32 = 0x1234_5678;
696
697 let mut bytes = Vec::new();
698 bytes.extend_from_slice(&val.to_le_bytes());
699 bytes.extend_from_slice(&pass_val.to_le_bytes());
700 bytes.push(1);
701
702 let mut res: Vec<Flat<Block128>> = Vec::new();
703 expander.parse_row(&bytes, &mut res).unwrap();
704
705 assert_eq!(res.len(), 34);
706
707 for (bit_idx, elem) in res.iter().enumerate().take(32) {
708 let expected = Flat::from_raw(Block32(val)).tower_bit(bit_idx);
709 let got = elem.tower_bit(0);
710 assert_eq!(got, expected, "bit {bit_idx} mismatch");
711 }
712
713 let pass = res[32];
714 assert_eq!(
715 pass,
716 <Block128 as hekate_math::FlatPromote<Block32>>::promote_flat(Flat::from_raw(Block32(
717 pass_val
718 )))
719 );
720
721 let ctrl = res[33].tower_bit(0);
722 assert_eq!(ctrl, 1);
723 }
724
725 #[test]
726 fn expand_variants_b32() {
727 let expander = VirtualExpander::new()
728 .expand_bits(1, ColumnType::B32)
729 .pass_through(1, ColumnType::B32)
730 .control_bits(1)
731 .build()
732 .unwrap();
733
734 let layout = [ColumnType::B32, ColumnType::B32, ColumnType::Bit];
735 let num_vars = 2;
736
737 let mut tb = TraceBuilder::new(&layout, num_vars).unwrap();
738 tb.set_b32(0, 0, Block32(0xAAAA_BBBB)).unwrap();
739 tb.set_b32(1, 0, Block32(0x1111_2222)).unwrap();
740 tb.set_bit(2, 0, Bit(1)).unwrap();
741
742 let trace = tb.build();
743
744 let variants: Vec<PolyVariant<'_, Block128>> = expander.expand_variants(&trace, 0).unwrap();
745
746 assert_eq!(variants.len(), 34);
747
748 for (i, v) in variants.iter().enumerate().take(32) {
749 assert!(matches!(v, PolyVariant::PackedBitB32 { bit_idx, .. } if *bit_idx == i));
750 }
751
752 assert!(matches!(variants[32], PolyVariant::B32Slice(_)));
753 assert!(matches!(variants[33], PolyVariant::BitSlice(_)));
754 }
755}