axiom_circuit/input/
flatten.rs1use anyhow::Result;
2use axiom_codec::HiLo;
3use serde::{Deserialize, Serialize};
4
5use crate::{impl_input_flatten_for_fixed_array, impl_input_flatten_for_tuple};
6
7pub trait InputFlatten<T: Copy>: Sized {
8 const NUM_FE: usize;
9 fn flatten_vec(&self) -> Vec<T>;
10 fn unflatten(vec: Vec<T>) -> Result<Self>;
11}
12
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
14pub struct FixLenVec<T: Copy + Default, const N: usize>(pub Vec<T>);
15
16impl<T: Copy + Default, const N: usize> Default for FixLenVec<T, N> {
17 fn default() -> Self {
18 Self(vec![T::default(); N])
19 }
20}
21
22impl<T: Copy + Default, const N: usize> FixLenVec<T, N> {
23 pub fn new(vec: Vec<T>) -> anyhow::Result<Self> {
24 if vec.len() != N {
25 anyhow::bail!("Invalid input length: {} != {}", vec.len(), N);
26 }
27 Ok(FixLenVec(vec))
28 }
29
30 pub fn into_inner(self) -> Vec<T> {
31 self.0
32 }
33}
34
35impl<T: Copy + Default, const N: usize> From<Vec<T>> for FixLenVec<T, N> {
36 fn from(vec: Vec<T>) -> Self {
37 Self(vec)
38 }
39}
40
41macro_rules! check_input_length {
42 ($vec:ident) => {
43 if $vec.len() != <Self as InputFlatten<T>>::NUM_FE {
44 anyhow::bail!(
45 "Invalid input length: {} != {}",
46 $vec.len(),
47 <Self as InputFlatten<T>>::NUM_FE
48 );
49 }
50 };
51}
52
53impl<T: Copy + Default, const N: usize> InputFlatten<T> for FixLenVec<T, N> {
54 const NUM_FE: usize = N;
55 fn flatten_vec(&self) -> Vec<T> {
56 self.0.clone()
57 }
58 fn unflatten(vec: Vec<T>) -> Result<Self> {
59 check_input_length!(vec);
60 Ok(FixLenVec(vec))
61 }
62}
63
64impl<T: Copy> InputFlatten<T> for HiLo<T> {
65 const NUM_FE: usize = 2;
66 fn flatten_vec(&self) -> Vec<T> {
67 vec![self.hi(), self.lo()]
68 }
69 fn unflatten(vec: Vec<T>) -> Result<Self> {
70 check_input_length!(vec);
71 Ok(HiLo::from_hi_lo([vec[0], vec[1]]))
72 }
73}
74
75impl_input_flatten_for_tuple!(HiLo<T>, HiLo<T>);
76
77impl<T: Copy> InputFlatten<T> for T {
78 const NUM_FE: usize = 1;
79 fn flatten_vec(&self) -> Vec<T> {
80 vec![*self]
81 }
82 fn unflatten(vec: Vec<T>) -> Result<Self> {
83 check_input_length!(vec);
84 Ok(vec[0])
85 }
86}
87
88impl_input_flatten_for_fixed_array!(T);
89impl_input_flatten_for_fixed_array!(HiLo<T>);
90
91#[macro_export]
92macro_rules! impl_input_flatten_for_tuple {
93 ($type1:ty, $type2:ty) => {
94 impl<T: Copy> InputFlatten<T> for ($type1, $type2)
95 where
96 $type1: InputFlatten<T>,
97 $type2: InputFlatten<T>,
98 {
99 const NUM_FE: usize = <$type1>::NUM_FE + <$type2>::NUM_FE;
100
101 fn flatten_vec(&self) -> Vec<T> {
102 let mut first_vec = self.0.flatten_vec();
103 first_vec.extend(self.1.flatten_vec());
104 first_vec
105 }
106
107 fn unflatten(vec: Vec<T>) -> anyhow::Result<Self> {
108 check_input_length!(vec);
109 let (first_part, second_part) = vec.split_at(<$type1>::NUM_FE);
110 let first = <$type1>::unflatten(first_part.to_vec())?;
111 let second = <$type2>::unflatten(second_part.to_vec())?;
112 Ok((first, second))
113 }
114 }
115 };
116}
117
118#[macro_export]
119macro_rules! impl_input_flatten_for_fixed_array {
120 ($type1:ty) => {
121 impl<T: Copy, const N: usize> InputFlatten<T> for [$type1; N]
122 where
123 $type1: InputFlatten<T>,
124 {
125 const NUM_FE: usize = <$type1>::NUM_FE * N;
126
127 fn flatten_vec(&self) -> Vec<T> {
128 self.to_vec()
129 .iter()
130 .map(|x| x.flatten_vec())
131 .flatten()
132 .collect()
133 }
134
135 fn unflatten(vec: Vec<T>) -> anyhow::Result<Self> {
136 check_input_length!(vec);
137 let res = vec
138 .chunks(<$type1>::NUM_FE)
139 .into_iter()
140 .map(|x| <$type1>::unflatten(x.to_vec()).unwrap())
141 .collect::<Vec<_>>();
142 let mut array = [res[0]; N];
143 for (i, item) in res.into_iter().enumerate() {
144 array[i] = item;
145 }
146 Ok(array)
147 }
148 }
149 };
150}