celestia_types/
sample.rs

1//! Types related to samples.
2//!
3//! Sample in Celestia is understood as a single [`Share`] located at an
4//! index in the particular [`row`] of the [`ExtendedDataSquare`].
5//!
6//! [`row`]: crate::row
7//! [`Share`]: crate::Share
8//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
9
10use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::Share as RawShare;
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{ROW_ID_SIZE, RowId};
22use crate::{DataAvailabilityHeader, Error, Result, Share, bail_validation};
23
24pub use celestia_proto::shwap::Sample as RawSample;
25
26/// Number of bytes needed to represent [`SampleId`] in `multihash`.
27const SAMPLE_ID_SIZE: usize = 12;
28/// The code of the [`SampleId`] hashing algorithm in `multihash`.
29pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
30/// The id of codec used for the [`SampleId`] in `Cid`s.
31pub const SAMPLE_ID_CODEC: u64 = 0x7810;
32
33/// Identifies a particular [`Share`] located in the [`ExtendedDataSquare`].
34///
35/// [`Share`]: crate::Share
36/// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
37#[derive(Debug, PartialEq, Clone, Copy)]
38pub struct SampleId {
39    row_id: RowId,
40    column_index: u16,
41}
42
43/// Represents Sample, with proof of its inclusion
44#[derive(Clone, Debug, Serialize)]
45#[serde(into = "RawSample")]
46pub struct Sample {
47    /// Indication whether proving was done row or column-wise
48    pub proof_type: AxisType,
49    /// Share that is being sampled
50    pub share: Share,
51    /// Proof of the inclusion of the share
52    pub proof: NamespaceProof,
53}
54
55impl Sample {
56    /// Create a new [`Sample`] for the given index of the [`ExtendedDataSquare`] in a block.
57    ///
58    /// `row_index` and `column_index` specifies the [`Share`] position in EDS.
59    /// `proof_type` determines whether proof of inclusion of the [`Share`] should be
60    /// constructed for its row or column.
61    ///
62    /// # Errors
63    ///
64    /// This function will return an error, if:
65    ///
66    /// - `row_index`/`column_index` falls outside the provided [`ExtendedDataSquare`].
67    /// - [`ExtendedDataSquare`] is incorrect (either data shares don't have their namespace
68    ///   prefixed, or [`Share`]s aren't namespace ordered)
69    /// - Block height is zero
70    ///
71    /// # Example
72    ///
73    /// ```no_run
74    /// use celestia_types::AxisType;
75    /// use celestia_types::sample::{Sample, SampleId};
76    /// # use celestia_types::{ExtendedDataSquare, ExtendedHeader};
77    /// #
78    /// # fn get_extended_data_square(height: u64) -> ExtendedDataSquare {
79    /// #    unimplemented!()
80    /// # }
81    /// #
82    /// # fn get_extended_header(height: u64) -> ExtendedHeader {
83    /// #    unimplemented!()
84    /// # }
85    ///
86    /// let block_height = 15;
87    /// let eds = get_extended_data_square(block_height);
88    /// let header = get_extended_header(block_height);
89    ///
90    /// let sample_id = SampleId::new(2, 3, block_height).unwrap();
91    /// let sample = Sample::new(2, 3, AxisType::Row, &eds).unwrap();
92    ///
93    /// sample.verify(sample_id, &header.dah).unwrap();
94    /// ```
95    ///
96    /// [`Share`]: crate::Share
97    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
98    pub fn new(
99        row_index: u16,
100        column_index: u16,
101        proof_type: AxisType,
102        eds: &ExtendedDataSquare,
103    ) -> Result<Self> {
104        let share = eds.share(row_index, column_index)?.clone();
105
106        let range_proof = match proof_type {
107            AxisType::Row => eds
108                .row_nmt(row_index)?
109                .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
110            AxisType::Col => eds
111                .column_nmt(column_index)?
112                .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
113        };
114
115        let proof = NmtNamespaceProof::PresenceProof {
116            proof: range_proof,
117            ignore_max_ns: true,
118        };
119
120        Ok(Sample {
121            share,
122            proof: proof.into(),
123            proof_type,
124        })
125    }
126
127    /// verify sample with root hash from ExtendedHeader
128    pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
129        let root = match self.proof_type {
130            AxisType::Row => dah
131                .row_root(id.row_index())
132                .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
133            AxisType::Col => dah
134                .column_root(id.column_index())
135                .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
136        };
137
138        self.proof
139            .verify_range(&root, &[&self.share], *self.share.namespace())
140            .map_err(Error::RangeProofError)
141    }
142
143    /// Encode Sample into the raw binary representation.
144    pub fn encode(&self, bytes: &mut BytesMut) {
145        let raw = RawSample::from(self.clone());
146
147        bytes.reserve(raw.encoded_len());
148        raw.encode(bytes).expect("capacity reserved");
149    }
150
151    /// Decode Sample from the binary representation.
152    ///
153    /// # Errors
154    ///
155    /// This function will return an error if protobuf deserialization
156    /// fails and propagate errors from [`Sample::from_raw`].
157    pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
158        let raw = RawSample::decode(buffer)?;
159        Self::from_raw(id, raw)
160    }
161
162    /// Recover Sample from it's raw representation.
163    ///
164    /// # Errors
165    ///
166    /// This function will return error if proof is missing or invalid shares are not in
167    /// the expected namespace, and will propagate errors from [`Share`] construction.
168    pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
169        let Some(proof) = sample.proof else {
170            return Err(Error::MissingProof);
171        };
172
173        let proof: NamespaceProof = proof.try_into()?;
174        let proof_type = AxisType::try_from(sample.proof_type)?;
175
176        if proof.is_of_absence() {
177            return Err(Error::WrongProofType);
178        }
179
180        let Some(share) = sample.share else {
181            bail_validation!("missing share");
182        };
183        let Some(square_size) = proof.total_leaves() else {
184            bail_validation!("proof must be for single leaf");
185        };
186
187        let row_index = id.row_index() as usize;
188        let col_index = id.column_index() as usize;
189        let share = if row_index < square_size / 2 && col_index < square_size / 2 {
190            Share::from_raw(&share.data)?
191        } else {
192            Share::parity(&share.data)?
193        };
194
195        Ok(Sample {
196            proof_type,
197            share,
198            proof,
199        })
200    }
201}
202
203impl From<Sample> for RawSample {
204    fn from(sample: Sample) -> RawSample {
205        RawSample {
206            share: Some(RawShare {
207                data: sample.share.to_vec(),
208            }),
209            proof: Some(sample.proof.into()),
210            proof_type: sample.proof_type as i32,
211        }
212    }
213}
214
215impl SampleId {
216    /// Create a new [`SampleId`] for the given `row_index` and `column_index` of the
217    /// [`ExtendedDataSquare`] in a block.
218    ///
219    /// # Errors
220    ///
221    /// This function will return an error if the block height is zero.
222    ///
223    /// # Example
224    ///
225    /// ```no_run
226    /// use celestia_types::sample::SampleId;
227    ///
228    /// // Consider a 64th share of EDS with block height of 15
229    /// let header_height = 15;
230    /// SampleId::new(2, 1, header_height).unwrap();
231    /// ```
232    ///
233    /// [`Share`]: crate::Share
234    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
235    pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
236        if block_height == 0 {
237            return Err(Error::ZeroBlockHeight);
238        }
239
240        Ok(SampleId {
241            row_id: RowId::new(row_index, block_height)?,
242            column_index,
243        })
244    }
245
246    /// A height of the block which contains the sample.
247    pub fn block_height(&self) -> u64 {
248        self.row_id.block_height()
249    }
250
251    /// Row index of the [`ExtendedDataSquare`] that sample is located on.
252    ///
253    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
254    pub fn row_index(&self) -> u16 {
255        self.row_id.index()
256    }
257
258    /// Column index of the [`ExtendedDataSquare`] that sample is located on.
259    ///
260    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
261    pub fn column_index(&self) -> u16 {
262        self.column_index
263    }
264
265    fn encode(&self, bytes: &mut BytesMut) {
266        bytes.reserve(SAMPLE_ID_SIZE);
267        self.row_id.encode(bytes);
268        bytes.put_u16(self.column_index);
269    }
270
271    fn decode(buffer: &[u8]) -> Result<Self, CidError> {
272        if buffer.len() != SAMPLE_ID_SIZE {
273            return Err(CidError::InvalidMultihashLength(buffer.len()));
274        }
275
276        let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
277        let row_id = RowId::decode(row_bytes)?;
278        let column_index = col_bytes.get_u16();
279
280        Ok(SampleId {
281            row_id,
282            column_index,
283        })
284    }
285}
286
287impl<const S: usize> TryFrom<&CidGeneric<S>> for SampleId {
288    type Error = CidError;
289
290    fn try_from(cid: &CidGeneric<S>) -> Result<Self, Self::Error> {
291        let codec = cid.codec();
292        if codec != SAMPLE_ID_CODEC {
293            return Err(CidError::InvalidCidCodec(codec));
294        }
295
296        let hash = cid.hash();
297
298        let size = hash.size() as usize;
299        if size != SAMPLE_ID_SIZE {
300            return Err(CidError::InvalidMultihashLength(size));
301        }
302
303        let code = hash.code();
304        if code != SAMPLE_ID_MULTIHASH_CODE {
305            return Err(CidError::InvalidMultihashCode(
306                code,
307                SAMPLE_ID_MULTIHASH_CODE,
308            ));
309        }
310
311        SampleId::decode(hash.digest())
312    }
313}
314
315impl<const S: usize> TryFrom<&mut CidGeneric<S>> for SampleId {
316    type Error = CidError;
317
318    fn try_from(cid: &mut CidGeneric<S>) -> Result<Self, Self::Error> {
319        Self::try_from(&*cid)
320    }
321}
322
323impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
324    type Error = CidError;
325
326    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
327        Self::try_from(&cid)
328    }
329}
330
331impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
332    fn from(sample_id: SampleId) -> Self {
333        let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
334        // length is correct, so unwrap is safe
335        sample_id.encode(&mut bytes);
336
337        let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
338
339        CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::consts::appconsts::AppVersion;
347    use crate::test_utils::generate_dummy_eds;
348
349    #[test]
350    fn round_trip() {
351        let sample_id = SampleId::new(5, 10, 100).unwrap();
352        let cid = CidGeneric::from(sample_id);
353
354        let multihash = cid.hash();
355        assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
356        assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
357
358        let deserialized_sample_id = SampleId::try_from(cid).unwrap();
359        assert_eq!(sample_id, deserialized_sample_id);
360    }
361
362    #[test]
363    fn index_calculation() {
364        let eds = generate_dummy_eds(8, AppVersion::V2);
365
366        Sample::new(0, 0, AxisType::Row, &eds).unwrap();
367        Sample::new(7, 6, AxisType::Row, &eds).unwrap();
368        Sample::new(7, 7, AxisType::Row, &eds).unwrap();
369
370        let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
371        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
372
373        let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
374        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
375    }
376
377    #[test]
378    fn sample_id_size() {
379        // Size MUST be 12 by the spec.
380        assert_eq!(SAMPLE_ID_SIZE, 12);
381
382        let sample_id = SampleId::new(0, 4, 1).unwrap();
383        let mut bytes = BytesMut::new();
384        sample_id.encode(&mut bytes);
385        assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
386    }
387
388    #[test]
389    fn from_buffer() {
390        let bytes = [
391            0x01, // CIDv1
392            0x90, 0xF0, 0x01, // CID codec = 7810
393            0x91, 0xF0, 0x01, // multihash code = 7811
394            0x0C, // len = SAMPLE_ID_SIZE = 12
395            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
396            0, 7, // row index = 7
397            0, 5, // sample index = 5
398        ];
399
400        let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
401        assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
402        let mh = cid.hash();
403        assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
404        assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
405        let sample_id = SampleId::try_from(cid).unwrap();
406        assert_eq!(sample_id.block_height(), 64);
407        assert_eq!(sample_id.row_index(), 7);
408        assert_eq!(sample_id.column_index(), 5);
409    }
410
411    #[test]
412    fn multihash_invalid_code() {
413        let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
414        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
415        let code_err = SampleId::try_from(cid).unwrap_err();
416        assert_eq!(
417            code_err,
418            CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
419        );
420    }
421
422    #[test]
423    fn cid_invalid_codec() {
424        let multihash =
425            Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
426                .unwrap();
427        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
428        let codec_err = SampleId::try_from(cid).unwrap_err();
429        assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
430    }
431
432    #[test]
433    fn test_roundtrip_verify() {
434        for _ in 0..5 {
435            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
436            let dah = DataAvailabilityHeader::from_eds(&eds);
437
438            let row_index = rand::random::<u16>() % eds.square_width();
439            let col_index = rand::random::<u16>() % eds.square_width();
440            let proof_type = if rand::random() {
441                AxisType::Row
442            } else {
443                AxisType::Col
444            };
445
446            let id = SampleId::new(row_index, col_index, 1).unwrap();
447            let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
448
449            let mut buf = BytesMut::new();
450            sample.encode(&mut buf);
451            let decoded = Sample::decode(id, &buf).unwrap();
452
453            decoded.verify(id, &dah).unwrap();
454        }
455    }
456}