celestia_types/
row.rs

1//! Types related to rows
2//!
3//! Row in Celestia is understood as all the [`Share`]s in a particular
4//! row of the [`ExtendedDataSquare`].
5//!
6//! [`Share`]: crate::Share
7//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
8
9use std::iter;
10
11use blockstore::block::CidError;
12use bytes::{Buf, BufMut, BytesMut};
13use celestia_proto::shwap::{row::HalfSide as RawHalfSide, Row as RawRow, Share as RawShare};
14use cid::CidGeneric;
15use multihash::Multihash;
16use prost::Message;
17use serde::Serialize;
18
19use crate::consts::appconsts::SHARE_SIZE;
20use crate::eds::ExtendedDataSquare;
21use crate::nmt::{Nmt, NmtExt};
22use crate::{DataAvailabilityHeader, Error, Result, Share};
23
24/// Number of bytes needed to represent [`EdsId`] in `multihash`.
25const EDS_ID_SIZE: usize = 8;
26/// Number of bytes needed to represent [`RowId`] in `multihash`.
27pub(crate) const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
28/// The code of the [`RowId`] hashing algorithm in `multihash`.
29pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
30/// The id of codec used for the [`RowId`] in `Cid`s.
31pub const ROW_ID_CODEC: u64 = 0x7800;
32
33/// Represents an EDS of a specific Height
34///
35/// # Note
36///
37/// EdsId is excluded from shwap operating on top of bitswap due to possible
38/// EDS sizes exceeding bitswap block limits.
39#[derive(Debug, PartialEq, Clone, Copy)]
40struct EdsId {
41    height: u64,
42}
43
44/// Represents particular row in a specific Data Square,
45#[derive(Debug, PartialEq, Clone, Copy)]
46pub struct RowId {
47    eds_id: EdsId,
48    index: u16,
49}
50
51/// Row together with the data
52#[derive(Clone, Debug, Serialize)]
53#[serde(into = "RawRow")]
54pub struct Row {
55    /// Shares contained in the row
56    pub shares: Vec<Share>,
57}
58
59impl Row {
60    /// Create Row with the given index from EDS
61    pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
62        let shares = eds.row(index)?;
63
64        Ok(Row { shares })
65    }
66
67    /// Verify the row against roots from DAH
68    pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
69        let row = id.index;
70        let mut tree = Nmt::default();
71
72        for share in &self.shares {
73            tree.push_leaf(share.as_ref(), *share.namespace())
74                .map_err(Error::Nmt)?;
75        }
76
77        let Some(root) = dah.row_root(row) else {
78            return Err(Error::EdsIndexOutOfRange(row, 0));
79        };
80
81        if tree.root().hash() != root.hash() {
82            return Err(Error::RootMismatch);
83        }
84
85        Ok(())
86    }
87
88    /// Encode Row into the raw binary representation.
89    pub fn encode(&self, bytes: &mut BytesMut) {
90        let raw = RawRow::from(self.clone());
91
92        bytes.reserve(raw.encoded_len());
93        raw.encode(bytes).expect("capacity reserved");
94    }
95
96    /// Decode Row from the binary representation.
97    ///
98    /// # Errors
99    ///
100    /// This function will return an error if protobuf deserialization
101    /// fails and propagate errors from [`Row::from_raw`].
102    pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
103        let raw = RawRow::decode(buffer)?;
104        Self::from_raw(id, raw)
105    }
106
107    /// Recover Row from it's raw representation, reconstructing the missing half
108    /// using [`leopard_codec`].
109    ///
110    /// # Errors
111    ///
112    /// This function will propagate errors from [`leopard_codec`] and [`Share`] construction.
113    pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
114        let data_shares = row.shares_half.len();
115
116        let shares = match row.half_side() {
117            RawHalfSide::Left => {
118                // We have original data, recompute parity shares
119                let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
120                shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
121                leopard_codec::encode(&mut shares, data_shares)?;
122                shares
123            }
124            RawHalfSide::Right => {
125                // We have parity data, recompute original shares
126                let mut shares: Vec<_> = iter::repeat_n(vec![], data_shares)
127                    .chain(row.shares_half.into_iter().map(|shr| shr.data))
128                    .collect();
129                leopard_codec::reconstruct(&mut shares, data_shares)?;
130                shares
131            }
132        };
133
134        let row_index = id.index() as usize;
135        let shares = shares
136            .into_iter()
137            .enumerate()
138            .map(|(col_index, shr)| {
139                if row_index < data_shares && col_index < data_shares {
140                    Share::from_raw(&shr)
141                } else {
142                    Share::parity(&shr)
143                }
144            })
145            .collect::<Result<_>>()?;
146
147        Ok(Row { shares })
148    }
149}
150
151impl From<Row> for RawRow {
152    fn from(row: Row) -> RawRow {
153        // parity shares aren't transmitted over shwap, just data shares
154        let square_width = row.shares.len();
155        let shares_half = row
156            .shares
157            .into_iter()
158            .map(|shr| RawShare { data: shr.to_vec() })
159            .take(square_width / 2)
160            .collect();
161
162        RawRow {
163            shares_half,
164            half_side: RawHalfSide::Left.into(),
165        }
166    }
167}
168
169impl RowId {
170    /// Create a new [`RowId`] for the particular block.
171    ///
172    /// # Errors
173    ///
174    /// This function will return an error if the block height is invalid.
175    pub fn new(index: u16, height: u64) -> Result<Self> {
176        if height == 0 {
177            return Err(Error::ZeroBlockHeight);
178        }
179
180        Ok(Self {
181            index,
182            eds_id: EdsId { height },
183        })
184    }
185
186    /// A height of the block which contains the data.
187    pub fn block_height(&self) -> u64 {
188        self.eds_id.height
189    }
190
191    /// An index of the row in the [`ExtendedDataSquare`].
192    ///
193    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
194    pub fn index(&self) -> u16 {
195        self.index
196    }
197
198    pub(crate) fn encode(&self, bytes: &mut BytesMut) {
199        bytes.reserve(ROW_ID_SIZE);
200        bytes.put_u64(self.block_height());
201        bytes.put_u16(self.index);
202    }
203
204    pub(crate) fn decode(mut buffer: &[u8]) -> Result<Self, CidError> {
205        if buffer.len() != ROW_ID_SIZE {
206            return Err(CidError::InvalidMultihashLength(buffer.len()));
207        }
208
209        let height = buffer.get_u64();
210        let index = buffer.get_u16();
211
212        if height == 0 {
213            return Err(CidError::InvalidCid("Zero block height".to_string()));
214        }
215
216        Ok(Self {
217            eds_id: EdsId { height },
218            index,
219        })
220    }
221}
222
223impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
224    type Error = CidError;
225
226    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
227        let codec = cid.codec();
228        if codec != ROW_ID_CODEC {
229            return Err(CidError::InvalidCidCodec(codec));
230        }
231
232        let hash = cid.hash();
233
234        let size = hash.size() as usize;
235        if size != ROW_ID_SIZE {
236            return Err(CidError::InvalidMultihashLength(size));
237        }
238
239        let code = hash.code();
240        if code != ROW_ID_MULTIHASH_CODE {
241            return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
242        }
243
244        RowId::decode(hash.digest())
245    }
246}
247
248impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
249    fn from(row: RowId) -> Self {
250        let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
251        row.encode(&mut bytes);
252        // length is correct, so unwrap is safe
253        let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
254
255        CidGeneric::new_v1(ROW_ID_CODEC, mh)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::consts::appconsts::{AppVersion, SHARE_SIZE};
263    use crate::test_utils::{generate_dummy_eds, generate_eds};
264    use crate::Blob;
265
266    #[test]
267    fn round_trip_test() {
268        let row_id = RowId::new(5, 100).unwrap();
269        let cid = CidGeneric::from(row_id);
270
271        let multihash = cid.hash();
272        assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
273        assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
274
275        let deserialized_row_id = RowId::try_from(cid).unwrap();
276        assert_eq!(row_id, deserialized_row_id);
277    }
278
279    #[test]
280    fn index_calculation() {
281        let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
282        let eds = ExtendedDataSquare::new(shares, "codec".to_string(), AppVersion::V2).unwrap();
283
284        Row::new(1, &eds).unwrap();
285        Row::new(7, &eds).unwrap();
286        let row_err = Row::new(8, &eds).unwrap_err();
287        assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
288        let row_err = Row::new(100, &eds).unwrap_err();
289        assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
290    }
291
292    #[test]
293    fn row_id_size() {
294        // Size MUST be 10 by the spec.
295        assert_eq!(ROW_ID_SIZE, 10);
296
297        let row_id = RowId::new(0, 1).unwrap();
298        let mut bytes = BytesMut::new();
299        row_id.encode(&mut bytes);
300        assert_eq!(bytes.len(), ROW_ID_SIZE);
301    }
302
303    #[test]
304    fn from_buffer() {
305        let bytes = [
306            0x01, // CIDv1
307            0x80, 0xF0, 0x01, // CID codec = 7800
308            0x81, 0xF0, 0x01, // multihash code = 7801
309            0x0A, // len = ROW_ID_SIZE = 10
310            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
311            0, 7, // row index = 7
312        ];
313
314        let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
315        assert_eq!(cid.codec(), ROW_ID_CODEC);
316        let mh = cid.hash();
317        assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
318        assert_eq!(mh.size(), ROW_ID_SIZE as u8);
319        let row_id = RowId::try_from(cid).unwrap();
320        assert_eq!(row_id.index, 7);
321        assert_eq!(row_id.block_height(), 64);
322    }
323
324    #[test]
325    fn zero_block_height() {
326        let bytes = [
327            0x01, // CIDv1
328            0x80, 0xF0, 0x01, // CID codec = 7800
329            0x81, 0xF0, 0x01, // code = 7801
330            0x0A, // len = ROW_ID_SIZE = 10
331            0, 0, 0, 0, 0, 0, 0, 0, // invalid block height = 0 !
332            0, 7, // row index = 7
333        ];
334
335        let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
336        assert_eq!(cid.codec(), ROW_ID_CODEC);
337        let mh = cid.hash();
338        assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
339        assert_eq!(mh.size(), ROW_ID_SIZE as u8);
340        let row_err = RowId::try_from(cid).unwrap_err();
341        assert_eq!(
342            row_err,
343            CidError::InvalidCid("Zero block height".to_string())
344        );
345    }
346
347    #[test]
348    fn multihash_invalid_code() {
349        let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
350        let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
351        let row_err = RowId::try_from(cid).unwrap_err();
352        assert_eq!(
353            row_err,
354            CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
355        );
356    }
357
358    #[test]
359    fn cid_invalid_codec() {
360        let multihash =
361            Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
362        let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
363        let row_err = RowId::try_from(cid).unwrap_err();
364        assert_eq!(row_err, CidError::InvalidCidCodec(1234));
365    }
366
367    #[test]
368    fn test_roundtrip_verify() {
369        for _ in 0..5 {
370            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
371            let dah = DataAvailabilityHeader::from_eds(&eds);
372
373            let index = rand::random::<u16>() % eds.square_width();
374            let id = RowId {
375                eds_id: EdsId { height: 1 },
376                index,
377            };
378
379            let row = Row {
380                shares: eds.row(index).unwrap(),
381            };
382
383            let mut buf = BytesMut::new();
384            row.encode(&mut buf);
385            let decoded = Row::decode(id, &buf).unwrap();
386
387            decoded.verify(id, &dah).unwrap();
388        }
389    }
390
391    #[test]
392    fn reconstruct_all() {
393        for _ in 0..3 {
394            let eds = generate_eds(8 << (rand::random::<usize>() % 6), AppVersion::V2);
395
396            let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
397            let blobs = Blob::reconstruct_all(
398                rows.iter().flat_map(|row| row.shares.iter()),
399                AppVersion::V2,
400            )
401            .unwrap();
402
403            assert_eq!(blobs.len(), 2);
404        }
405    }
406}