1use alloc::vec::Vec;
19use core::iter::repeat_n;
20use hekate_core::errors::Error;
21use hekate_core::poly::PolyVariant;
22use hekate_core::trace::{ColumnType, Trace, TraceColumn, TraceCompatibleField};
23use hekate_math::{Bit, Block8, Block16, Block32, Block64, Flat};
24
25#[derive(Clone, Copy, Debug)]
27pub enum ExpansionEntry {
28 ExpandBits {
29 count: usize,
30 storage: ColumnType,
31 },
32 PassThrough {
33 count: usize,
34 storage: ColumnType,
35 },
36 ControlBits {
37 count: usize,
38 },
39 ReusePassThrough {
40 phy_col_start: usize,
41 count: usize,
42 storage: ColumnType,
43 },
44 ReuseExpandBits {
45 phy_col_start: usize,
46 count: usize,
47 storage: ColumnType,
48 },
49}
50
51#[derive(Clone, Copy, Debug)]
53enum EntryKind {
54 ExpandBits { count: usize, storage: ColumnType },
57
58 PassThrough { count: usize, storage: ColumnType },
61
62 ControlBits { count: usize },
65}
66
67impl EntryKind {
68 fn count(&self) -> usize {
69 match self {
70 Self::ExpandBits { count, .. }
71 | Self::PassThrough { count, .. }
72 | Self::ControlBits { count } => *count,
73 }
74 }
75
76 fn storage(&self) -> ColumnType {
77 match self {
78 Self::ExpandBits { storage, .. } | Self::PassThrough { storage, .. } => *storage,
79 Self::ControlBits { .. } => ColumnType::Bit,
80 }
81 }
82}
83
84#[derive(Clone, Copy, Debug)]
87struct CompiledEntry {
88 phy_col_start: usize,
91
92 byte_offset: usize,
94 kind: EntryKind,
95
96 reuse: bool,
99}
100
101#[derive(Clone, Debug)]
109pub struct VirtualExpander {
110 entries: Vec<CompiledEntry>,
111 num_virtual: usize,
112 num_physical: usize,
113 physical_row_bytes: usize,
114 virtual_layout: Vec<ColumnType>,
115 error: Option<Error>,
116}
117
118impl VirtualExpander {
119 pub fn new() -> Self {
120 Self {
121 entries: Vec::new(),
122 num_virtual: 0,
123 num_physical: 0,
124 physical_row_bytes: 0,
125 virtual_layout: Vec::new(),
126 error: None,
127 }
128 }
129
130 pub fn build(self) -> Result<Self, Error> {
133 match self.error {
134 Some(e) => Err(e),
135 None => Ok(self),
136 }
137 }
138
139 pub fn expand_bits(mut self, count: usize, storage: ColumnType) -> Self {
142 if self.error.is_some() {
143 return self;
144 }
145
146 let bits_per = match expand_bit_width(storage) {
147 Ok(v) => v,
148 Err(e) => {
149 self.error = Some(e);
150 return self;
151 }
152 };
153
154 let byte_offset = self.physical_row_bytes;
155 let phy_col_start = self.num_physical;
156
157 self.entries.push(CompiledEntry {
158 phy_col_start,
159 byte_offset,
160 kind: EntryKind::ExpandBits { count, storage },
161 reuse: false,
162 });
163
164 let virt_count = count * bits_per;
165 self.virtual_layout
166 .extend(repeat_n(ColumnType::Bit, virt_count));
167
168 self.num_virtual += virt_count;
169 self.num_physical += count;
170 self.physical_row_bytes += count * storage.byte_size();
171
172 self
173 }
174
175 pub fn pass_through(mut self, count: usize, storage: ColumnType) -> Self {
178 let byte_offset = self.physical_row_bytes;
179 let phy_col_start = self.num_physical;
180
181 self.entries.push(CompiledEntry {
182 phy_col_start,
183 byte_offset,
184 kind: EntryKind::PassThrough { count, storage },
185 reuse: false,
186 });
187
188 self.virtual_layout.extend(repeat_n(storage, count));
189
190 self.num_virtual += count;
191 self.num_physical += count;
192 self.physical_row_bytes += count * storage.byte_size();
193
194 self
195 }
196
197 pub fn control_bits(mut self, count: usize) -> Self {
199 let byte_offset = self.physical_row_bytes;
200 let phy_col_start = self.num_physical;
201
202 self.entries.push(CompiledEntry {
203 phy_col_start,
204 byte_offset,
205 kind: EntryKind::ControlBits { count },
206 reuse: false,
207 });
208
209 self.virtual_layout.extend(repeat_n(ColumnType::Bit, count));
210
211 self.num_virtual += count;
212 self.num_physical += count;
213 self.physical_row_bytes += count;
214
215 self
216 }
217
218 pub fn reuse_pass_through(mut self, phy_col_start: usize, count: usize) -> Self {
222 if self.error.is_some() {
223 return self;
224 }
225
226 if phy_col_start + count > self.num_physical {
227 self.error = Some(Error::Protocol {
228 protocol: "virtual_expand",
229 message: "reuse_pass_through: range exceeds declared physical columns",
230 });
231 return self;
232 }
233
234 let (byte_offset, storage) = match self.find_phy_source(phy_col_start, count) {
235 Ok(v) => v,
236 Err(e) => {
237 self.error = Some(e);
238 return self;
239 }
240 };
241
242 self.entries.push(CompiledEntry {
243 phy_col_start,
244 byte_offset,
245 kind: EntryKind::PassThrough { count, storage },
246 reuse: true,
247 });
248
249 self.virtual_layout.extend(repeat_n(storage, count));
250
251 self.num_virtual += count;
252
253 self
254 }
255
256 pub fn reuse_expand_bits(mut self, phy_col_start: usize, count: usize) -> Self {
260 if self.error.is_some() {
261 return self;
262 }
263
264 if phy_col_start + count > self.num_physical {
265 self.error = Some(Error::Protocol {
266 protocol: "virtual_expand",
267 message: "reuse_expand_bits: range exceeds declared physical columns",
268 });
269 return self;
270 }
271
272 let (byte_offset, storage) = match self.find_phy_source(phy_col_start, count) {
273 Ok(v) => v,
274 Err(e) => {
275 self.error = Some(e);
276 return self;
277 }
278 };
279
280 let bits_per = match expand_bit_width(storage) {
281 Ok(v) => v,
282 Err(e) => {
283 self.error = Some(e);
284 return self;
285 }
286 };
287
288 self.entries.push(CompiledEntry {
289 phy_col_start,
290 byte_offset,
291 kind: EntryKind::ExpandBits { count, storage },
292 reuse: true,
293 });
294
295 let virt_count = count * bits_per;
296 self.virtual_layout
297 .extend(repeat_n(ColumnType::Bit, virt_count));
298
299 self.num_virtual += virt_count;
300
301 self
302 }
303
304 #[inline]
305 pub fn num_virtual_columns(&self) -> usize {
306 self.num_virtual
307 }
308
309 #[inline]
310 pub fn num_physical_columns(&self) -> usize {
311 self.num_physical
312 }
313
314 #[inline]
315 pub fn physical_row_bytes(&self) -> usize {
316 self.physical_row_bytes
317 }
318
319 #[inline]
320 pub fn virtual_layout(&self) -> &[ColumnType] {
321 &self.virtual_layout
322 }
323
324 pub fn parse_row<F: TraceCompatibleField>(
328 &self,
329 bytes: &[u8],
330 res: &mut Vec<Flat<F>>,
331 ) -> Result<(), Error> {
332 if bytes.len() != self.physical_row_bytes {
333 return Err(Error::Protocol {
334 protocol: "virtual_expand",
335 message: "parse_row: byte slice length mismatch",
336 });
337 }
338
339 res.reserve(self.num_virtual);
340
341 for entry in &self.entries {
342 let off = entry.byte_offset;
343 match entry.kind {
344 EntryKind::ExpandBits { count, storage } => {
345 let bsz = storage.byte_size();
346 let bits = expand_bit_width(storage)?;
347
348 for i in 0..count {
349 let start = off + i * bsz;
350 for bit_idx in 0..bits {
351 let bit = parse_tower_bit(storage, &bytes[start..start + bsz], bit_idx);
352 res.push(Flat::from_raw(F::from(Bit::from(bit))));
353 }
354 }
355 }
356 EntryKind::PassThrough { count, storage } => {
357 let bsz = storage.byte_size();
358 for i in 0..count {
359 let start = off + i * bsz;
360 res.push(storage.parse_from_bytes(&bytes[start..start + bsz]));
361 }
362 }
363 EntryKind::ControlBits { count } => {
364 for i in 0..count {
365 res.push(Flat::from_raw(F::from(Bit::from(bytes[off + i] & 1))));
366 }
367 }
368 }
369 }
370
371 Ok(())
372 }
373
374 pub fn expand_variants<'a, F, T: Trace + ?Sized>(
378 &self,
379 trace: &'a T,
380 phy_start_idx: usize,
381 ) -> Result<Vec<PolyVariant<'a, F>>, Error>
382 where
383 F: TraceCompatibleField + 'static,
384 {
385 let columns = trace.columns();
386
387 let mut variants = Vec::with_capacity(self.num_virtual);
388 for entry in &self.entries {
389 let base = phy_start_idx + entry.phy_col_start;
390 match entry.kind {
391 EntryKind::ExpandBits { count, storage } => {
392 let bits = expand_bit_width(storage)?;
393 for i in 0..count {
394 let col = columns.get(base + i).ok_or(Error::Protocol {
395 protocol: "virtual_expand",
396 message: "missing physical column for ExpandBits",
397 })?;
398
399 for bit_idx in 0..bits {
400 variants.push(expand_packed_bit(col, storage, bit_idx)?);
401 }
402 }
403 }
404 EntryKind::PassThrough { count, storage } => {
405 for i in 0..count {
406 let col = columns.get(base + i).ok_or(Error::Protocol {
407 protocol: "virtual_expand",
408 message: "missing physical column for PassThrough",
409 })?;
410
411 variants.push(expand_pass_through(col, storage)?);
412 }
413 }
414 EntryKind::ControlBits { count } => {
415 for i in 0..count {
416 let col = columns.get(base + i).ok_or(Error::Protocol {
417 protocol: "virtual_expand",
418 message: "missing physical column for ControlBits",
419 })?;
420 let data = col.as_bit_slice().ok_or(Error::Protocol {
421 protocol: "virtual_expand",
422 message: "control column must be Bit",
423 })?;
424
425 variants.push(PolyVariant::BitSlice(data));
426 }
427 }
428 }
429 }
430
431 Ok(variants)
432 }
433
434 pub fn expansion_entries(&self) -> Vec<ExpansionEntry> {
436 self.entries
437 .iter()
438 .map(|e| match (e.kind, e.reuse) {
439 (EntryKind::PassThrough { count, storage }, true) => {
440 ExpansionEntry::ReusePassThrough {
441 phy_col_start: e.phy_col_start,
442 count,
443 storage,
444 }
445 }
446 (EntryKind::ExpandBits { count, storage }, true) => {
447 ExpansionEntry::ReuseExpandBits {
448 phy_col_start: e.phy_col_start,
449 count,
450 storage,
451 }
452 }
453 (EntryKind::ExpandBits { count, storage }, false) => {
454 ExpansionEntry::ExpandBits { count, storage }
455 }
456 (EntryKind::PassThrough { count, storage }, false) => {
457 ExpansionEntry::PassThrough { count, storage }
458 }
459 (EntryKind::ControlBits { count }, _) => ExpansionEntry::ControlBits { count },
460 })
461 .collect()
462 }
463
464 fn find_phy_source(
467 &self,
468 target_start: usize,
469 target_count: usize,
470 ) -> Result<(usize, ColumnType), Error> {
471 let mut running_phy = 0usize;
472 for entry in &self.entries {
473 if entry.phy_col_start != running_phy {
474 continue;
475 }
476
477 let entry_count = entry.kind.count();
478 let entry_end = running_phy + entry_count;
479
480 if target_start >= running_phy && target_start + target_count <= entry_end {
481 let storage = entry.kind.storage();
482 let offset_in_entry = target_start - running_phy;
483
484 return Ok((
485 entry.byte_offset + offset_in_entry * storage.byte_size(),
486 storage,
487 ));
488 }
489
490 running_phy = entry_end;
491 }
492
493 Err(Error::Protocol {
494 protocol: "virtual_expand",
495 message: "reuse: source columns not found in any single fresh entry",
496 })
497 }
498}
499
500impl Default for VirtualExpander {
501 fn default() -> Self {
502 Self::new()
503 }
504}
505
506fn expand_bit_width(storage: ColumnType) -> Result<usize, Error> {
507 match storage {
508 ColumnType::B8 => Ok(8),
509 ColumnType::B16 => Ok(16),
510 ColumnType::B32 => Ok(32),
511 ColumnType::B64 => Ok(64),
512 _ => Err(Error::Protocol {
513 protocol: "virtual_expand",
514 message: "ExpandBits requires B8/B16/B32/B64",
515 }),
516 }
517}
518
519fn parse_tower_bit(storage: ColumnType, bytes: &[u8], bit_idx: usize) -> u8 {
521 match storage {
522 ColumnType::B8 => Flat::from_raw(Block8(bytes[0])).tower_bit(bit_idx),
523 ColumnType::B16 => {
524 let mut arr = [0u8; 2];
525 arr.copy_from_slice(bytes);
526
527 Flat::from_raw(Block16(u16::from_le_bytes(arr))).tower_bit(bit_idx)
528 }
529 ColumnType::B32 => {
530 let mut arr = [0u8; 4];
531 arr.copy_from_slice(bytes);
532
533 Flat::from_raw(Block32(u32::from_le_bytes(arr))).tower_bit(bit_idx)
534 }
535 ColumnType::B64 => {
536 let mut arr = [0u8; 8];
537 arr.copy_from_slice(bytes);
538
539 Flat::from_raw(Block64(u64::from_le_bytes(arr))).tower_bit(bit_idx)
540 }
541 _ => unreachable!(),
542 }
543}
544
545fn expand_packed_bit<F: TraceCompatibleField + 'static>(
546 col: &'_ TraceColumn,
547 storage: ColumnType,
548 bit_idx: usize,
549) -> Result<PolyVariant<'_, F>, Error> {
550 match storage {
551 ColumnType::B8 => {
552 let data = col.as_b8_slice().ok_or(Error::Protocol {
553 protocol: "virtual_expand",
554 message: "ExpandBits B8: column type mismatch",
555 })?;
556
557 Ok(PolyVariant::PackedBitB8 { data, bit_idx })
558 }
559 ColumnType::B16 => {
560 let data = col.as_b16_slice().ok_or(Error::Protocol {
561 protocol: "virtual_expand",
562 message: "ExpandBits B16: column type mismatch",
563 })?;
564
565 Ok(PolyVariant::PackedBitB16 { data, bit_idx })
566 }
567 ColumnType::B32 => {
568 let data = col.as_b32_slice().ok_or(Error::Protocol {
569 protocol: "virtual_expand",
570 message: "ExpandBits B32: column type mismatch",
571 })?;
572
573 Ok(PolyVariant::PackedBitB32 { data, bit_idx })
574 }
575 ColumnType::B64 => {
576 let data = col.as_b64_slice().ok_or(Error::Protocol {
577 protocol: "virtual_expand",
578 message: "ExpandBits B64: column type mismatch",
579 })?;
580
581 Ok(PolyVariant::PackedBitB64 { data, bit_idx })
582 }
583 _ => unreachable!(),
584 }
585}
586
587fn expand_pass_through<F: TraceCompatibleField + 'static>(
588 col: &TraceColumn,
589 storage: ColumnType,
590) -> Result<PolyVariant<'_, F>, Error> {
591 match storage {
592 ColumnType::Bit => {
593 let data = col.as_bit_slice().ok_or(Error::Protocol {
594 protocol: "virtual_expand",
595 message: "PassThrough Bit: column type mismatch",
596 })?;
597
598 Ok(PolyVariant::BitSlice(data))
599 }
600 ColumnType::B8 => {
601 let data = col.as_b8_slice().ok_or(Error::Protocol {
602 protocol: "virtual_expand",
603 message: "PassThrough B8: column type mismatch",
604 })?;
605
606 Ok(PolyVariant::B8Slice(data))
607 }
608 ColumnType::B16 => {
609 let data = col.as_b16_slice().ok_or(Error::Protocol {
610 protocol: "virtual_expand",
611 message: "PassThrough B16: column type mismatch",
612 })?;
613
614 Ok(PolyVariant::B16Slice(data))
615 }
616 ColumnType::B32 => {
617 let data = col.as_b32_slice().ok_or(Error::Protocol {
618 protocol: "virtual_expand",
619 message: "PassThrough B32: column type mismatch",
620 })?;
621
622 Ok(PolyVariant::B32Slice(data))
623 }
624 ColumnType::B64 => {
625 let data = col.as_b64_slice().ok_or(Error::Protocol {
626 protocol: "virtual_expand",
627 message: "PassThrough B64: column type mismatch",
628 })?;
629
630 Ok(PolyVariant::B64Slice(data))
631 }
632 ColumnType::B128 => {
633 let data = col.as_b128_slice().ok_or(Error::Protocol {
634 protocol: "virtual_expand",
635 message: "PassThrough B128: column type mismatch",
636 })?;
637
638 Ok(PolyVariant::B128Slice(data))
639 }
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646 use hekate_core::trace::TraceBuilder;
647 use hekate_math::Block128;
648
649 #[test]
650 fn ram_layout() {
651 let e = VirtualExpander::new()
652 .expand_bits(2, ColumnType::B32)
653 .pass_through(13, ColumnType::B32)
654 .pass_through(1, ColumnType::B128)
655 .control_bits(4)
656 .build()
657 .unwrap();
658
659 assert_eq!(e.num_virtual_columns(), 82);
660 assert_eq!(e.num_physical_columns(), 20);
661 assert_eq!(e.physical_row_bytes(), 80);
662
663 let layout = e.virtual_layout();
664 assert_eq!(layout.len(), 82);
665 assert!(layout[..64].iter().all(|&t| t == ColumnType::Bit));
666 assert!(layout[64..77].iter().all(|&t| t == ColumnType::B32));
667 assert_eq!(layout[77], ColumnType::B128);
668 assert!(layout[78..82].iter().all(|&t| t == ColumnType::Bit));
669 }
670
671 #[test]
672 fn keccak_layout() {
673 let e = VirtualExpander::new()
674 .expand_bits(25, ColumnType::B64)
675 .expand_bits(1, ColumnType::B64)
676 .reuse_pass_through(0, 25)
677 .control_bits(2)
678 .build()
679 .unwrap();
680
681 assert_eq!(e.num_virtual_columns(), 1691);
682 assert_eq!(e.num_physical_columns(), 28);
683 assert_eq!(e.physical_row_bytes(), 210);
684
685 let layout = e.virtual_layout();
686 assert_eq!(layout.len(), 1691);
687 assert!(layout[..1600].iter().all(|&t| t == ColumnType::Bit));
688 assert!(layout[1600..1664].iter().all(|&t| t == ColumnType::Bit));
689 assert!(layout[1664..1689].iter().all(|&t| t == ColumnType::B64));
690 assert!(layout[1689..1691].iter().all(|&t| t == ColumnType::Bit));
691 }
692
693 #[test]
694 fn reuse_partial_range() {
695 let e = VirtualExpander::new()
696 .expand_bits(10, ColumnType::B32)
697 .reuse_pass_through(3, 4)
698 .build()
699 .unwrap();
700
701 assert_eq!(e.num_virtual_columns(), 324);
702 assert_eq!(e.num_physical_columns(), 10);
703 assert_eq!(e.physical_row_bytes(), 40);
704
705 let layout = e.virtual_layout();
706 assert_eq!(layout[320..324].len(), 4);
707 assert!(layout[320..324].iter().all(|&t| t == ColumnType::B32));
708 }
709
710 #[test]
711 fn reuse_exceeds_declared() {
712 let result = VirtualExpander::new()
713 .expand_bits(5, ColumnType::B32)
714 .reuse_pass_through(3, 5)
715 .build();
716 assert!(result.is_err());
717 }
718
719 #[test]
720 fn reuse_expand_bits_from_pass_through() {
721 let e = VirtualExpander::new()
722 .pass_through(4, ColumnType::B64)
723 .reuse_expand_bits(0, 4)
724 .build()
725 .unwrap();
726
727 assert_eq!(e.num_physical_columns(), 4);
728 assert_eq!(e.physical_row_bytes(), 32);
729 assert_eq!(e.num_virtual_columns(), 4 + 256);
730
731 let layout = e.virtual_layout();
732 assert!(layout[0..4].iter().all(|&t| t == ColumnType::B64));
733 assert!(layout[4..260].iter().all(|&t| t == ColumnType::Bit));
734 }
735
736 #[test]
737 fn reuse_expand_bits_exceeds_declared() {
738 let result = VirtualExpander::new()
739 .pass_through(4, ColumnType::B64)
740 .reuse_expand_bits(2, 4)
741 .build();
742 assert!(result.is_err());
743 }
744
745 #[test]
746 fn reuse_expand_bits_rejects_b128_source() {
747 let result = VirtualExpander::new()
748 .pass_through(1, ColumnType::B128)
749 .reuse_expand_bits(0, 1)
750 .build();
751 assert!(result.is_err());
752 }
753
754 #[test]
755 fn expand_rejects_bit() {
756 let result = VirtualExpander::new()
757 .expand_bits(1, ColumnType::Bit)
758 .build();
759 assert!(result.is_err());
760 }
761
762 #[test]
763 fn expand_rejects_b128() {
764 let result = VirtualExpander::new()
765 .expand_bits(1, ColumnType::B128)
766 .build();
767 assert!(result.is_err());
768 }
769
770 #[test]
771 fn empty_expander() {
772 let e = VirtualExpander::new();
773 assert_eq!(e.num_virtual_columns(), 0);
774 assert_eq!(e.num_physical_columns(), 0);
775 assert_eq!(e.physical_row_bytes(), 0);
776 assert!(e.virtual_layout().is_empty());
777 }
778
779 #[test]
780 fn parse_row_b32_roundtrip() {
781 let expander = VirtualExpander::new()
782 .expand_bits(1, ColumnType::B32)
783 .pass_through(1, ColumnType::B32)
784 .control_bits(1)
785 .build()
786 .unwrap();
787
788 let val: u32 = 0xDEAD_BEEF;
789 let pass_val: u32 = 0x1234_5678;
790
791 let mut bytes = Vec::new();
792 bytes.extend_from_slice(&val.to_le_bytes());
793 bytes.extend_from_slice(&pass_val.to_le_bytes());
794 bytes.push(1);
795
796 let mut res: Vec<Flat<Block128>> = Vec::new();
797 expander.parse_row(&bytes, &mut res).unwrap();
798
799 assert_eq!(res.len(), 34);
800
801 for (bit_idx, elem) in res.iter().enumerate().take(32) {
802 let expected = Flat::from_raw(Block32(val)).tower_bit(bit_idx);
803 let got = elem.tower_bit(0);
804 assert_eq!(got, expected, "bit {bit_idx} mismatch");
805 }
806
807 let pass = res[32];
808 assert_eq!(
809 pass,
810 <Block128 as hekate_math::FlatPromote<Block32>>::promote_flat(Flat::from_raw(Block32(
811 pass_val
812 )))
813 );
814
815 let ctrl = res[33].tower_bit(0);
816 assert_eq!(ctrl, 1);
817 }
818
819 #[test]
820 fn expand_variants_b32() {
821 let expander = VirtualExpander::new()
822 .expand_bits(1, ColumnType::B32)
823 .pass_through(1, ColumnType::B32)
824 .control_bits(1)
825 .build()
826 .unwrap();
827
828 let layout = [ColumnType::B32, ColumnType::B32, ColumnType::Bit];
829 let num_vars = 2;
830
831 let mut tb = TraceBuilder::new(&layout, num_vars).unwrap();
832 tb.set_b32(0, 0, Block32(0xAAAA_BBBB)).unwrap();
833 tb.set_b32(1, 0, Block32(0x1111_2222)).unwrap();
834 tb.set_bit(2, 0, Bit(1)).unwrap();
835
836 let trace = tb.build();
837
838 let variants: Vec<PolyVariant<'_, Block128>> = expander.expand_variants(&trace, 0).unwrap();
839
840 assert_eq!(variants.len(), 34);
841
842 for (i, v) in variants.iter().enumerate().take(32) {
843 assert!(matches!(v, PolyVariant::PackedBitB32 { bit_idx, .. } if *bit_idx == i));
844 }
845
846 assert!(matches!(variants[32], PolyVariant::B32Slice(_)));
847 assert!(matches!(variants[33], PolyVariant::BitSlice(_)));
848 }
849}