Skip to main content

hekate_program/
expander.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use alloc::vec::Vec;
19use core::iter::repeat_n;
20use hekate_core::errors::Error;
21use hekate_core::poly::PolyVariant;
22use hekate_core::trace::{ColumnType, Trace, TraceColumn, TraceCompatibleField};
23use hekate_math::{Bit, Block8, Block16, Block32, Block64, Flat};
24
25/// Serializable expansion step descriptor.
26#[derive(Clone, Copy, Debug)]
27pub enum ExpansionEntry {
28    ExpandBits {
29        count: usize,
30        storage: ColumnType,
31    },
32    PassThrough {
33        count: usize,
34        storage: ColumnType,
35    },
36    ControlBits {
37        count: usize,
38    },
39    ReusePassThrough {
40        phy_col_start: usize,
41        count: usize,
42        storage: ColumnType,
43    },
44}
45
46/// Physical-to-virtual column mapping rule.
47#[derive(Clone, Copy, Debug)]
48enum EntryKind {
49    /// N physical columns to N ×
50    /// bit_width virtual Bit columns.
51    ExpandBits { count: usize, storage: ColumnType },
52
53    /// N physical columns to N virtual
54    /// columns of the same type.
55    PassThrough { count: usize, storage: ColumnType },
56
57    /// N physical Bit columns
58    /// to N virtual Bit columns.
59    ControlBits { count: usize },
60}
61
62impl EntryKind {
63    fn count(&self) -> usize {
64        match self {
65            Self::ExpandBits { count, .. }
66            | Self::PassThrough { count, .. }
67            | Self::ControlBits { count } => *count,
68        }
69    }
70
71    fn storage(&self) -> ColumnType {
72        match self {
73            Self::ExpandBits { storage, .. } | Self::PassThrough { storage, .. } => *storage,
74            Self::ControlBits { .. } => ColumnType::Bit,
75        }
76    }
77}
78
79/// Pre-computed expansion entry
80/// with frozen byte/column offsets.
81#[derive(Clone, Copy, Debug)]
82struct CompiledEntry {
83    /// Physical column index,
84    /// relative to `phy_start_idx`.
85    phy_col_start: usize,
86
87    /// Byte offset in the committed row.
88    byte_offset: usize,
89    kind: EntryKind,
90
91    /// True if this entry reuses physical
92    /// columns declared by a prior entry.
93    reuse: bool,
94}
95
96/// Declarative physical->virtual
97/// column expander for chiplets.
98///
99/// Built once per chiplet, generates
100/// `virtual_layout()`, `parse_row()`,
101/// and `expand_variants()` from the
102/// same packing specification.
103#[derive(Clone, Debug)]
104pub struct VirtualExpander {
105    entries: Vec<CompiledEntry>,
106    num_virtual: usize,
107    num_physical: usize,
108    physical_row_bytes: usize,
109    virtual_layout: Vec<ColumnType>,
110    error: Option<Error>,
111}
112
113impl VirtualExpander {
114    pub fn new() -> Self {
115        Self {
116            entries: Vec::new(),
117            num_virtual: 0,
118            num_physical: 0,
119            physical_row_bytes: 0,
120            virtual_layout: Vec::new(),
121            error: None,
122        }
123    }
124
125    /// Finalize the builder. Returns `Err` if any
126    /// builder step recorded a validation error.
127    pub fn build(self) -> Result<Self, Error> {
128        match self.error {
129            Some(e) => Err(e),
130            None => Ok(self),
131        }
132    }
133
134    /// N physical columns of `storage` type
135    /// to N × bit_width virtual Bit columns.
136    pub fn expand_bits(mut self, count: usize, storage: ColumnType) -> Self {
137        if self.error.is_some() {
138            return self;
139        }
140
141        let bits_per = match expand_bit_width(storage) {
142            Ok(v) => v,
143            Err(e) => {
144                self.error = Some(e);
145                return self;
146            }
147        };
148
149        let byte_offset = self.physical_row_bytes;
150        let phy_col_start = self.num_physical;
151
152        self.entries.push(CompiledEntry {
153            phy_col_start,
154            byte_offset,
155            kind: EntryKind::ExpandBits { count, storage },
156            reuse: false,
157        });
158
159        let virt_count = count * bits_per;
160        self.virtual_layout
161            .extend(repeat_n(ColumnType::Bit, virt_count));
162
163        self.num_virtual += virt_count;
164        self.num_physical += count;
165        self.physical_row_bytes += count * storage.byte_size();
166
167        self
168    }
169
170    /// N physical columns pass through
171    /// 1:1 as virtual columns.
172    pub fn pass_through(mut self, count: usize, storage: ColumnType) -> Self {
173        let byte_offset = self.physical_row_bytes;
174        let phy_col_start = self.num_physical;
175
176        self.entries.push(CompiledEntry {
177            phy_col_start,
178            byte_offset,
179            kind: EntryKind::PassThrough { count, storage },
180            reuse: false,
181        });
182
183        self.virtual_layout.extend(repeat_n(storage, count));
184
185        self.num_virtual += count;
186        self.num_physical += count;
187        self.physical_row_bytes += count * storage.byte_size();
188
189        self
190    }
191
192    /// N physical Bit columns pass through 1:1.
193    pub fn control_bits(mut self, count: usize) -> Self {
194        let byte_offset = self.physical_row_bytes;
195        let phy_col_start = self.num_physical;
196
197        self.entries.push(CompiledEntry {
198            phy_col_start,
199            byte_offset,
200            kind: EntryKind::ControlBits { count },
201            reuse: false,
202        });
203
204        self.virtual_layout.extend(repeat_n(ColumnType::Bit, count));
205
206        self.num_virtual += count;
207        self.num_physical += count;
208        self.physical_row_bytes += count;
209
210        self
211    }
212
213    /// Emit pass-through for columns
214    /// already declared by a prior
215    /// fresh entry. Does not advance
216    /// the physical cursor.
217    pub fn reuse_pass_through(mut self, phy_col_start: usize, count: usize) -> Self {
218        if self.error.is_some() {
219            return self;
220        }
221
222        if phy_col_start + count > self.num_physical {
223            self.error = Some(Error::Protocol {
224                protocol: "virtual_expand",
225                message: "reuse_pass_through: range exceeds declared physical columns",
226            });
227            return self;
228        }
229
230        let (byte_offset, storage) = match self.find_phy_source(phy_col_start, count) {
231            Ok(v) => v,
232            Err(e) => {
233                self.error = Some(e);
234                return self;
235            }
236        };
237
238        self.entries.push(CompiledEntry {
239            phy_col_start,
240            byte_offset,
241            kind: EntryKind::PassThrough { count, storage },
242            reuse: true,
243        });
244
245        self.virtual_layout.extend(repeat_n(storage, count));
246
247        self.num_virtual += count;
248
249        self
250    }
251
252    #[inline]
253    pub fn num_virtual_columns(&self) -> usize {
254        self.num_virtual
255    }
256
257    #[inline]
258    pub fn num_physical_columns(&self) -> usize {
259        self.num_physical
260    }
261
262    #[inline]
263    pub fn physical_row_bytes(&self) -> usize {
264        self.physical_row_bytes
265    }
266
267    #[inline]
268    pub fn virtual_layout(&self) -> &[ColumnType] {
269        &self.virtual_layout
270    }
271
272    /// Verifier-side:
273    /// parse committed physical row bytes
274    /// into virtual field elements.
275    pub fn parse_row<F: TraceCompatibleField>(
276        &self,
277        bytes: &[u8],
278        res: &mut Vec<Flat<F>>,
279    ) -> Result<(), Error> {
280        if bytes.len() != self.physical_row_bytes {
281            return Err(Error::Protocol {
282                protocol: "virtual_expand",
283                message: "parse_row: byte slice length mismatch",
284            });
285        }
286
287        res.reserve(self.num_virtual);
288
289        for entry in &self.entries {
290            let off = entry.byte_offset;
291            match entry.kind {
292                EntryKind::ExpandBits { count, storage } => {
293                    let bsz = storage.byte_size();
294                    let bits = expand_bit_width(storage)?;
295
296                    for i in 0..count {
297                        let start = off + i * bsz;
298                        for bit_idx in 0..bits {
299                            let bit = parse_tower_bit(storage, &bytes[start..start + bsz], bit_idx);
300                            res.push(Flat::from_raw(F::from(Bit::from(bit))));
301                        }
302                    }
303                }
304                EntryKind::PassThrough { count, storage } => {
305                    let bsz = storage.byte_size();
306                    for i in 0..count {
307                        let start = off + i * bsz;
308                        res.push(storage.parse_from_bytes(&bytes[start..start + bsz]));
309                    }
310                }
311                EntryKind::ControlBits { count } => {
312                    for i in 0..count {
313                        res.push(Flat::from_raw(F::from(Bit::from(bytes[off + i] & 1))));
314                    }
315                }
316            }
317        }
318
319        Ok(())
320    }
321
322    /// Prover-side:
323    /// expand physical `ColumnTrace`
324    /// into virtual `PolyVariant`s.
325    pub fn expand_variants<'a, F, T: Trace + ?Sized>(
326        &self,
327        trace: &'a T,
328        phy_start_idx: usize,
329    ) -> Result<Vec<PolyVariant<'a, F>>, Error>
330    where
331        F: TraceCompatibleField + 'static,
332    {
333        let columns = trace.columns();
334
335        let mut variants = Vec::with_capacity(self.num_virtual);
336        for entry in &self.entries {
337            let base = phy_start_idx + entry.phy_col_start;
338            match entry.kind {
339                EntryKind::ExpandBits { count, storage } => {
340                    let bits = expand_bit_width(storage)?;
341                    for i in 0..count {
342                        let col = columns.get(base + i).ok_or(Error::Protocol {
343                            protocol: "virtual_expand",
344                            message: "missing physical column for ExpandBits",
345                        })?;
346
347                        for bit_idx in 0..bits {
348                            variants.push(expand_packed_bit(col, storage, bit_idx)?);
349                        }
350                    }
351                }
352                EntryKind::PassThrough { count, storage } => {
353                    for i in 0..count {
354                        let col = columns.get(base + i).ok_or(Error::Protocol {
355                            protocol: "virtual_expand",
356                            message: "missing physical column for PassThrough",
357                        })?;
358
359                        variants.push(expand_pass_through(col, storage)?);
360                    }
361                }
362                EntryKind::ControlBits { count } => {
363                    for i in 0..count {
364                        let col = columns.get(base + i).ok_or(Error::Protocol {
365                            protocol: "virtual_expand",
366                            message: "missing physical column for ControlBits",
367                        })?;
368                        let data = col.as_bit_slice().ok_or(Error::Protocol {
369                            protocol: "virtual_expand",
370                            message: "control column must be Bit",
371                        })?;
372
373                        variants.push(PolyVariant::BitSlice(data));
374                    }
375                }
376            }
377        }
378
379        Ok(variants)
380    }
381
382    /// Wire-format serialization descriptor.
383    pub fn expansion_entries(&self) -> Vec<ExpansionEntry> {
384        self.entries
385            .iter()
386            .map(|e| match (e.kind, e.reuse) {
387                (EntryKind::PassThrough { count, storage }, true) => {
388                    ExpansionEntry::ReusePassThrough {
389                        phy_col_start: e.phy_col_start,
390                        count,
391                        storage,
392                    }
393                }
394                (EntryKind::ExpandBits { count, storage }, _) => {
395                    ExpansionEntry::ExpandBits { count, storage }
396                }
397                (EntryKind::PassThrough { count, storage }, false) => {
398                    ExpansionEntry::PassThrough { count, storage }
399                }
400                (EntryKind::ControlBits { count }, _) => ExpansionEntry::ControlBits { count },
401            })
402            .collect()
403    }
404
405    // Fresh entries have phy_col_start == running_phy;
406    // reuse entries point backward.
407    fn find_phy_source(
408        &self,
409        target_start: usize,
410        target_count: usize,
411    ) -> Result<(usize, ColumnType), Error> {
412        let mut running_phy = 0usize;
413        for entry in &self.entries {
414            if entry.phy_col_start != running_phy {
415                continue;
416            }
417
418            let entry_count = entry.kind.count();
419            let entry_end = running_phy + entry_count;
420
421            if target_start >= running_phy && target_start + target_count <= entry_end {
422                let storage = entry.kind.storage();
423                let offset_in_entry = target_start - running_phy;
424
425                return Ok((
426                    entry.byte_offset + offset_in_entry * storage.byte_size(),
427                    storage,
428                ));
429            }
430
431            running_phy = entry_end;
432        }
433
434        Err(Error::Protocol {
435            protocol: "virtual_expand",
436            message: "reuse_pass_through: source columns not found in any fresh entry",
437        })
438    }
439}
440
441impl Default for VirtualExpander {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447fn expand_bit_width(storage: ColumnType) -> Result<usize, Error> {
448    match storage {
449        ColumnType::B8 => Ok(8),
450        ColumnType::B16 => Ok(16),
451        ColumnType::B32 => Ok(32),
452        ColumnType::B64 => Ok(64),
453        _ => Err(Error::Protocol {
454            protocol: "virtual_expand",
455            message: "ExpandBits requires B8/B16/B32/B64",
456        }),
457    }
458}
459
460/// Tower-basis bit extraction from LE bytes.
461fn parse_tower_bit(storage: ColumnType, bytes: &[u8], bit_idx: usize) -> u8 {
462    match storage {
463        ColumnType::B8 => Flat::from_raw(Block8(bytes[0])).tower_bit(bit_idx),
464        ColumnType::B16 => {
465            let mut arr = [0u8; 2];
466            arr.copy_from_slice(bytes);
467
468            Flat::from_raw(Block16(u16::from_le_bytes(arr))).tower_bit(bit_idx)
469        }
470        ColumnType::B32 => {
471            let mut arr = [0u8; 4];
472            arr.copy_from_slice(bytes);
473
474            Flat::from_raw(Block32(u32::from_le_bytes(arr))).tower_bit(bit_idx)
475        }
476        ColumnType::B64 => {
477            let mut arr = [0u8; 8];
478            arr.copy_from_slice(bytes);
479
480            Flat::from_raw(Block64(u64::from_le_bytes(arr))).tower_bit(bit_idx)
481        }
482        _ => unreachable!(),
483    }
484}
485
486fn expand_packed_bit<F: TraceCompatibleField + 'static>(
487    col: &'_ TraceColumn,
488    storage: ColumnType,
489    bit_idx: usize,
490) -> Result<PolyVariant<'_, F>, Error> {
491    match storage {
492        ColumnType::B8 => {
493            let data = col.as_b8_slice().ok_or(Error::Protocol {
494                protocol: "virtual_expand",
495                message: "ExpandBits B8: column type mismatch",
496            })?;
497
498            Ok(PolyVariant::PackedBitB8 { data, bit_idx })
499        }
500        ColumnType::B16 => {
501            let data = col.as_b16_slice().ok_or(Error::Protocol {
502                protocol: "virtual_expand",
503                message: "ExpandBits B16: column type mismatch",
504            })?;
505
506            Ok(PolyVariant::PackedBitB16 { data, bit_idx })
507        }
508        ColumnType::B32 => {
509            let data = col.as_b32_slice().ok_or(Error::Protocol {
510                protocol: "virtual_expand",
511                message: "ExpandBits B32: column type mismatch",
512            })?;
513
514            Ok(PolyVariant::PackedBitB32 { data, bit_idx })
515        }
516        ColumnType::B64 => {
517            let data = col.as_b64_slice().ok_or(Error::Protocol {
518                protocol: "virtual_expand",
519                message: "ExpandBits B64: column type mismatch",
520            })?;
521
522            Ok(PolyVariant::PackedBitB64 { data, bit_idx })
523        }
524        _ => unreachable!(),
525    }
526}
527
528fn expand_pass_through<F: TraceCompatibleField + 'static>(
529    col: &TraceColumn,
530    storage: ColumnType,
531) -> Result<PolyVariant<'_, F>, Error> {
532    match storage {
533        ColumnType::Bit => {
534            let data = col.as_bit_slice().ok_or(Error::Protocol {
535                protocol: "virtual_expand",
536                message: "PassThrough Bit: column type mismatch",
537            })?;
538
539            Ok(PolyVariant::BitSlice(data))
540        }
541        ColumnType::B8 => {
542            let data = col.as_b8_slice().ok_or(Error::Protocol {
543                protocol: "virtual_expand",
544                message: "PassThrough B8: column type mismatch",
545            })?;
546
547            Ok(PolyVariant::B8Slice(data))
548        }
549        ColumnType::B16 => {
550            let data = col.as_b16_slice().ok_or(Error::Protocol {
551                protocol: "virtual_expand",
552                message: "PassThrough B16: column type mismatch",
553            })?;
554
555            Ok(PolyVariant::B16Slice(data))
556        }
557        ColumnType::B32 => {
558            let data = col.as_b32_slice().ok_or(Error::Protocol {
559                protocol: "virtual_expand",
560                message: "PassThrough B32: column type mismatch",
561            })?;
562
563            Ok(PolyVariant::B32Slice(data))
564        }
565        ColumnType::B64 => {
566            let data = col.as_b64_slice().ok_or(Error::Protocol {
567                protocol: "virtual_expand",
568                message: "PassThrough B64: column type mismatch",
569            })?;
570
571            Ok(PolyVariant::B64Slice(data))
572        }
573        ColumnType::B128 => {
574            let data = col.as_b128_slice().ok_or(Error::Protocol {
575                protocol: "virtual_expand",
576                message: "PassThrough B128: column type mismatch",
577            })?;
578
579            Ok(PolyVariant::B128Slice(data))
580        }
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use hekate_core::trace::TraceBuilder;
588    use hekate_math::Block128;
589
590    #[test]
591    fn ram_layout() {
592        let e = VirtualExpander::new()
593            .expand_bits(2, ColumnType::B32)
594            .pass_through(13, ColumnType::B32)
595            .pass_through(1, ColumnType::B128)
596            .control_bits(4)
597            .build()
598            .unwrap();
599
600        assert_eq!(e.num_virtual_columns(), 82);
601        assert_eq!(e.num_physical_columns(), 20);
602        assert_eq!(e.physical_row_bytes(), 80);
603
604        let layout = e.virtual_layout();
605        assert_eq!(layout.len(), 82);
606        assert!(layout[..64].iter().all(|&t| t == ColumnType::Bit));
607        assert!(layout[64..77].iter().all(|&t| t == ColumnType::B32));
608        assert_eq!(layout[77], ColumnType::B128);
609        assert!(layout[78..82].iter().all(|&t| t == ColumnType::Bit));
610    }
611
612    #[test]
613    fn keccak_layout() {
614        let e = VirtualExpander::new()
615            .expand_bits(25, ColumnType::B64)
616            .expand_bits(1, ColumnType::B64)
617            .reuse_pass_through(0, 25)
618            .control_bits(2)
619            .build()
620            .unwrap();
621
622        assert_eq!(e.num_virtual_columns(), 1691);
623        assert_eq!(e.num_physical_columns(), 28);
624        assert_eq!(e.physical_row_bytes(), 210);
625
626        let layout = e.virtual_layout();
627        assert_eq!(layout.len(), 1691);
628        assert!(layout[..1600].iter().all(|&t| t == ColumnType::Bit));
629        assert!(layout[1600..1664].iter().all(|&t| t == ColumnType::Bit));
630        assert!(layout[1664..1689].iter().all(|&t| t == ColumnType::B64));
631        assert!(layout[1689..1691].iter().all(|&t| t == ColumnType::Bit));
632    }
633
634    #[test]
635    fn reuse_partial_range() {
636        let e = VirtualExpander::new()
637            .expand_bits(10, ColumnType::B32)
638            .reuse_pass_through(3, 4)
639            .build()
640            .unwrap();
641
642        assert_eq!(e.num_virtual_columns(), 324);
643        assert_eq!(e.num_physical_columns(), 10);
644        assert_eq!(e.physical_row_bytes(), 40);
645
646        let layout = e.virtual_layout();
647        assert_eq!(layout[320..324].len(), 4);
648        assert!(layout[320..324].iter().all(|&t| t == ColumnType::B32));
649    }
650
651    #[test]
652    fn reuse_exceeds_declared() {
653        let result = VirtualExpander::new()
654            .expand_bits(5, ColumnType::B32)
655            .reuse_pass_through(3, 5)
656            .build();
657        assert!(result.is_err());
658    }
659
660    #[test]
661    fn expand_rejects_bit() {
662        let result = VirtualExpander::new()
663            .expand_bits(1, ColumnType::Bit)
664            .build();
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn expand_rejects_b128() {
670        let result = VirtualExpander::new()
671            .expand_bits(1, ColumnType::B128)
672            .build();
673        assert!(result.is_err());
674    }
675
676    #[test]
677    fn empty_expander() {
678        let e = VirtualExpander::new();
679        assert_eq!(e.num_virtual_columns(), 0);
680        assert_eq!(e.num_physical_columns(), 0);
681        assert_eq!(e.physical_row_bytes(), 0);
682        assert!(e.virtual_layout().is_empty());
683    }
684
685    #[test]
686    fn parse_row_b32_roundtrip() {
687        let expander = VirtualExpander::new()
688            .expand_bits(1, ColumnType::B32)
689            .pass_through(1, ColumnType::B32)
690            .control_bits(1)
691            .build()
692            .unwrap();
693
694        let val: u32 = 0xDEAD_BEEF;
695        let pass_val: u32 = 0x1234_5678;
696
697        let mut bytes = Vec::new();
698        bytes.extend_from_slice(&val.to_le_bytes());
699        bytes.extend_from_slice(&pass_val.to_le_bytes());
700        bytes.push(1);
701
702        let mut res: Vec<Flat<Block128>> = Vec::new();
703        expander.parse_row(&bytes, &mut res).unwrap();
704
705        assert_eq!(res.len(), 34);
706
707        for (bit_idx, elem) in res.iter().enumerate().take(32) {
708            let expected = Flat::from_raw(Block32(val)).tower_bit(bit_idx);
709            let got = elem.tower_bit(0);
710            assert_eq!(got, expected, "bit {bit_idx} mismatch");
711        }
712
713        let pass = res[32];
714        assert_eq!(
715            pass,
716            <Block128 as hekate_math::FlatPromote<Block32>>::promote_flat(Flat::from_raw(Block32(
717                pass_val
718            )))
719        );
720
721        let ctrl = res[33].tower_bit(0);
722        assert_eq!(ctrl, 1);
723    }
724
725    #[test]
726    fn expand_variants_b32() {
727        let expander = VirtualExpander::new()
728            .expand_bits(1, ColumnType::B32)
729            .pass_through(1, ColumnType::B32)
730            .control_bits(1)
731            .build()
732            .unwrap();
733
734        let layout = [ColumnType::B32, ColumnType::B32, ColumnType::Bit];
735        let num_vars = 2;
736
737        let mut tb = TraceBuilder::new(&layout, num_vars).unwrap();
738        tb.set_b32(0, 0, Block32(0xAAAA_BBBB)).unwrap();
739        tb.set_b32(1, 0, Block32(0x1111_2222)).unwrap();
740        tb.set_bit(2, 0, Bit(1)).unwrap();
741
742        let trace = tb.build();
743
744        let variants: Vec<PolyVariant<'_, Block128>> = expander.expand_variants(&trace, 0).unwrap();
745
746        assert_eq!(variants.len(), 34);
747
748        for (i, v) in variants.iter().enumerate().take(32) {
749            assert!(matches!(v, PolyVariant::PackedBitB32 { bit_idx, .. } if *bit_idx == i));
750        }
751
752        assert!(matches!(variants[32], PolyVariant::B32Slice(_)));
753        assert!(matches!(variants[33], PolyVariant::BitSlice(_)));
754    }
755}