1use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::Share as RawShare;
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{ROW_ID_SIZE, RowId};
22use crate::{DataAvailabilityHeader, Error, Result, Share, bail_validation};
23
24pub use celestia_proto::shwap::Sample as RawSample;
25
26const SAMPLE_ID_SIZE: usize = 12;
28pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
30pub const SAMPLE_ID_CODEC: u64 = 0x7810;
32
33#[derive(Debug, PartialEq, Clone, Copy)]
38pub struct SampleId {
39 row_id: RowId,
40 column_index: u16,
41}
42
43#[derive(Clone, Debug, Serialize)]
45#[serde(into = "RawSample")]
46pub struct Sample {
47 pub proof_type: AxisType,
49 pub share: Share,
51 pub proof: NamespaceProof,
53}
54
55impl Sample {
56 pub fn new(
99 row_index: u16,
100 column_index: u16,
101 proof_type: AxisType,
102 eds: &ExtendedDataSquare,
103 ) -> Result<Self> {
104 let share = eds.share(row_index, column_index)?.clone();
105
106 let range_proof = match proof_type {
107 AxisType::Row => eds
108 .row_nmt(row_index)?
109 .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
110 AxisType::Col => eds
111 .column_nmt(column_index)?
112 .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
113 };
114
115 let proof = NmtNamespaceProof::PresenceProof {
116 proof: range_proof,
117 ignore_max_ns: true,
118 };
119
120 Ok(Sample {
121 share,
122 proof: proof.into(),
123 proof_type,
124 })
125 }
126
127 pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
129 let root = match self.proof_type {
130 AxisType::Row => dah
131 .row_root(id.row_index())
132 .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
133 AxisType::Col => dah
134 .column_root(id.column_index())
135 .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
136 };
137
138 self.proof
139 .verify_range(&root, &[&self.share], *self.share.namespace())
140 .map_err(Error::RangeProofError)
141 }
142
143 pub fn encode(&self, bytes: &mut BytesMut) {
145 let raw = RawSample::from(self.clone());
146
147 bytes.reserve(raw.encoded_len());
148 raw.encode(bytes).expect("capacity reserved");
149 }
150
151 pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
158 let raw = RawSample::decode(buffer)?;
159 Self::from_raw(id, raw)
160 }
161
162 pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
169 let Some(proof) = sample.proof else {
170 return Err(Error::MissingProof);
171 };
172
173 let proof: NamespaceProof = proof.try_into()?;
174 let proof_type = AxisType::try_from(sample.proof_type)?;
175
176 if proof.is_of_absence() {
177 return Err(Error::WrongProofType);
178 }
179
180 let Some(share) = sample.share else {
181 bail_validation!("missing share");
182 };
183 let Some(square_size) = proof.total_leaves() else {
184 bail_validation!("proof must be for single leaf");
185 };
186
187 let row_index = id.row_index() as usize;
188 let col_index = id.column_index() as usize;
189 let share = if row_index < square_size / 2 && col_index < square_size / 2 {
190 Share::from_raw(&share.data)?
191 } else {
192 Share::parity(&share.data)?
193 };
194
195 Ok(Sample {
196 proof_type,
197 share,
198 proof,
199 })
200 }
201}
202
203impl From<Sample> for RawSample {
204 fn from(sample: Sample) -> RawSample {
205 RawSample {
206 share: Some(RawShare {
207 data: sample.share.to_vec(),
208 }),
209 proof: Some(sample.proof.into()),
210 proof_type: sample.proof_type as i32,
211 }
212 }
213}
214
215impl SampleId {
216 pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
236 if block_height == 0 {
237 return Err(Error::ZeroBlockHeight);
238 }
239
240 Ok(SampleId {
241 row_id: RowId::new(row_index, block_height)?,
242 column_index,
243 })
244 }
245
246 pub fn block_height(&self) -> u64 {
248 self.row_id.block_height()
249 }
250
251 pub fn row_index(&self) -> u16 {
255 self.row_id.index()
256 }
257
258 pub fn column_index(&self) -> u16 {
262 self.column_index
263 }
264
265 fn encode(&self, bytes: &mut BytesMut) {
266 bytes.reserve(SAMPLE_ID_SIZE);
267 self.row_id.encode(bytes);
268 bytes.put_u16(self.column_index);
269 }
270
271 fn decode(buffer: &[u8]) -> Result<Self, CidError> {
272 if buffer.len() != SAMPLE_ID_SIZE {
273 return Err(CidError::InvalidMultihashLength(buffer.len()));
274 }
275
276 let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
277 let row_id = RowId::decode(row_bytes)?;
278 let column_index = col_bytes.get_u16();
279
280 Ok(SampleId {
281 row_id,
282 column_index,
283 })
284 }
285}
286
287impl<const S: usize> TryFrom<&CidGeneric<S>> for SampleId {
288 type Error = CidError;
289
290 fn try_from(cid: &CidGeneric<S>) -> Result<Self, Self::Error> {
291 let codec = cid.codec();
292 if codec != SAMPLE_ID_CODEC {
293 return Err(CidError::InvalidCidCodec(codec));
294 }
295
296 let hash = cid.hash();
297
298 let size = hash.size() as usize;
299 if size != SAMPLE_ID_SIZE {
300 return Err(CidError::InvalidMultihashLength(size));
301 }
302
303 let code = hash.code();
304 if code != SAMPLE_ID_MULTIHASH_CODE {
305 return Err(CidError::InvalidMultihashCode(
306 code,
307 SAMPLE_ID_MULTIHASH_CODE,
308 ));
309 }
310
311 SampleId::decode(hash.digest())
312 }
313}
314
315impl<const S: usize> TryFrom<&mut CidGeneric<S>> for SampleId {
316 type Error = CidError;
317
318 fn try_from(cid: &mut CidGeneric<S>) -> Result<Self, Self::Error> {
319 Self::try_from(&*cid)
320 }
321}
322
323impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
324 type Error = CidError;
325
326 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
327 Self::try_from(&cid)
328 }
329}
330
331impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
332 fn from(sample_id: SampleId) -> Self {
333 let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
334 sample_id.encode(&mut bytes);
336
337 let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
338
339 CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::consts::appconsts::AppVersion;
347 use crate::test_utils::generate_dummy_eds;
348
349 #[test]
350 fn round_trip() {
351 let sample_id = SampleId::new(5, 10, 100).unwrap();
352 let cid = CidGeneric::from(sample_id);
353
354 let multihash = cid.hash();
355 assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
356 assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
357
358 let deserialized_sample_id = SampleId::try_from(cid).unwrap();
359 assert_eq!(sample_id, deserialized_sample_id);
360 }
361
362 #[test]
363 fn index_calculation() {
364 let eds = generate_dummy_eds(8, AppVersion::V2);
365
366 Sample::new(0, 0, AxisType::Row, &eds).unwrap();
367 Sample::new(7, 6, AxisType::Row, &eds).unwrap();
368 Sample::new(7, 7, AxisType::Row, &eds).unwrap();
369
370 let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
371 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
372
373 let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
374 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
375 }
376
377 #[test]
378 fn sample_id_size() {
379 assert_eq!(SAMPLE_ID_SIZE, 12);
381
382 let sample_id = SampleId::new(0, 4, 1).unwrap();
383 let mut bytes = BytesMut::new();
384 sample_id.encode(&mut bytes);
385 assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
386 }
387
388 #[test]
389 fn from_buffer() {
390 let bytes = [
391 0x01, 0x90, 0xF0, 0x01, 0x91, 0xF0, 0x01, 0x0C, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, 0, 5, ];
399
400 let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
401 assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
402 let mh = cid.hash();
403 assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
404 assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
405 let sample_id = SampleId::try_from(cid).unwrap();
406 assert_eq!(sample_id.block_height(), 64);
407 assert_eq!(sample_id.row_index(), 7);
408 assert_eq!(sample_id.column_index(), 5);
409 }
410
411 #[test]
412 fn multihash_invalid_code() {
413 let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
414 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
415 let code_err = SampleId::try_from(cid).unwrap_err();
416 assert_eq!(
417 code_err,
418 CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
419 );
420 }
421
422 #[test]
423 fn cid_invalid_codec() {
424 let multihash =
425 Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
426 .unwrap();
427 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
428 let codec_err = SampleId::try_from(cid).unwrap_err();
429 assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
430 }
431
432 #[test]
433 fn test_roundtrip_verify() {
434 for _ in 0..5 {
435 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
436 let dah = DataAvailabilityHeader::from_eds(&eds);
437
438 let row_index = rand::random::<u16>() % eds.square_width();
439 let col_index = rand::random::<u16>() % eds.square_width();
440 let proof_type = if rand::random() {
441 AxisType::Row
442 } else {
443 AxisType::Col
444 };
445
446 let id = SampleId::new(row_index, col_index, 1).unwrap();
447 let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
448
449 let mut buf = BytesMut::new();
450 sample.encode(&mut buf);
451 let decoded = Sample::decode(id, &buf).unwrap();
452
453 decoded.verify(id, &dah).unwrap();
454 }
455 }
456}