ekzg_serialization/
lib.rs

1pub mod constants;
2pub mod errors;
3pub mod types;
4
5use bls12_381::{G1Point, Scalar};
6use constants::{
7    BYTES_PER_BLOB, BYTES_PER_CELL, BYTES_PER_FIELD_ELEMENT, BYTES_PER_G1_POINT,
8    CELLS_PER_EXT_BLOB, FIELD_ELEMENTS_PER_CELL,
9};
10use types::*;
11
12pub use crate::errors::Error as SerializationError;
13
14/// Deserializes a byte slice into a vector of `Scalar`s.
15///
16/// The input must be a multiple of the scalar size (32 bytes).
17fn deserialize_bytes_to_scalars(bytes: &[u8]) -> Result<Vec<Scalar>, SerializationError> {
18    // Check that the bytes are a multiple of the scalar size
19    if bytes.len() % BYTES_PER_FIELD_ELEMENT != 0 {
20        return Err(SerializationError::ScalarHasInvalidLength {
21            length: bytes.len(),
22            bytes: bytes.to_vec(),
23        });
24    }
25
26    bytes
27        .chunks_exact(BYTES_PER_FIELD_ELEMENT)
28        .map(deserialize_bytes_to_scalar)
29        .collect()
30}
31
32/// Deserializes a blob into a vector of `Scalar`s.
33///
34/// The blob must be exactly `BYTES_PER_BLOB` long (4096 field elements).
35/// Returns an error if the length is incorrect or parsing fails.
36pub fn deserialize_blob_to_scalars(blob_bytes: &[u8]) -> Result<Vec<Scalar>, SerializationError> {
37    if blob_bytes.len() != BYTES_PER_BLOB {
38        return Err(SerializationError::BlobHasInvalidLength {
39            length: blob_bytes.len(),
40            bytes: blob_bytes.to_vec(),
41        });
42    }
43    deserialize_bytes_to_scalars(blob_bytes)
44}
45
46/// Deserializes a 32-byte slice into a single `Scalar`.
47///
48/// This expects the input to be exactly 32 bytes.
49/// Fails if the bytes do not correspond to a valid field element.
50pub fn deserialize_bytes_to_scalar(scalar_bytes: &[u8]) -> Result<Scalar, SerializationError> {
51    let bytes32 = scalar_bytes.try_into().expect("infallible: expected blob chunks to be exactly {SCALAR_SERIALIZED_SIZE} bytes, since blob was a multiple of {SCALAR_SERIALIZED_SIZE");
52
53    // Convert the CtOption into Option
54    let option_scalar: Option<Scalar> = Scalar::from_bytes_be(bytes32).into();
55    option_scalar.map_or_else(
56        || {
57            Err(SerializationError::CouldNotDeserializeScalar {
58                bytes: scalar_bytes.to_vec(),
59            })
60        },
61        Ok,
62    )
63}
64
65/// Converts a compressed G1 point (48 bytes) to a `G1Point`.
66///
67/// Returns an error if the length is incorrect or the bytes are invalid.
68/// Wraps the `from_compressed` function from the BLS12-381 crate.
69pub fn deserialize_compressed_g1(point_bytes: &[u8]) -> Result<G1Point, SerializationError> {
70    let Ok(point_bytes) = point_bytes.try_into() else {
71        return Err(SerializationError::G1PointHasInvalidLength {
72            length: point_bytes.len(),
73            bytes: point_bytes.to_vec(),
74        });
75    };
76
77    let opt_g1: Option<G1Point> = Option::from(G1Point::from_compressed(point_bytes));
78    opt_g1.ok_or_else(|| SerializationError::CouldNotDeserializeG1Point {
79        bytes: point_bytes.to_vec(),
80    })
81}
82
83/// Serializes a G1 point into its compressed representation.
84pub fn serialize_g1_compressed(point: &G1Point) -> [u8; BYTES_PER_G1_POINT] {
85    point.to_compressed()
86}
87
88/// Deserializes a list of compressed G1 point byte slices.
89///
90/// Returns a vector of `G1Point`s or fails on the first invalid point.
91/// Each input slice must be exactly 48 bytes.
92pub fn deserialize_compressed_g1_points(
93    points: Vec<&[u8; BYTES_PER_G1_POINT]>,
94) -> Result<Vec<G1Point>, SerializationError> {
95    points
96        .into_iter()
97        .map(|point| deserialize_compressed_g1(point))
98        .collect()
99}
100
101/// Serializes a slice of `Scalar`s into a byte vector representing a cell.
102///
103/// The input must be exactly `FIELD_ELEMENTS_PER_CELL` elements long.
104/// Produces a flat byte array suitable for storage or transmission.
105pub(crate) fn serialize_scalars_to_cell(scalars: &[Scalar]) -> Vec<u8> {
106    assert_eq!(
107        scalars.len(),
108        FIELD_ELEMENTS_PER_CELL,
109        "must have exactly {FIELD_ELEMENTS_PER_CELL} scalars to serialize to a cell"
110    );
111
112    scalars.iter().flat_map(Scalar::to_bytes_be).collect()
113}
114
115/// Deserializes a vector of cell byte slices into vectors of `Scalar`s.
116///
117/// Each cell must be `BYTES_PER_CELL` bytes long.
118/// Returns an error if parsing any cell fails.
119pub fn deserialize_cells(
120    cells: Vec<&[u8; BYTES_PER_CELL]>,
121) -> Result<Vec<Vec<Scalar>>, SerializationError> {
122    cells
123        .into_iter()
124        .map(|c| deserialize_bytes_to_scalars(c))
125        .collect()
126}
127
128/// Serializes both cells and corresponding proofs into flat output formats.
129///
130/// Converts evaluation sets to `Cell`s and G1 points to `KZGProof`s.
131/// Expects exactly `CELLS_PER_EXT_BLOB` items in both inputs.
132pub fn serialize_cells_and_proofs(
133    coset_evaluations: &[Vec<Scalar>],
134    proofs: &[G1Point],
135) -> ([Cell; CELLS_PER_EXT_BLOB], [KZGProof; CELLS_PER_EXT_BLOB]) {
136    (
137        serialize_cells(coset_evaluations),
138        std::array::from_fn(|i| proofs[i].to_compressed()),
139    )
140}
141
142/// Serializes a list of evaluation sets into an array of `Cell`s.
143///
144/// Each set must contain exactly `FIELD_ELEMENTS_PER_CELL` scalars.
145/// Returns a fixed-size array with length `CELLS_PER_EXT_BLOB`.
146pub fn serialize_cells(coset_evaluations: &[Vec<Scalar>]) -> [Cell; CELLS_PER_EXT_BLOB] {
147    // Serialize the evaluation sets into `Cell`s.
148    std::array::from_fn(|i| {
149        let evals = &coset_evaluations[i];
150        let bytes = serialize_scalars_to_cell(evals);
151        bytes
152            .into_boxed_slice()
153            .try_into()
154            .expect("infallible: serialized cell must be BYTES_PER_CELL long")
155    })
156}
157
158/// Serialization methods that are used for the trusted setup
159pub mod trusted_setup {
160    use bls12_381::{G1Point, G2Point};
161
162    /// An enum used to specify whether to check that the points are in the correct subgroup
163    #[derive(Debug, Copy, Clone)]
164    pub enum SubgroupCheck {
165        /// Enforce subgroup membership checks during deserialization.
166        Check,
167        /// Skip subgroup checks (use only when inputs are trusted).
168        NoCheck,
169    }
170
171    /// Deserialize G1 points from hex strings without checking that the element
172    /// is in the correct subgroup.
173    pub fn deserialize_g1_points<T: AsRef<str>>(
174        g1_points_hex_str: &[T],
175        check: SubgroupCheck,
176    ) -> Vec<G1Point> {
177        g1_points_hex_str
178            .iter()
179            .map(|hex_str| {
180                let hex_str = hex_str
181                    .as_ref()
182                    .strip_prefix("0x")
183                    .expect("expected hex points to be prefixed with `0x`");
184
185                let bytes = hex::decode(hex_str)
186                    .expect("trusted setup has malformed g1 points")
187                    .try_into()
188                    .expect("expected 48 bytes for G1 point");
189
190                match check {
191                    SubgroupCheck::Check => G1Point::from_compressed(&bytes),
192                    SubgroupCheck::NoCheck => G1Point::from_compressed_unchecked(&bytes),
193                }
194                .expect("invalid g1 point")
195            })
196            .collect()
197    }
198
199    /// Deserialize G2 points from hex strings without checking that the element
200    /// is in the correct subgroup.
201    pub fn deserialize_g2_points<T: AsRef<str>>(
202        g2_points_hex_str: &[T],
203        subgroup_check: SubgroupCheck,
204    ) -> Vec<G2Point> {
205        g2_points_hex_str
206            .iter()
207            .map(|hex_str| {
208                let hex_str = hex_str
209                    .as_ref()
210                    .strip_prefix("0x")
211                    .expect("expected hex points to be prefixed with `0x`");
212
213                let bytes: [u8; 96] = hex::decode(hex_str)
214                    .expect("trusted setup has malformed g2 points")
215                    .try_into()
216                    .expect("expected 96 bytes for G2 point");
217
218                match subgroup_check {
219                    SubgroupCheck::Check => G2Point::from_compressed(&bytes),
220                    SubgroupCheck::NoCheck => G2Point::from_compressed_unchecked(&bytes),
221                }
222                .expect("invalid g2 point")
223            })
224            .collect()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use bls12_381::{traits::*, G1Point, G1Projective, Scalar};
231    use rand::thread_rng;
232
233    use super::*;
234    use crate::constants::FIELD_ELEMENTS_PER_BLOB;
235
236    /// Returns a randomly generated scalar field element.
237    fn random_scalar() -> Scalar {
238        Scalar::random(thread_rng())
239    }
240
241    /// Returns a random scalar serialized to `BYTES_PER_FIELD_ELEMENT` big-endian bytes.
242    fn scalar_bytes() -> [u8; BYTES_PER_FIELD_ELEMENT] {
243        random_scalar().to_bytes_be()
244    }
245
246    /// Constructs a valid blob by repeating a random scalar `FIELD_ELEMENTS_PER_BLOB` times.
247    fn valid_blob() -> Vec<u8> {
248        scalar_bytes().repeat(FIELD_ELEMENTS_PER_BLOB)
249    }
250
251    /// Constructs a valid cell by repeating a random scalar FIELD_ELEMENTS_PER_CELL times.
252    fn valid_cell() -> [u8; BYTES_PER_CELL] {
253        scalar_bytes()
254            .repeat(FIELD_ELEMENTS_PER_CELL)
255            .try_into()
256            .unwrap()
257    }
258
259    #[test]
260    fn test_deserialize_scalar_valid() {
261        let bytes = scalar_bytes();
262        let scalar = deserialize_bytes_to_scalar(&bytes).unwrap();
263        assert_eq!(scalar.to_bytes_be(), bytes);
264    }
265
266    #[test]
267    #[should_panic]
268    fn test_deserialize_scalar_invalid_length() {
269        let bytes = vec![1u8; 31]; // invalid
270        let _ = deserialize_bytes_to_scalar(&bytes);
271    }
272
273    #[test]
274    fn test_deserialize_blob_to_scalars_valid() {
275        let blob = valid_blob();
276        let scalars = deserialize_blob_to_scalars(&blob).unwrap();
277        assert_eq!(scalars.len(), FIELD_ELEMENTS_PER_BLOB);
278    }
279
280    #[test]
281    fn test_deserialize_blob_to_scalars_invalid_length() {
282        let blob = vec![0u8; BYTES_PER_BLOB - 1];
283        assert!(matches!(
284            deserialize_blob_to_scalars(&blob),
285            Err(SerializationError::BlobHasInvalidLength { .. })
286        ));
287    }
288
289    #[test]
290    fn test_deserialize_bytes_to_scalars_valid() {
291        let cell = valid_cell();
292        let scalars = deserialize_bytes_to_scalars(&cell).unwrap();
293        assert_eq!(scalars.len(), FIELD_ELEMENTS_PER_CELL);
294    }
295
296    #[test]
297    fn test_serialize_scalars_to_cell_and_back() {
298        let scalars: Vec<_> = (0..FIELD_ELEMENTS_PER_CELL)
299            .map(|_| random_scalar())
300            .collect();
301        let cell_bytes = serialize_scalars_to_cell(&scalars);
302        let scalars_back = deserialize_bytes_to_scalars(&cell_bytes).unwrap();
303        assert_eq!(scalars, scalars_back);
304    }
305
306    #[test]
307    fn test_serialize_deserialize_g1_point() {
308        let point = G1Point::from(G1Projective::generator());
309        let compressed = point.to_compressed();
310        let decompressed = deserialize_compressed_g1(&compressed).unwrap();
311        assert_eq!(G1Point::from(decompressed), point);
312    }
313
314    #[test]
315    fn test_deserialize_compressed_g1_invalid_length() {
316        let bad_bytes = vec![0u8; 47];
317        assert!(matches!(
318            deserialize_compressed_g1(&bad_bytes),
319            Err(SerializationError::G1PointHasInvalidLength { .. })
320        ));
321    }
322
323    #[test]
324    fn test_coset_evaluations_to_cells() {
325        let evaluations: Vec<_> = (0..CELLS_PER_EXT_BLOB)
326            .map(|_| {
327                (0..FIELD_ELEMENTS_PER_CELL)
328                    .map(|_| random_scalar())
329                    .collect::<Vec<_>>()
330            })
331            .collect();
332        let cells = serialize_cells(&evaluations);
333        assert_eq!(cells.len(), CELLS_PER_EXT_BLOB);
334        for cell in &cells {
335            assert_eq!(cell.len(), BYTES_PER_CELL);
336        }
337    }
338
339    #[test]
340    fn test_serialize_cells_and_proofs() {
341        let evaluations: Vec<_> = (0..CELLS_PER_EXT_BLOB)
342            .map(|_| {
343                (0..FIELD_ELEMENTS_PER_CELL)
344                    .map(|_| random_scalar())
345                    .collect::<Vec<_>>()
346            })
347            .collect();
348        let proofs: Vec<_> = (0..CELLS_PER_EXT_BLOB)
349            .map(|_| G1Point::from(G1Projective::generator()))
350            .collect();
351
352        let (cells, proofs) = serialize_cells_and_proofs(&evaluations, &proofs);
353        assert_eq!(cells.len(), CELLS_PER_EXT_BLOB);
354        assert_eq!(proofs.len(), CELLS_PER_EXT_BLOB);
355    }
356}