celestia_types/
sample.rs

1//! Types related to samples.
2//!
3//! Sample in Celestia is understood as a single [`Share`] located at an
4//! index in the particular [`row`] of the [`ExtendedDataSquare`].
5//!
6//! [`row`]: crate::row
7//! [`Share`]: crate::Share
8//! [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
9
10use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::{Sample as RawSample, Share as RawShare};
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{RowId, ROW_ID_SIZE};
22use crate::{bail_validation, DataAvailabilityHeader, Error, Result, Share};
23
24/// Number of bytes needed to represent [`SampleId`] in `multihash`.
25const SAMPLE_ID_SIZE: usize = 12;
26/// The code of the [`SampleId`] hashing algorithm in `multihash`.
27pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
28/// The id of codec used for the [`SampleId`] in `Cid`s.
29pub const SAMPLE_ID_CODEC: u64 = 0x7810;
30
31/// Identifies a particular [`Share`] located in the [`ExtendedDataSquare`].
32///
33/// [`Share`]: crate::Share
34/// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
35#[derive(Debug, PartialEq, Clone, Copy)]
36pub struct SampleId {
37    row_id: RowId,
38    column_index: u16,
39}
40
41/// Represents Sample, with proof of its inclusion
42#[derive(Clone, Debug, Serialize)]
43#[serde(into = "RawSample")]
44pub struct Sample {
45    /// Indication whether proving was done row or column-wise
46    pub proof_type: AxisType,
47    /// Share that is being sampled
48    pub share: Share,
49    /// Proof of the inclusion of the share
50    pub proof: NamespaceProof,
51}
52
53impl Sample {
54    /// Create a new [`Sample`] for the given index of the [`ExtendedDataSquare`] in a block.
55    ///
56    /// `row_index` and `column_index` specifies the [`Share`] position in EDS.
57    /// `proof_type` determines whether proof of inclusion of the [`Share`] should be
58    /// constructed for its row or column.
59    ///
60    /// # Errors
61    ///
62    /// This function will return an error, if:
63    ///
64    /// - `row_index`/`column_index` falls outside the provided [`ExtendedDataSquare`].
65    /// - [`ExtendedDataSquare`] is incorrect (either data shares don't have their namespace
66    ///   prefixed, or [`Share`]s aren't namespace ordered)
67    /// - Block height is zero
68    ///
69    /// # Example
70    ///
71    /// ```no_run
72    /// use celestia_types::AxisType;
73    /// use celestia_types::sample::{Sample, SampleId};
74    /// # use celestia_types::{ExtendedDataSquare, ExtendedHeader};
75    /// #
76    /// # fn get_extended_data_square(height: u64) -> ExtendedDataSquare {
77    /// #    unimplemented!()
78    /// # }
79    /// #
80    /// # fn get_extended_header(height: u64) -> ExtendedHeader {
81    /// #    unimplemented!()
82    /// # }
83    ///
84    /// let block_height = 15;
85    /// let eds = get_extended_data_square(block_height);
86    /// let header = get_extended_header(block_height);
87    ///
88    /// let sample_id = SampleId::new(2, 3, block_height).unwrap();
89    /// let sample = Sample::new(2, 3, AxisType::Row, &eds).unwrap();
90    ///
91    /// sample.verify(sample_id, &header.dah).unwrap();
92    /// ```
93    ///
94    /// [`Share`]: crate::Share
95    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
96    pub fn new(
97        row_index: u16,
98        column_index: u16,
99        proof_type: AxisType,
100        eds: &ExtendedDataSquare,
101    ) -> Result<Self> {
102        let share = eds.share(row_index, column_index)?.clone();
103
104        let range_proof = match proof_type {
105            AxisType::Row => eds
106                .row_nmt(row_index)?
107                .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
108            AxisType::Col => eds
109                .column_nmt(column_index)?
110                .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
111        };
112
113        let proof = NmtNamespaceProof::PresenceProof {
114            proof: range_proof,
115            ignore_max_ns: true,
116        };
117
118        Ok(Sample {
119            share,
120            proof: proof.into(),
121            proof_type,
122        })
123    }
124
125    /// verify sample with root hash from ExtendedHeader
126    pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
127        let root = match self.proof_type {
128            AxisType::Row => dah
129                .row_root(id.row_index())
130                .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
131            AxisType::Col => dah
132                .column_root(id.column_index())
133                .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
134        };
135
136        self.proof
137            .verify_range(&root, &[&self.share], *self.share.namespace())
138            .map_err(Error::RangeProofError)
139    }
140
141    /// Encode Sample into the raw binary representation.
142    pub fn encode(&self, bytes: &mut BytesMut) {
143        let raw = RawSample::from(self.clone());
144
145        bytes.reserve(raw.encoded_len());
146        raw.encode(bytes).expect("capacity reserved");
147    }
148
149    /// Decode Sample from the binary representation.
150    ///
151    /// # Errors
152    ///
153    /// This function will return an error if protobuf deserialization
154    /// fails and propagate errors from [`Sample::from_raw`].
155    pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
156        let raw = RawSample::decode(buffer)?;
157        Self::from_raw(id, raw)
158    }
159
160    /// Recover Sample from it's raw representation.
161    ///
162    /// # Errors
163    ///
164    /// This function will return error if proof is missing or invalid shares are not in
165    /// the expected namespace, and will propagate errors from [`Share`] construction.
166    pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
167        let Some(proof) = sample.proof else {
168            return Err(Error::MissingProof);
169        };
170
171        let proof: NamespaceProof = proof.try_into()?;
172        let proof_type = AxisType::try_from(sample.proof_type)?;
173
174        if proof.is_of_absence() {
175            return Err(Error::WrongProofType);
176        }
177
178        let Some(share) = sample.share else {
179            bail_validation!("missing share");
180        };
181        let Some(square_size) = proof.total_leaves() else {
182            bail_validation!("proof must be for single leaf");
183        };
184
185        let row_index = id.row_index() as usize;
186        let col_index = id.column_index() as usize;
187        let share = if row_index < square_size / 2 && col_index < square_size / 2 {
188            Share::from_raw(&share.data)?
189        } else {
190            Share::parity(&share.data)?
191        };
192
193        Ok(Sample {
194            proof_type,
195            share,
196            proof,
197        })
198    }
199}
200
201impl From<Sample> for RawSample {
202    fn from(sample: Sample) -> RawSample {
203        RawSample {
204            share: Some(RawShare {
205                data: sample.share.to_vec(),
206            }),
207            proof: Some(sample.proof.into()),
208            proof_type: sample.proof_type as i32,
209        }
210    }
211}
212
213impl SampleId {
214    /// Create a new [`SampleId`] for the given `row_index` and `column_index` of the
215    /// [`ExtendedDataSquare`] in a block.
216    ///
217    /// # Errors
218    ///
219    /// This function will return an error if the block height is zero.
220    ///
221    /// # Example
222    ///
223    /// ```no_run
224    /// use celestia_types::sample::SampleId;
225    ///
226    /// // Consider a 64th share of EDS with block height of 15
227    /// let header_height = 15;
228    /// SampleId::new(2, 1, header_height).unwrap();
229    /// ```
230    ///
231    /// [`Share`]: crate::Share
232    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
233    pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
234        if block_height == 0 {
235            return Err(Error::ZeroBlockHeight);
236        }
237
238        Ok(SampleId {
239            row_id: RowId::new(row_index, block_height)?,
240            column_index,
241        })
242    }
243
244    /// A height of the block which contains the sample.
245    pub fn block_height(&self) -> u64 {
246        self.row_id.block_height()
247    }
248
249    /// Row index of the [`ExtendedDataSquare`] that sample is located on.
250    ///
251    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
252    pub fn row_index(&self) -> u16 {
253        self.row_id.index()
254    }
255
256    /// Column index of the [`ExtendedDataSquare`] that sample is located on.
257    ///
258    /// [`ExtendedDataSquare`]: crate::eds::ExtendedDataSquare
259    pub fn column_index(&self) -> u16 {
260        self.column_index
261    }
262
263    fn encode(&self, bytes: &mut BytesMut) {
264        bytes.reserve(SAMPLE_ID_SIZE);
265        self.row_id.encode(bytes);
266        bytes.put_u16(self.column_index);
267    }
268
269    fn decode(buffer: &[u8]) -> Result<Self, CidError> {
270        if buffer.len() != SAMPLE_ID_SIZE {
271            return Err(CidError::InvalidMultihashLength(buffer.len()));
272        }
273
274        let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
275        let row_id = RowId::decode(row_bytes)?;
276        let column_index = col_bytes.get_u16();
277
278        Ok(SampleId {
279            row_id,
280            column_index,
281        })
282    }
283}
284
285impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
286    type Error = CidError;
287
288    fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
289        let codec = cid.codec();
290        if codec != SAMPLE_ID_CODEC {
291            return Err(CidError::InvalidCidCodec(codec));
292        }
293
294        let hash = cid.hash();
295
296        let size = hash.size() as usize;
297        if size != SAMPLE_ID_SIZE {
298            return Err(CidError::InvalidMultihashLength(size));
299        }
300
301        let code = hash.code();
302        if code != SAMPLE_ID_MULTIHASH_CODE {
303            return Err(CidError::InvalidMultihashCode(
304                code,
305                SAMPLE_ID_MULTIHASH_CODE,
306            ));
307        }
308
309        SampleId::decode(hash.digest())
310    }
311}
312
313impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
314    fn from(sample_id: SampleId) -> Self {
315        let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
316        // length is correct, so unwrap is safe
317        sample_id.encode(&mut bytes);
318
319        let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
320
321        CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::consts::appconsts::AppVersion;
329    use crate::test_utils::generate_dummy_eds;
330
331    #[test]
332    fn round_trip() {
333        let sample_id = SampleId::new(5, 10, 100).unwrap();
334        let cid = CidGeneric::from(sample_id);
335
336        let multihash = cid.hash();
337        assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
338        assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
339
340        let deserialized_sample_id = SampleId::try_from(cid).unwrap();
341        assert_eq!(sample_id, deserialized_sample_id);
342    }
343
344    #[test]
345    fn index_calculation() {
346        let eds = generate_dummy_eds(8, AppVersion::V2);
347
348        Sample::new(0, 0, AxisType::Row, &eds).unwrap();
349        Sample::new(7, 6, AxisType::Row, &eds).unwrap();
350        Sample::new(7, 7, AxisType::Row, &eds).unwrap();
351
352        let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
353        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
354
355        let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
356        assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
357    }
358
359    #[test]
360    fn sample_id_size() {
361        // Size MUST be 12 by the spec.
362        assert_eq!(SAMPLE_ID_SIZE, 12);
363
364        let sample_id = SampleId::new(0, 4, 1).unwrap();
365        let mut bytes = BytesMut::new();
366        sample_id.encode(&mut bytes);
367        assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
368    }
369
370    #[test]
371    fn from_buffer() {
372        let bytes = [
373            0x01, // CIDv1
374            0x90, 0xF0, 0x01, // CID codec = 7810
375            0x91, 0xF0, 0x01, // multihash code = 7811
376            0x0C, // len = SAMPLE_ID_SIZE = 12
377            0, 0, 0, 0, 0, 0, 0, 64, // block height = 64
378            0, 7, // row index = 7
379            0, 5, // sample index = 5
380        ];
381
382        let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
383        assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
384        let mh = cid.hash();
385        assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
386        assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
387        let sample_id = SampleId::try_from(cid).unwrap();
388        assert_eq!(sample_id.block_height(), 64);
389        assert_eq!(sample_id.row_index(), 7);
390        assert_eq!(sample_id.column_index(), 5);
391    }
392
393    #[test]
394    fn multihash_invalid_code() {
395        let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
396        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
397        let code_err = SampleId::try_from(cid).unwrap_err();
398        assert_eq!(
399            code_err,
400            CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
401        );
402    }
403
404    #[test]
405    fn cid_invalid_codec() {
406        let multihash =
407            Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
408                .unwrap();
409        let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
410        let codec_err = SampleId::try_from(cid).unwrap_err();
411        assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
412    }
413
414    #[test]
415    fn test_roundtrip_verify() {
416        for _ in 0..5 {
417            let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
418            let dah = DataAvailabilityHeader::from_eds(&eds);
419
420            let row_index = rand::random::<u16>() % eds.square_width();
421            let col_index = rand::random::<u16>() % eds.square_width();
422            let proof_type = if rand::random() {
423                AxisType::Row
424            } else {
425                AxisType::Col
426            };
427
428            let id = SampleId::new(row_index, col_index, 1).unwrap();
429            let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
430
431            let mut buf = BytesMut::new();
432            sample.encode(&mut buf);
433            let decoded = Sample::decode(id, &buf).unwrap();
434
435            decoded.verify(id, &dah).unwrap();
436        }
437    }
438}