axiom_circuit/input/
flatten.rs

1use anyhow::Result;
2use axiom_codec::HiLo;
3use serde::{Deserialize, Serialize};
4
5use crate::{impl_input_flatten_for_fixed_array, impl_input_flatten_for_tuple};
6
7pub trait InputFlatten<T: Copy>: Sized {
8    const NUM_FE: usize;
9    fn flatten_vec(&self) -> Vec<T>;
10    fn unflatten(vec: Vec<T>) -> Result<Self>;
11}
12
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
14pub struct FixLenVec<T: Copy + Default, const N: usize>(pub Vec<T>);
15
16impl<T: Copy + Default, const N: usize> Default for FixLenVec<T, N> {
17    fn default() -> Self {
18        Self(vec![T::default(); N])
19    }
20}
21
22impl<T: Copy + Default, const N: usize> FixLenVec<T, N> {
23    pub fn new(vec: Vec<T>) -> anyhow::Result<Self> {
24        if vec.len() != N {
25            anyhow::bail!("Invalid input length: {} != {}", vec.len(), N);
26        }
27        Ok(FixLenVec(vec))
28    }
29
30    pub fn into_inner(self) -> Vec<T> {
31        self.0
32    }
33}
34
35impl<T: Copy + Default, const N: usize> From<Vec<T>> for FixLenVec<T, N> {
36    fn from(vec: Vec<T>) -> Self {
37        Self(vec)
38    }
39}
40
41macro_rules! check_input_length {
42    ($vec:ident) => {
43        if $vec.len() != <Self as InputFlatten<T>>::NUM_FE {
44            anyhow::bail!(
45                "Invalid input length: {} != {}",
46                $vec.len(),
47                <Self as InputFlatten<T>>::NUM_FE
48            );
49        }
50    };
51}
52
53impl<T: Copy + Default, const N: usize> InputFlatten<T> for FixLenVec<T, N> {
54    const NUM_FE: usize = N;
55    fn flatten_vec(&self) -> Vec<T> {
56        self.0.clone()
57    }
58    fn unflatten(vec: Vec<T>) -> Result<Self> {
59        check_input_length!(vec);
60        Ok(FixLenVec(vec))
61    }
62}
63
64impl<T: Copy> InputFlatten<T> for HiLo<T> {
65    const NUM_FE: usize = 2;
66    fn flatten_vec(&self) -> Vec<T> {
67        vec![self.hi(), self.lo()]
68    }
69    fn unflatten(vec: Vec<T>) -> Result<Self> {
70        check_input_length!(vec);
71        Ok(HiLo::from_hi_lo([vec[0], vec[1]]))
72    }
73}
74
75impl_input_flatten_for_tuple!(HiLo<T>, HiLo<T>);
76
77impl<T: Copy> InputFlatten<T> for T {
78    const NUM_FE: usize = 1;
79    fn flatten_vec(&self) -> Vec<T> {
80        vec![*self]
81    }
82    fn unflatten(vec: Vec<T>) -> Result<Self> {
83        check_input_length!(vec);
84        Ok(vec[0])
85    }
86}
87
88impl_input_flatten_for_fixed_array!(T);
89impl_input_flatten_for_fixed_array!(HiLo<T>);
90
91#[macro_export]
92macro_rules! impl_input_flatten_for_tuple {
93    ($type1:ty, $type2:ty) => {
94        impl<T: Copy> InputFlatten<T> for ($type1, $type2)
95        where
96            $type1: InputFlatten<T>,
97            $type2: InputFlatten<T>,
98        {
99            const NUM_FE: usize = <$type1>::NUM_FE + <$type2>::NUM_FE;
100
101            fn flatten_vec(&self) -> Vec<T> {
102                let mut first_vec = self.0.flatten_vec();
103                first_vec.extend(self.1.flatten_vec());
104                first_vec
105            }
106
107            fn unflatten(vec: Vec<T>) -> anyhow::Result<Self> {
108                check_input_length!(vec);
109                let (first_part, second_part) = vec.split_at(<$type1>::NUM_FE);
110                let first = <$type1>::unflatten(first_part.to_vec())?;
111                let second = <$type2>::unflatten(second_part.to_vec())?;
112                Ok((first, second))
113            }
114        }
115    };
116}
117
118#[macro_export]
119macro_rules! impl_input_flatten_for_fixed_array {
120    ($type1:ty) => {
121        impl<T: Copy, const N: usize> InputFlatten<T> for [$type1; N]
122        where
123            $type1: InputFlatten<T>,
124        {
125            const NUM_FE: usize = <$type1>::NUM_FE * N;
126
127            fn flatten_vec(&self) -> Vec<T> {
128                self.to_vec()
129                    .iter()
130                    .map(|x| x.flatten_vec())
131                    .flatten()
132                    .collect()
133            }
134
135            fn unflatten(vec: Vec<T>) -> anyhow::Result<Self> {
136                check_input_length!(vec);
137                let res = vec
138                    .chunks(<$type1>::NUM_FE)
139                    .into_iter()
140                    .map(|x| <$type1>::unflatten(x.to_vec()).unwrap())
141                    .collect::<Vec<_>>();
142                let mut array = [res[0]; N];
143                for (i, item) in res.into_iter().enumerate() {
144                    array[i] = item;
145                }
146                Ok(array)
147            }
148        }
149    };
150}