celestia_types/
sample.rs

1//! Types related to samples.
2//!
3//! Sample in Celestia is understood as a single [`Share`] located at an
4//! index in the particular [`row`] of the [`ExtendedDataSquare`].
5//!
6//! [`row`]: crate::row
7//! [`Share`]: crate::Share
8//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
9
10use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::{Sample as RawSample, Share as RawShare};
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{RowId, ROW_ID_SIZE};
22use crate::{bail_validation, DataAvailabilityHeader, Error, Result, Share};
23
24/// Number of bytes needed to represent [`SampleId`] in `multihash`.
25const SAMPLE_ID_SIZE: usize = 12;
26/// The code of the [`SampleId`] hashing algorithm in `multihash`.
27pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
28/// The id of codec used for the [`SampleId`] in `Cid`s.
29pub const SAMPLE_ID_CODEC: u64 = 0x7810;
30
31/// Identifies a particular [`Share`] located in the [`ExtendedDataSquare`].
32///
33/// [`Share`]: crate::Share
34/// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
35#[derive(Debug, PartialEq, Clone, Copy)]
36pub struct SampleId {
37    row_id: RowId,
38    column_index: u16,
39}
40
41/// Represents Sample, with proof of its inclusion
42#[derive(Clone, Debug, Serialize)]
43#[serde(into = "RawSample")]
44pub struct Sample {
45    /// Indication whether proving was done row or column-wise
46    pub proof_type: AxisType,
47    /// Share that is being sampled
48    pub share: Share,
49    /// Proof of the inclusion of the share
50    pub proof: NamespaceProof,
51}
52
53impl Sample {
54    /// Create a new [`Sample`] for the given index of the [`ExtendedDataSquare`] in a block.
55    ///
56    /// `row_index` and `column_index` specifies the [`Share`] position in EDS.
57    /// `proof_type` determines whether proof of inclusion of the [`Share`] should be
58    /// constructed for its row or column.
59    ///
60    /// # Errors
61    ///
62    /// This function will return an error, if:
63    ///
64    /// - `row_index`/`column_index` falls outside the provided [`ExtendedDataSquare`].
65    /// - [`ExtendedDataSquare`] is incorrect (either data shares don't have their namespace
66    ///   prefixed, or [`Share`]s aren't namespace ordered)
67    /// - Block height is zero
68    ///
69    /// # Example
70    ///
71    /// ```no_run
72    /// use celestia_types::AxisType;
73    /// use celestia_types::sample::{Sample, SampleId};
74    /// # use celestia_types::{ExtendedDataSquare, ExtendedHeader};
75    /// #
76    /// # fn get_extended_data_square(height: u64) -> ExtendedDataSquare {
77    /// #    unimplemented!()
78    /// # }
79    /// #
80    /// # fn get_extended_header(height: u64) -> ExtendedHeader {
81    /// #    unimplemented!()
82    /// # }
83    ///
84    /// let block_height = 15;
85    /// let eds = get_extended_data_square(block_height);
86    /// let header = get_extended_header(block_height);
87    ///
88    /// let sample_id = SampleId::new(2, 3, block_height).unwrap();
89    /// let sample = Sample::new(2, 3, AxisType::Row, &eds).unwrap();
90    ///
91    /// sample.verify(sample_id, &header.dah).unwrap();
92    /// ```
93    ///
94    /// [`Share`]: crate::Share
95    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
96    pub fn new(
97        row_index: u16,
98        column_index: u16,
99        proof_type: AxisType,
100        eds: &ExtendedDataSquare,
101    ) -> Result<Self> {
102        let share = eds.share(row_index, column_index)?.clone();
103
104        let range_proof = match proof_type {
105            AxisType::Row => eds
106                .row_nmt(row_index)?
107                .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
108            AxisType::Col => eds
109                .column_nmt(column_index)?
110                .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
111        };
112
113        let proof = NmtNamespaceProof::PresenceProof {
114            proof: range_proof,
115            ignore_max_ns: true,
116        };
117
118        Ok(Sample {
119            share,
120            proof: proof.into(),
121            proof_type,
122        })
123    }
124
125    /// verify sample with root hash from ExtendedHeader
126    pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
127        let root = match self.proof_type {
128            AxisType::Row => dah
129                .row_root(id.row_index())
130                .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
131            AxisType::Col => dah
132                .column_root(id.column_index())
133                .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
134        };
135
136        self.proof
137            .verify_range(&root, &[&self.share], *self.share.namespace())
138            .map_err(Error::RangeProofError)
139    }
140
141    /// Encode Sample into the raw binary representation.
142    pub fn encode(&self, bytes: &mut BytesMut) {
143        let raw = RawSample::from(self.clone());
144
145        bytes.reserve(raw.encoded_len());
146        raw.encode(bytes).expect("capacity reserved");
147    }
148
149    /// Decode Sample from the binary representation.
150    ///
151    /// # Errors
152    ///
153    /// This function will return an error if protobuf deserialization
154    /// fails and propagate errors from [`Sample::from_raw`].
155    pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
156        let raw = RawSample::decode(buffer)?;
157        Self::from_raw(id, raw)
158    }
159
160    /// Recover Sample from it's raw representation.
161    ///
162    /// # Errors
163    ///
164    /// This function will return error if proof is missing or invalid shares are not in
165    /// the expected namespace, and will propagate errors from [`Share`] construction.
166    pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
167        let Some(proof) = sample.proof else {
168            return Err(Error::MissingProof);
169        };
170
171        let proof: NamespaceProof = proof.try_into()?;
172        let proof_type = AxisType::try_from(sample.proof_type)?;
173
174        if proof.is_of_absence() {
175            return Err(Error::WrongProofType);
176        }
177
178        let Some(share) = sample.share else {
179            bail_validation!("missing share");
180        };
181        let Some(square_size) = proof.total_leaves() else {
182            bail_validation!("proof must be for single leaf");
183        };
184
185        let row_index = id.row_index() as usize;
186        let col_index = id.column_index() as usize;
187        let share = if row_index < square_size / 2 && col_index < square_size / 2 {
188            Share::from_raw(&share.data)?
189        } else {
190            Share::parity(&share.data)?
191        };
192
193        Ok(Sample {
194            proof_type,
195            share,
196            proof,
197        })
198    }
199}
200
201impl From<Sample> for RawSample {
202    fn from(sample: Sample) -> RawSample {
203        RawSample {
204            share: Some(RawShare {
205                data: sample.share.to_vec(),
206            }),
207            proof: Some(sample.proof.into()),
208            proof_type: sample.proof_type as i32,
209        }
210    }
211}
212
213impl SampleId {
214    /// Create a new [`SampleId`] for the given `row_index` and `column_index` of the
215    /// [`ExtendedDataSquare`] in a block.
216    ///
217    /// # Errors
218    ///
219    /// This function will return an error if the block height is zero.
220    ///
221    /// # Example
222    ///
223    /// ```no_run
224    /// use celestia_types::sample::SampleId;
225    ///
226    /// // Consider a 64th share of EDS with block height of 15
227    /// let header_height = 15;
228    /// SampleId::new(2, 1, header_height).unwrap();
229    /// ```
230    ///
231    /// [`Share`]: crate::Share
232    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
233    pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
234        if block_height == 0 {
235            return Err(Error::ZeroBlockHeight);
236        }
237
238        Ok(SampleId {
239            row_id: RowId::new(row_index, block_height)?,
240            column_index,
241        })
242    }
243
244    /// A height of the block which contains the sample.
245    pub fn block_height(&self) -> u64 {
246        self.row_id.block_height()
247    }
248
249    /// Row index of the [`ExtendedDataSquare`] that sample is located on.
250    ///
251    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
252    pub fn row_index(&self) -> u16 {
253        self.row_id.index()
254    }
255
256    /// Column index of the [`ExtendedDataSquare`] that sample is located on.
257    ///
258    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
259    pub fn column_index(&self) -> u16 {
260        self.column_index
261    }
262
263    fn encode(&self, bytes: &mut BytesMut) {
264        bytes.reserve(SAMPLE_ID_SIZE);
265        self.row_id.encode(bytes);
266        bytes.put_u16(self.column_index);
267    }
268
269    fn decode(buffer: &[u8]) -> Result<Self, CidError> {
270        if buffer.len() != SAMPLE_ID_SIZE {
271            return Err(CidError::InvalidMultihashLength(buffer.len()));
272        }
273
274        let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
275        let row_id = RowId::decode(row_bytes)?;
276        let column_index = col_bytes.get_u16();
277
278        Ok(SampleId {
279            row_id,
280            column_index,
281        })
282    }
283}
284
285impl<const S: usize> TryFrom<&CidGeneric<S>> for SampleId {
286    type Error = CidError;
287
288    fn try_from(cid: &CidGeneric<S>) -> Result<Self, Self::Error> {
289        let codec = cid.codec();
290        if codec != SAMPLE_ID_CODEC {
291            return Err(CidError::InvalidCidCodec(codec));
292        }
293
294        let hash = cid.hash();
295
296        let size = hash.size() as usize;
297        if size != SAMPLE_ID_SIZE {
298            return Err(CidError::InvalidMultihashLength(size));
299        }
300
301        let code = hash.code();
302        if code != SAMPLE_ID_MULTIHASH_CODE {
303            return Err(CidError::InvalidMultihashCode(
304                code,
305                SAMPLE_ID_MULTIHASH_CODE,
306            ));
307        }
308
309        SampleId::decode(hash.digest())
310    }
311}
312
313impl<const S: usize> TryFrom<&mut CidGeneric<S>> for SampleId {
314    type Error = CidError;
315
316    fn try_from(cid: &mut CidGeneric<S>) -> Result<Self, Self::Error> {
317        Self::try_from(&*cid)
318    }
319}
320
321impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
322    type Error = CidError;
323
324    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
325        Self::try_from(&cid)
326    }
327}
328
329impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
330    fn from(sample_id: SampleId) -> Self {
331        let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
332        // length is correct, so unwrap is safe
333        sample_id.encode(&mut bytes);
334
335        let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
336
337        CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::consts::appconsts::AppVersion;
345    use crate::test_utils::generate_dummy_eds;
346
347    #[test]
348    fn round_trip() {
349        let sample_id = SampleId::new(5, 10, 100).unwrap();
350        let cid = CidGeneric::from(sample_id);
351
352        let multihash = cid.hash();
353        assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
354        assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
355
356        let deserialized_sample_id = SampleId::try_from(cid).unwrap();
357        assert_eq!(sample_id, deserialized_sample_id);
358    }
359
360    #[test]
361    fn index_calculation() {
362        let eds = generate_dummy_eds(8, AppVersion::V2);
363
364        Sample::new(0, 0, AxisType::Row, &eds).unwrap();
365        Sample::new(7, 6, AxisType::Row, &eds).unwrap();
366        Sample::new(7, 7, AxisType::Row, &eds).unwrap();
367
368        let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
369        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
370
371        let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
372        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
373    }
374
375    #[test]
376    fn sample_id_size() {
377        // Size MUST be 12 by the spec.
378        assert_eq!(SAMPLE_ID_SIZE, 12);
379
380        let sample_id = SampleId::new(0, 4, 1).unwrap();
381        let mut bytes = BytesMut::new();
382        sample_id.encode(&mut bytes);
383        assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
384    }
385
386    #[test]
387    fn from_buffer() {
388        let bytes = [
389            0x01, // CIDv1
390            0x90, 0xF0, 0x01, // CID codec = 7810
391            0x91, 0xF0, 0x01, // multihash code = 7811
392            0x0C, // len = SAMPLE_ID_SIZE = 12
393            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
394            0, 7, // row index = 7
395            0, 5, // sample index = 5
396        ];
397
398        let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
399        assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
400        let mh = cid.hash();
401        assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
402        assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
403        let sample_id = SampleId::try_from(cid).unwrap();
404        assert_eq!(sample_id.block_height(), 64);
405        assert_eq!(sample_id.row_index(), 7);
406        assert_eq!(sample_id.column_index(), 5);
407    }
408
409    #[test]
410    fn multihash_invalid_code() {
411        let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
412        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
413        let code_err = SampleId::try_from(cid).unwrap_err();
414        assert_eq!(
415            code_err,
416            CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
417        );
418    }
419
420    #[test]
421    fn cid_invalid_codec() {
422        let multihash =
423            Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
424                .unwrap();
425        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
426        let codec_err = SampleId::try_from(cid).unwrap_err();
427        assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
428    }
429
430    #[test]
431    fn test_roundtrip_verify() {
432        for _ in 0..5 {
433            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
434            let dah = DataAvailabilityHeader::from_eds(&eds);
435
436            let row_index = rand::random::<u16>() % eds.square_width();
437            let col_index = rand::random::<u16>() % eds.square_width();
438            let proof_type = if rand::random() {
439                AxisType::Row
440            } else {
441                AxisType::Col
442            };
443
444            let id = SampleId::new(row_index, col_index, 1).unwrap();
445            let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
446
447            let mut buf = BytesMut::new();
448            sample.encode(&mut buf);
449            let decoded = Sample::decode(id, &buf).unwrap();
450
451            decoded.verify(id, &dah).unwrap();
452        }
453    }
454}