celestia_types/
sample.rs

1//! Types related to samples.
2//!
3//! Sample in Celestia is understood as a single [`Share`] located at an
4//! index in the particular [`row`] of the [`ExtendedDataSquare`].
5//!
6//! [`row`]: crate::row
7//! [`Share`]: crate::Share
8//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
9
10use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::Share as RawShare;
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{ROW_ID_SIZE, RowId};
22use crate::{DataAvailabilityHeader, Error, Result, Share, bail_validation};
23
24pub use celestia_proto::shwap::Sample as RawSample;
25
26/// Number of bytes needed to represent [`SampleId`] in `multihash`.
27const SAMPLE_ID_SIZE: usize = 12;
28/// The code of the [`SampleId`] hashing algorithm in `multihash`.
29pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
30/// The id of codec used for the [`SampleId`] in `Cid`s.
31pub const SAMPLE_ID_CODEC: u64 = 0x7810;
32
33/// Identifies a particular [`Share`] located in the [`ExtendedDataSquare`].
34///
35/// [`Share`]: crate::Share
36/// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
37#[derive(Debug, PartialEq, Clone, Copy)]
38pub struct SampleId {
39    row_id: RowId,
40    column_index: u16,
41}
42
43/// Represents Sample, with proof of its inclusion
44#[derive(Clone, Debug, Serialize)]
45#[serde(into = "RawSample")]
46pub struct Sample {
47    /// Indication whether proving was done row or column-wise
48    pub proof_type: AxisType,
49    /// Share that is being sampled
50    pub share: Share,
51    /// Proof of the inclusion of the share
52    pub proof: NamespaceProof,
53}
54
55impl Sample {
56    /// Create a new [`Sample`] for the given index of the [`ExtendedDataSquare`] in a block.
57    ///
58    /// `row_index` and `column_index` specifies the [`Share`] position in EDS.
59    /// `proof_type` determines whether proof of inclusion of the [`Share`] should be
60    /// constructed for its row or column.
61    ///
62    /// # Errors
63    ///
64    /// This function will return an error, if:
65    ///
66    /// - `row_index`/`column_index` falls outside the provided [`ExtendedDataSquare`].
67    /// - [`ExtendedDataSquare`] is incorrect (either data shares don't have their namespace
68    ///   prefixed, or [`Share`]s aren't namespace ordered)
69    /// - Block height is zero
70    ///
71    /// # Example
72    ///
73    /// ```no_run
74    /// use celestia_types::AxisType;
75    /// use celestia_types::sample::{Sample, SampleId};
76    /// # use celestia_types::{ExtendedDataSquare, ExtendedHeader};
77    /// #
78    /// # fn get_extended_data_square(height: u64) -> ExtendedDataSquare {
79    /// #    unimplemented!()
80    /// # }
81    /// #
82    /// # fn get_extended_header(height: u64) -> ExtendedHeader {
83    /// #    unimplemented!()
84    /// # }
85    ///
86    /// let block_height = 15;
87    /// let eds = get_extended_data_square(block_height);
88    /// let header = get_extended_header(block_height);
89    ///
90    /// let sample_id = SampleId::new(2, 3, block_height).unwrap();
91    /// let sample = Sample::new(2, 3, AxisType::Row, &eds).unwrap();
92    ///
93    /// sample.verify(sample_id, &header.dah).unwrap();
94    /// ```
95    ///
96    /// [`Share`]: crate::Share
97    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
98    pub fn new(
99        row_index: u16,
100        column_index: u16,
101        proof_type: AxisType,
102        eds: &ExtendedDataSquare,
103    ) -> Result<Self> {
104        let share = eds.share(row_index, column_index)?.clone();
105
106        let range_proof = match proof_type {
107            AxisType::Row => eds
108                .row_nmt(row_index)?
109                .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
110            AxisType::Col => eds
111                .column_nmt(column_index)?
112                .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
113        };
114
115        let proof = NmtNamespaceProof::PresenceProof {
116            proof: range_proof,
117            ignore_max_ns: true,
118        };
119
120        Ok(Sample {
121            share,
122            proof: proof.into(),
123            proof_type,
124        })
125    }
126
127    /// verify sample with root hash from ExtendedHeader
128    pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
129        let root = match self.proof_type {
130            AxisType::Row => dah
131                .row_root(id.row_index())
132                .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
133            AxisType::Col => dah
134                .column_root(id.column_index())
135                .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
136        };
137
138        self.proof
139            .verify_range(&root, &[&self.share], *self.share.namespace())
140            .map_err(Error::RangeProofError)
141    }
142
143    /// Encode Sample into the raw binary representation.
144    pub fn encode(&self, bytes: &mut BytesMut) {
145        let raw = RawSample::from(self.clone());
146
147        bytes.reserve(raw.encoded_len());
148        raw.encode(bytes).expect("capacity reserved");
149    }
150
151    /// Decode Sample from the binary representation.
152    ///
153    /// # Errors
154    ///
155    /// This function will return an error if protobuf deserialization
156    /// fails and propagate errors from [`Sample::from_raw`].
157    pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
158        let raw = RawSample::decode(buffer)?;
159        Self::from_raw(id, raw)
160    }
161
162    /// Recover Sample from it's raw representation.
163    ///
164    /// # Errors
165    ///
166    /// This function will return error if proof is missing or invalid shares are not in
167    /// the expected namespace, and will propagate errors from [`Share`] construction.
168    pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
169        let Some(proof) = sample.proof else {
170            return Err(Error::MissingProof);
171        };
172
173        let proof: NamespaceProof = proof.try_into()?;
174        let proof_type = AxisType::try_from(sample.proof_type)?;
175
176        if proof.is_of_absence() {
177            return Err(Error::WrongProofType);
178        }
179
180        let Some(share) = sample.share else {
181            bail_validation!("missing share");
182        };
183        let Some(square_size) = proof.total_leaves() else {
184            bail_validation!("proof must be for single leaf");
185        };
186
187        let row_index = id.row_index() as usize;
188        let col_index = id.column_index() as usize;
189        let share = if row_index < square_size / 2 && col_index < square_size / 2 {
190            Share::from_raw(&share.data)?
191        } else {
192            Share::parity(&share.data)?
193        };
194
195        Ok(Sample {
196            proof_type,
197            share,
198            proof,
199        })
200    }
201}
202
203impl From<Sample> for RawSample {
204    fn from(sample: Sample) -> RawSample {
205        RawSample {
206            share: Some(RawShare {
207                data: sample.share.to_vec(),
208            }),
209            proof: Some(sample.proof.into()),
210            proof_type: sample.proof_type as i32,
211        }
212    }
213}
214
215impl SampleId {
216    /// Create a new [`SampleId`] for the given `row_index` and `column_index` of the
217    /// [`ExtendedDataSquare`] in a block.
218    ///
219    /// # Errors
220    ///
221    /// This function will return an error if the block height is zero.
222    ///
223    /// # Example
224    ///
225    /// ```no_run
226    /// use celestia_types::sample::SampleId;
227    ///
228    /// // Consider a 64th share of EDS with block height of 15
229    /// let header_height = 15;
230    /// SampleId::new(2, 1, header_height).unwrap();
231    /// ```
232    ///
233    /// [`Share`]: crate::Share
234    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
235    pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
236        if block_height == 0 {
237            return Err(Error::ZeroBlockHeight);
238        }
239
240        Ok(SampleId {
241            row_id: RowId::new(row_index, block_height)?,
242            column_index,
243        })
244    }
245
246    /// A height of the block which contains the sample.
247    pub fn block_height(&self) -> u64 {
248        self.row_id.block_height()
249    }
250
251    /// Row index of the [`ExtendedDataSquare`] that sample is located on.
252    ///
253    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
254    pub fn row_index(&self) -> u16 {
255        self.row_id.index()
256    }
257
258    /// Column index of the [`ExtendedDataSquare`] that sample is located on.
259    ///
260    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
261    pub fn column_index(&self) -> u16 {
262        self.column_index
263    }
264
265    /// Encode sample id into the byte representation.
266    pub fn encode(&self, bytes: &mut BytesMut) {
267        bytes.reserve(SAMPLE_ID_SIZE);
268        self.row_id.encode(bytes);
269        bytes.put_u16(self.column_index);
270    }
271
272    /// Decode sample id from the byte representation.
273    pub fn decode(buffer: &[u8]) -> Result<Self> {
274        if buffer.len() != SAMPLE_ID_SIZE {
275            return Err(Error::InvalidLength(buffer.len(), SAMPLE_ID_SIZE));
276        }
277
278        let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
279        let row_id = RowId::decode(row_bytes)?;
280        let column_index = col_bytes.get_u16();
281
282        Ok(SampleId {
283            row_id,
284            column_index,
285        })
286    }
287}
288
289impl<const S: usize> TryFrom<&CidGeneric<S>> for SampleId {
290    type Error = CidError;
291
292    fn try_from(cid: &CidGeneric<S>) -> Result<Self, Self::Error> {
293        let codec = cid.codec();
294        if codec != SAMPLE_ID_CODEC {
295            return Err(CidError::InvalidCidCodec(codec));
296        }
297
298        let hash = cid.hash();
299
300        let size = hash.size() as usize;
301        if size != SAMPLE_ID_SIZE {
302            return Err(CidError::InvalidMultihashLength(size));
303        }
304
305        let code = hash.code();
306        if code != SAMPLE_ID_MULTIHASH_CODE {
307            return Err(CidError::InvalidMultihashCode(
308                code,
309                SAMPLE_ID_MULTIHASH_CODE,
310            ));
311        }
312
313        SampleId::decode(hash.digest()).map_err(|e| CidError::InvalidCid(e.to_string()))
314    }
315}
316
317impl<const S: usize> TryFrom<&mut CidGeneric<S>> for SampleId {
318    type Error = CidError;
319
320    fn try_from(cid: &mut CidGeneric<S>) -> Result<Self, Self::Error> {
321        Self::try_from(&*cid)
322    }
323}
324
325impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
326    type Error = CidError;
327
328    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
329        Self::try_from(&cid)
330    }
331}
332
333impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
334    fn from(sample_id: SampleId) -> Self {
335        let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
336        // length is correct, so unwrap is safe
337        sample_id.encode(&mut bytes);
338
339        let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
340
341        CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::consts::appconsts::AppVersion;
349    use crate::test_utils::generate_dummy_eds;
350
351    #[test]
352    fn round_trip() {
353        let sample_id = SampleId::new(5, 10, 100).unwrap();
354        let cid = CidGeneric::from(sample_id);
355
356        let multihash = cid.hash();
357        assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
358        assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
359
360        let deserialized_sample_id = SampleId::try_from(cid).unwrap();
361        assert_eq!(sample_id, deserialized_sample_id);
362    }
363
364    #[test]
365    fn index_calculation() {
366        let eds = generate_dummy_eds(8, AppVersion::V2);
367
368        Sample::new(0, 0, AxisType::Row, &eds).unwrap();
369        Sample::new(7, 6, AxisType::Row, &eds).unwrap();
370        Sample::new(7, 7, AxisType::Row, &eds).unwrap();
371
372        let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
373        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
374
375        let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
376        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
377    }
378
379    #[test]
380    fn sample_id_size() {
381        // Size MUST be 12 by the spec.
382        assert_eq!(SAMPLE_ID_SIZE, 12);
383
384        let sample_id = SampleId::new(0, 4, 1).unwrap();
385        let mut bytes = BytesMut::new();
386        sample_id.encode(&mut bytes);
387        assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
388    }
389
390    #[test]
391    fn from_buffer() {
392        let bytes = [
393            0x01, // CIDv1
394            0x90, 0xF0, 0x01, // CID codec = 7810
395            0x91, 0xF0, 0x01, // multihash code = 7811
396            0x0C, // len = SAMPLE_ID_SIZE = 12
397            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
398            0, 7, // row index = 7
399            0, 5, // sample index = 5
400        ];
401
402        let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
403        assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
404        let mh = cid.hash();
405        assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
406        assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
407        let sample_id = SampleId::try_from(cid).unwrap();
408        assert_eq!(sample_id.block_height(), 64);
409        assert_eq!(sample_id.row_index(), 7);
410        assert_eq!(sample_id.column_index(), 5);
411    }
412
413    #[test]
414    fn multihash_invalid_code() {
415        let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
416        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
417        let code_err = SampleId::try_from(cid).unwrap_err();
418        assert_eq!(
419            code_err,
420            CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
421        );
422    }
423
424    #[test]
425    fn cid_invalid_codec() {
426        let multihash =
427            Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
428                .unwrap();
429        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
430        let codec_err = SampleId::try_from(cid).unwrap_err();
431        assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
432    }
433
434    #[test]
435    fn test_roundtrip_verify() {
436        for _ in 0..5 {
437            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
438            let dah = DataAvailabilityHeader::from_eds(&eds);
439
440            let row_index = rand::random::<u16>() % eds.square_width();
441            let col_index = rand::random::<u16>() % eds.square_width();
442            let proof_type = if rand::random() {
443                AxisType::Row
444            } else {
445                AxisType::Col
446            };
447
448            let id = SampleId::new(row_index, col_index, 1).unwrap();
449            let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
450
451            let mut buf = BytesMut::new();
452            sample.encode(&mut buf);
453            let decoded = Sample::decode(id, &buf).unwrap();
454
455            decoded.verify(id, &dah).unwrap();
456        }
457    }
458}