Skip to main content

celestia_types/
row.rs

1//! Types related to rows
2//!
3//! Row in Celestia is understood as all the [`Share`]s in a particular
4//! row of the [`ExtendedDataSquare`].
5//!
6//! [`Share`]: crate::Share
7//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
8
9use std::iter;
10
11use blockstore::block::CidError;
12use bytes::{Buf, BufMut, BytesMut};
13use celestia_proto::shwap::{Row as RawRow, Share as RawShare, row::HalfSide as RawHalfSide};
14use cid::CidGeneric;
15use multihash::Multihash;
16use prost::Message;
17use serde::Serialize;
18
19use crate::consts::appconsts::SHARE_SIZE;
20use crate::eds::{EDS_ID_SIZE, EdsId, ExtendedDataSquare};
21use crate::nmt::{Nmt, NmtExt};
22use crate::{DataAvailabilityHeader, Error, Result, Share};
23
24/// Number of bytes needed to represent [`RowId`] in `multihash`.
25pub const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
26/// The code of the [`RowId`] hashing algorithm in `multihash`.
27pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
28/// The id of codec used for the [`RowId`] in `Cid`s.
29pub const ROW_ID_CODEC: u64 = 0x7800;
30
31/// Represents particular row in a specific Data Square,
32#[derive(Debug, PartialEq, Clone, Copy)]
33pub struct RowId {
34    eds_id: EdsId,
35    index: u16,
36}
37
38/// Row together with the data
39#[derive(Clone, Debug, Serialize)]
40#[serde(into = "RawRow")]
41pub struct Row {
42    /// Shares contained in the row
43    pub shares: Vec<Share>,
44}
45
46impl Row {
47    /// Create Row with the given index from EDS
48    pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
49        let shares = eds.row(index)?;
50
51        Ok(Row { shares })
52    }
53
54    /// Verify the row against roots from DAH
55    pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
56        let row = id.index;
57        let mut tree = Nmt::default();
58
59        for share in &self.shares {
60            tree.push_leaf(share.as_ref(), *share.namespace())
61                .map_err(Error::Nmt)?;
62        }
63
64        let Some(root) = dah.row_root(row) else {
65            return Err(Error::EdsIndexOutOfRange(row, 0));
66        };
67
68        if tree.root().hash() != root.hash() {
69            return Err(Error::RootMismatch);
70        }
71
72        Ok(())
73    }
74
75    /// Encode Row into the raw binary representation.
76    pub fn encode(&self, bytes: &mut BytesMut) {
77        let raw = RawRow::from(self.clone());
78
79        bytes.reserve(raw.encoded_len());
80        raw.encode(bytes).expect("capacity reserved");
81    }
82
83    /// Decode Row from the binary representation.
84    ///
85    /// # Errors
86    ///
87    /// This function will return an error if protobuf deserialization
88    /// fails and propagate errors from [`Row::from_raw`].
89    pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
90        let raw = RawRow::decode(buffer)?;
91        Self::from_raw(id, raw)
92    }
93
94    /// Recover Row from it's raw representation, reconstructing the missing half
95    /// using [`leopard_codec`].
96    ///
97    /// # Errors
98    ///
99    /// This function will propagate errors from [`leopard_codec`] and [`Share`] construction.
100    pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
101        let data_shares = row.shares_half.len();
102
103        let shares = match row.half_side() {
104            RawHalfSide::Left => {
105                // We have original data, recompute parity shares
106                let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
107                shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
108                leopard_codec::encode(&mut shares, data_shares)?;
109                shares
110            }
111            RawHalfSide::Right => {
112                // We have parity data, recompute original shares
113                let mut shares: Vec<_> = iter::repeat_n(vec![], data_shares)
114                    .chain(row.shares_half.into_iter().map(|shr| shr.data))
115                    .collect();
116                leopard_codec::reconstruct(&mut shares, data_shares)?;
117                shares
118            }
119        };
120
121        let row_index = id.index() as usize;
122        let shares = shares
123            .into_iter()
124            .enumerate()
125            .map(|(col_index, shr)| {
126                if row_index < data_shares && col_index < data_shares {
127                    Share::from_raw(&shr)
128                } else {
129                    Share::parity(&shr)
130                }
131            })
132            .collect::<Result<_>>()?;
133
134        Ok(Row { shares })
135    }
136}
137
138impl From<Row> for RawRow {
139    fn from(row: Row) -> RawRow {
140        // parity shares aren't transmitted over shwap, just data shares
141        let square_width = row.shares.len();
142        let shares_half = row
143            .shares
144            .into_iter()
145            .map(|shr| RawShare { data: shr.to_vec() })
146            .take(square_width / 2)
147            .collect();
148
149        RawRow {
150            shares_half,
151            half_side: RawHalfSide::Left.into(),
152        }
153    }
154}
155
156impl RowId {
157    /// Create a new [`RowId`] for the particular block.
158    ///
159    /// # Errors
160    ///
161    /// This function will return an error if the block height is invalid.
162    pub fn new(index: u16, height: u64) -> Result<Self> {
163        Ok(Self {
164            index,
165            eds_id: EdsId::new(height)?,
166        })
167    }
168
169    /// A height of the block which contains the data.
170    pub fn block_height(&self) -> u64 {
171        self.eds_id.block_height()
172    }
173
174    /// An index of the row in the [`ExtendedDataSquare`].
175    ///
176    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
177    pub fn index(&self) -> u16 {
178        self.index
179    }
180
181    /// Encode row id into the byte representation.
182    pub fn encode(&self, bytes: &mut BytesMut) {
183        bytes.reserve(ROW_ID_SIZE);
184        self.eds_id.encode(bytes);
185        bytes.put_u16(self.index);
186    }
187
188    /// Decode row id from the byte representation.
189    pub fn decode(buffer: &[u8]) -> Result<Self> {
190        if buffer.len() != ROW_ID_SIZE {
191            return Err(Error::InvalidLength(buffer.len(), ROW_ID_SIZE));
192        }
193
194        let (eds_bytes, mut row_bytes) = buffer.split_at(EDS_ID_SIZE);
195        let eds_id = EdsId::decode(eds_bytes)?;
196        let index = row_bytes.get_u16();
197
198        Ok(Self { eds_id, index })
199    }
200}
201
202impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
203    type Error = CidError;
204
205    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
206        let codec = cid.codec();
207        if codec != ROW_ID_CODEC {
208            return Err(CidError::InvalidCidCodec(codec));
209        }
210
211        let hash = cid.hash();
212
213        let size = hash.size() as usize;
214        if size != ROW_ID_SIZE {
215            return Err(CidError::InvalidMultihashLength(size));
216        }
217
218        let code = hash.code();
219        if code != ROW_ID_MULTIHASH_CODE {
220            return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
221        }
222
223        RowId::decode(hash.digest()).map_err(|e| CidError::InvalidCid(e.to_string()))
224    }
225}
226
227impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
228    fn from(row: RowId) -> Self {
229        let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
230        row.encode(&mut bytes);
231        // length is correct, so unwrap is safe
232        let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
233
234        CidGeneric::new_v1(ROW_ID_CODEC, mh)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::Blob;
242    use crate::consts::appconsts::{AppVersion, SHARE_SIZE};
243    use crate::test_utils::{generate_dummy_eds, generate_eds};
244
245    #[test]
246    fn round_trip_test() {
247        let row_id = RowId::new(5, 100).unwrap();
248        let cid = CidGeneric::from(row_id);
249
250        let multihash = cid.hash();
251        assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
252        assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
253
254        let deserialized_row_id = RowId::try_from(cid).unwrap();
255        assert_eq!(row_id, deserialized_row_id);
256    }
257
258    #[test]
259    fn index_calculation() {
260        let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
261        let eds = ExtendedDataSquare::new(shares, "codec".to_string(), AppVersion::V2).unwrap();
262
263        Row::new(1, &eds).unwrap();
264        Row::new(7, &eds).unwrap();
265        let row_err = Row::new(8, &eds).unwrap_err();
266        assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
267        let row_err = Row::new(100, &eds).unwrap_err();
268        assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
269    }
270
271    #[test]
272    fn row_id_size() {
273        // Size MUST be 10 by the spec.
274        assert_eq!(ROW_ID_SIZE, 10);
275
276        let row_id = RowId::new(0, 1).unwrap();
277        let mut bytes = BytesMut::new();
278        row_id.encode(&mut bytes);
279        assert_eq!(bytes.len(), ROW_ID_SIZE);
280    }
281
282    #[test]
283    fn from_buffer() {
284        let bytes = [
285            0x01, // CIDv1
286            0x80, 0xF0, 0x01, // CID codec = 7800
287            0x81, 0xF0, 0x01, // multihash code = 7801
288            0x0A, // len = ROW_ID_SIZE = 10
289            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
290            0, 7, // row index = 7
291        ];
292
293        let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
294        assert_eq!(cid.codec(), ROW_ID_CODEC);
295        let mh = cid.hash();
296        assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
297        assert_eq!(mh.size(), ROW_ID_SIZE as u8);
298        let row_id = RowId::try_from(cid).unwrap();
299        assert_eq!(row_id.index, 7);
300        assert_eq!(row_id.block_height(), 64);
301    }
302
303    #[test]
304    fn zero_block_height() {
305        let bytes = [
306            0x01, // CIDv1
307            0x80, 0xF0, 0x01, // CID codec = 7800
308            0x81, 0xF0, 0x01, // code = 7801
309            0x0A, // len = ROW_ID_SIZE = 10
310            0, 0, 0, 0, 0, 0, 0, 0, // invalid block height = 0 !
311            0, 7, // row index = 7
312        ];
313
314        let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
315        assert_eq!(cid.codec(), ROW_ID_CODEC);
316        let mh = cid.hash();
317        assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
318        assert_eq!(mh.size(), ROW_ID_SIZE as u8);
319        let row_err = RowId::try_from(cid).unwrap_err();
320        assert_eq!(
321            row_err,
322            CidError::InvalidCid("Invalid zero block height".to_string())
323        );
324    }
325
326    #[test]
327    fn multihash_invalid_code() {
328        let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
329        let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
330        let row_err = RowId::try_from(cid).unwrap_err();
331        assert_eq!(
332            row_err,
333            CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
334        );
335    }
336
337    #[test]
338    fn cid_invalid_codec() {
339        let multihash =
340            Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
341        let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
342        let row_err = RowId::try_from(cid).unwrap_err();
343        assert_eq!(row_err, CidError::InvalidCidCodec(1234));
344    }
345
346    #[test]
347    fn test_roundtrip_verify() {
348        for _ in 0..5 {
349            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
350            let dah = DataAvailabilityHeader::from_eds(&eds);
351
352            let index = rand::random::<u16>() % eds.square_width();
353            let id = RowId::new(index, 1).unwrap();
354
355            let row = Row {
356                shares: eds.row(index).unwrap(),
357            };
358
359            let mut buf = BytesMut::new();
360            row.encode(&mut buf);
361            let decoded = Row::decode(id, &buf).unwrap();
362
363            decoded.verify(id, &dah).unwrap();
364        }
365    }
366
367    #[test]
368    fn reconstruct_all() {
369        for _ in 0..3 {
370            let eds = generate_eds(8 << (rand::random::<usize>() % 6), AppVersion::V2);
371
372            let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
373            let blobs = Blob::reconstruct_all(
374                rows.iter().flat_map(|row| row.shares.iter()),
375                AppVersion::V2,
376            )
377            .unwrap();
378
379            assert_eq!(blobs.len(), 2);
380        }
381    }
382}