celestia_types/
row.rs

1//! Types related to rows
2//!
3//! Row in Celestia is understood as all the [`Share`]s in a particular
4//! row of the [`ExtendedDataSquare`].
5//!
6//! [`Share`]: crate::Share
7//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
8
9use std::iter;
10
11use blockstore::block::CidError;
12use bytes::{Buf, BufMut, BytesMut};
13use celestia_proto::shwap::{row::HalfSide as RawHalfSide, Row as RawRow, Share as RawShare};
14use cid::CidGeneric;
15use multihash::Multihash;
16use prost::Message;
17use serde::Serialize;
18
19use crate::consts::appconsts::SHARE_SIZE;
20use crate::eds::ExtendedDataSquare;
21use crate::nmt::{Nmt, NmtExt};
22use crate::{DataAvailabilityHeader, Error, Result, Share};
23
24/// Number of bytes needed to represent [`EdsId`] in `multihash`.
25const EDS_ID_SIZE: usize = 8;
26/// Number of bytes needed to represent [`RowId`] in `multihash`.
27pub(crate) const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
28/// The code of the [`RowId`] hashing algorithm in `multihash`.
29pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
30/// The id of codec used for the [`RowId`] in `Cid`s.
31pub const ROW_ID_CODEC: u64 = 0x7800;
32
33/// Represents an EDS of a specific Height
34///
35/// # Note
36///
37/// EdsId is excluded from shwap operating on top of bitswap due to possible
38/// EDS sizes exceeding bitswap block limits.
39#[derive(Debug, PartialEq, Clone, Copy)]
40struct EdsId {
41    height: u64,
42}
43
44/// Represents particular row in a specific Data Square,
45#[derive(Debug, PartialEq, Clone, Copy)]
46pub struct RowId {
47    eds_id: EdsId,
48    index: u16,
49}
50
51/// Row together with the data
52#[derive(Clone, Debug, Serialize)]
53#[serde(into = "RawRow")]
54pub struct Row {
55    /// Shares contained in the row
56    pub shares: Vec<Share>,
57}
58
59impl Row {
60    /// Create Row with the given index from EDS
61    pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
62        let shares = eds.row(index)?;
63
64        Ok(Row { shares })
65    }
66
67    /// Verify the row against roots from DAH
68    pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
69        let row = id.index;
70        let mut tree = Nmt::default();
71
72        for share in &self.shares {
73            tree.push_leaf(share.as_ref(), *share.namespace())
74                .map_err(Error::Nmt)?;
75        }
76
77        let Some(root) = dah.row_root(row) else {
78            return Err(Error::EdsIndexOutOfRange(row, 0));
79        };
80
81        if tree.root().hash() != root.hash() {
82            return Err(Error::RootMismatch);
83        }
84
85        Ok(())
86    }
87
88    /// Encode Row into the raw binary representation.
89    pub fn encode(&self, bytes: &mut BytesMut) {
90        let raw = RawRow::from(self.clone());
91
92        bytes.reserve(raw.encoded_len());
93        raw.encode(bytes).expect("capacity reserved");
94    }
95
96    /// Decode Row from the binary representation.
97    ///
98    /// # Errors
99    ///
100    /// This function will return an error if protobuf deserialization
101    /// fails and propagate errors from [`Row::from_raw`].
102    pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
103        let raw = RawRow::decode(buffer)?;
104        Self::from_raw(id, raw)
105    }
106
107    /// Recover Row from it's raw representation, reconstructing the missing half
108    /// using [`leopard_codec`].
109    ///
110    /// # Errors
111    ///
112    /// This function will propagate errors from [`leopard_codec`] and [`Share`] construction.
113    pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
114        let data_shares = row.shares_half.len();
115
116        let shares = match row.half_side() {
117            RawHalfSide::Left => {
118                // We have original data, recompute parity shares
119                let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
120                shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
121                leopard_codec::encode(&mut shares, data_shares)?;
122                shares
123            }
124            RawHalfSide::Right => {
125                // We have parity data, recompute original shares
126                let mut shares: Vec<_> = iter::repeat(vec![])
127                    .take(data_shares)
128                    .chain(row.shares_half.into_iter().map(|shr| shr.data))
129                    .collect();
130                leopard_codec::reconstruct(&mut shares, data_shares)?;
131                shares
132            }
133        };
134
135        let row_index = id.index() as usize;
136        let shares = shares
137            .into_iter()
138            .enumerate()
139            .map(|(col_index, shr)| {
140                if row_index < data_shares && col_index < data_shares {
141                    Share::from_raw(&shr)
142                } else {
143                    Share::parity(&shr)
144                }
145            })
146            .collect::<Result<_>>()?;
147
148        Ok(Row { shares })
149    }
150}
151
152impl From<Row> for RawRow {
153    fn from(row: Row) -> RawRow {
154        // parity shares aren't transmitted over shwap, just data shares
155        let square_width = row.shares.len();
156        let shares_half = row
157            .shares
158            .into_iter()
159            .map(|shr| RawShare { data: shr.to_vec() })
160            .take(square_width / 2)
161            .collect();
162
163        RawRow {
164            shares_half,
165            half_side: RawHalfSide::Left.into(),
166        }
167    }
168}
169
170impl RowId {
171    /// Create a new [`RowId`] for the particular block.
172    ///
173    /// # Errors
174    ///
175    /// This function will return an error if the block height is invalid.
176    pub fn new(index: u16, height: u64) -> Result<Self> {
177        if height == 0 {
178            return Err(Error::ZeroBlockHeight);
179        }
180
181        Ok(Self {
182            index,
183            eds_id: EdsId { height },
184        })
185    }
186
187    /// A height of the block which contains the data.
188    pub fn block_height(&self) -> u64 {
189        self.eds_id.height
190    }
191
192    /// An index of the row in the [`ExtendedDataSquare`].
193    ///
194    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
195    pub fn index(&self) -> u16 {
196        self.index
197    }
198
199    pub(crate) fn encode(&self, bytes: &mut BytesMut) {
200        bytes.reserve(ROW_ID_SIZE);
201        bytes.put_u64(self.block_height());
202        bytes.put_u16(self.index);
203    }
204
205    pub(crate) fn decode(mut buffer: &[u8]) -> Result<Self, CidError> {
206        if buffer.len() != ROW_ID_SIZE {
207            return Err(CidError::InvalidMultihashLength(buffer.len()));
208        }
209
210        let height = buffer.get_u64();
211        let index = buffer.get_u16();
212
213        if height == 0 {
214            return Err(CidError::InvalidCid("Zero block height".to_string()));
215        }
216
217        Ok(Self {
218            eds_id: EdsId { height },
219            index,
220        })
221    }
222}
223
224impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
225    type Error = CidError;
226
227    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
228        let codec = cid.codec();
229        if codec != ROW_ID_CODEC {
230            return Err(CidError::InvalidCidCodec(codec));
231        }
232
233        let hash = cid.hash();
234
235        let size = hash.size() as usize;
236        if size != ROW_ID_SIZE {
237            return Err(CidError::InvalidMultihashLength(size));
238        }
239
240        let code = hash.code();
241        if code != ROW_ID_MULTIHASH_CODE {
242            return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
243        }
244
245        RowId::decode(hash.digest())
246    }
247}
248
249impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
250    fn from(row: RowId) -> Self {
251        let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
252        row.encode(&mut bytes);
253        // length is correct, so unwrap is safe
254        let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
255
256        CidGeneric::new_v1(ROW_ID_CODEC, mh)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::consts::appconsts::{AppVersion, SHARE_SIZE};
264    use crate::test_utils::{generate_dummy_eds, generate_eds};
265    use crate::Blob;
266
267    #[test]
268    fn round_trip_test() {
269        let row_id = RowId::new(5, 100).unwrap();
270        let cid = CidGeneric::from(row_id);
271
272        let multihash = cid.hash();
273        assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
274        assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
275
276        let deserialized_row_id = RowId::try_from(cid).unwrap();
277        assert_eq!(row_id, deserialized_row_id);
278    }
279
280    #[test]
281    fn index_calculation() {
282        let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
283        let eds = ExtendedDataSquare::new(shares, "codec".to_string(), AppVersion::V2).unwrap();
284
285        Row::new(1, &eds).unwrap();
286        Row::new(7, &eds).unwrap();
287        let row_err = Row::new(8, &eds).unwrap_err();
288        assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
289        let row_err = Row::new(100, &eds).unwrap_err();
290        assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
291    }
292
293    #[test]
294    fn row_id_size() {
295        // Size MUST be 10 by the spec.
296        assert_eq!(ROW_ID_SIZE, 10);
297
298        let row_id = RowId::new(0, 1).unwrap();
299        let mut bytes = BytesMut::new();
300        row_id.encode(&mut bytes);
301        assert_eq!(bytes.len(), ROW_ID_SIZE);
302    }
303
304    #[test]
305    fn from_buffer() {
306        let bytes = [
307            0x01, // CIDv1
308            0x80, 0xF0, 0x01, // CID codec = 7800
309            0x81, 0xF0, 0x01, // multihash code = 7801
310            0x0A, // len = ROW_ID_SIZE = 10
311            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
312            0, 7, // row index = 7
313        ];
314
315        let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
316        assert_eq!(cid.codec(), ROW_ID_CODEC);
317        let mh = cid.hash();
318        assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
319        assert_eq!(mh.size(), ROW_ID_SIZE as u8);
320        let row_id = RowId::try_from(cid).unwrap();
321        assert_eq!(row_id.index, 7);
322        assert_eq!(row_id.block_height(), 64);
323    }
324
325    #[test]
326    fn zero_block_height() {
327        let bytes = [
328            0x01, // CIDv1
329            0x80, 0xF0, 0x01, // CID codec = 7800
330            0x81, 0xF0, 0x01, // code = 7801
331            0x0A, // len = ROW_ID_SIZE = 10
332            0, 0, 0, 0, 0, 0, 0, 0, // invalid block height = 0 !
333            0, 7, // row index = 7
334        ];
335
336        let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
337        assert_eq!(cid.codec(), ROW_ID_CODEC);
338        let mh = cid.hash();
339        assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
340        assert_eq!(mh.size(), ROW_ID_SIZE as u8);
341        let row_err = RowId::try_from(cid).unwrap_err();
342        assert_eq!(
343            row_err,
344            CidError::InvalidCid("Zero block height".to_string())
345        );
346    }
347
348    #[test]
349    fn multihash_invalid_code() {
350        let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
351        let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
352        let row_err = RowId::try_from(cid).unwrap_err();
353        assert_eq!(
354            row_err,
355            CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
356        );
357    }
358
359    #[test]
360    fn cid_invalid_codec() {
361        let multihash =
362            Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
363        let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
364        let row_err = RowId::try_from(cid).unwrap_err();
365        assert_eq!(row_err, CidError::InvalidCidCodec(1234));
366    }
367
368    #[test]
369    fn test_roundtrip_verify() {
370        for _ in 0..5 {
371            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
372            let dah = DataAvailabilityHeader::from_eds(&eds);
373
374            let index = rand::random::<u16>() % eds.square_width();
375            let id = RowId {
376                eds_id: EdsId { height: 1 },
377                index,
378            };
379
380            let row = Row {
381                shares: eds.row(index).unwrap(),
382            };
383
384            let mut buf = BytesMut::new();
385            row.encode(&mut buf);
386            let decoded = Row::decode(id, &buf).unwrap();
387
388            decoded.verify(id, &dah).unwrap();
389        }
390    }
391
392    #[test]
393    fn reconstruct_all() {
394        for _ in 0..3 {
395            let eds = generate_eds(8 << (rand::random::<usize>() % 6), AppVersion::V2);
396
397            let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
398            let blobs = Blob::reconstruct_all(
399                rows.iter().flat_map(|row| row.shares.iter()),
400                AppVersion::V2,
401            )
402            .unwrap();
403
404            assert_eq!(blobs.len(), 2);
405        }
406    }
407}