1use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::Share as RawShare;
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{ROW_ID_SIZE, RowId};
22use crate::{DataAvailabilityHeader, Error, Result, Share, bail_validation};
23
24pub use celestia_proto::shwap::Sample as RawSample;
25
26const SAMPLE_ID_SIZE: usize = 12;
28pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
30pub const SAMPLE_ID_CODEC: u64 = 0x7810;
32
33#[derive(Debug, PartialEq, Clone, Copy)]
38pub struct SampleId {
39 row_id: RowId,
40 column_index: u16,
41}
42
43#[derive(Clone, Debug, Serialize)]
45#[serde(into = "RawSample")]
46pub struct Sample {
47 pub proof_type: AxisType,
49 pub share: Share,
51 pub proof: NamespaceProof,
53}
54
55impl Sample {
56 pub fn new(
99 row_index: u16,
100 column_index: u16,
101 proof_type: AxisType,
102 eds: &ExtendedDataSquare,
103 ) -> Result<Self> {
104 let share = eds.share(row_index, column_index)?.clone();
105
106 let range_proof = match proof_type {
107 AxisType::Row => eds
108 .row_nmt(row_index)?
109 .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
110 AxisType::Col => eds
111 .column_nmt(column_index)?
112 .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
113 };
114
115 let proof = NmtNamespaceProof::PresenceProof {
116 proof: range_proof,
117 ignore_max_ns: true,
118 };
119
120 Ok(Sample {
121 share,
122 proof: proof.into(),
123 proof_type,
124 })
125 }
126
127 pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
129 let root = match self.proof_type {
130 AxisType::Row => dah
131 .row_root(id.row_index())
132 .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
133 AxisType::Col => dah
134 .column_root(id.column_index())
135 .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
136 };
137
138 self.proof
139 .verify_range(&root, &[&self.share], *self.share.namespace())
140 .map_err(Error::RangeProofError)
141 }
142
143 pub fn encode(&self, bytes: &mut BytesMut) {
145 let raw = RawSample::from(self.clone());
146
147 bytes.reserve(raw.encoded_len());
148 raw.encode(bytes).expect("capacity reserved");
149 }
150
151 pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
158 let raw = RawSample::decode(buffer)?;
159 Self::from_raw(id, raw)
160 }
161
162 pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
169 let Some(proof) = sample.proof else {
170 return Err(Error::MissingProof);
171 };
172
173 let proof: NamespaceProof = proof.try_into()?;
174 let proof_type = AxisType::try_from(sample.proof_type)?;
175
176 if proof.is_of_absence() {
177 return Err(Error::WrongProofType);
178 }
179
180 let Some(share) = sample.share else {
181 bail_validation!("missing share");
182 };
183 let Some(square_size) = proof.total_leaves() else {
184 bail_validation!("proof must be for single leaf");
185 };
186
187 let row_index = id.row_index() as usize;
188 let col_index = id.column_index() as usize;
189 let share = if row_index < square_size / 2 && col_index < square_size / 2 {
190 Share::from_raw(&share.data)?
191 } else {
192 Share::parity(&share.data)?
193 };
194
195 Ok(Sample {
196 proof_type,
197 share,
198 proof,
199 })
200 }
201}
202
203impl From<Sample> for RawSample {
204 fn from(sample: Sample) -> RawSample {
205 RawSample {
206 share: Some(RawShare {
207 data: sample.share.to_vec(),
208 }),
209 proof: Some(sample.proof.into()),
210 proof_type: sample.proof_type as i32,
211 }
212 }
213}
214
215impl SampleId {
216 pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
236 if block_height == 0 {
237 return Err(Error::ZeroBlockHeight);
238 }
239
240 Ok(SampleId {
241 row_id: RowId::new(row_index, block_height)?,
242 column_index,
243 })
244 }
245
246 pub fn block_height(&self) -> u64 {
248 self.row_id.block_height()
249 }
250
251 pub fn row_index(&self) -> u16 {
255 self.row_id.index()
256 }
257
258 pub fn column_index(&self) -> u16 {
262 self.column_index
263 }
264
265 pub fn encode(&self, bytes: &mut BytesMut) {
267 bytes.reserve(SAMPLE_ID_SIZE);
268 self.row_id.encode(bytes);
269 bytes.put_u16(self.column_index);
270 }
271
272 pub fn decode(buffer: &[u8]) -> Result<Self> {
274 if buffer.len() != SAMPLE_ID_SIZE {
275 return Err(Error::InvalidLength(buffer.len(), SAMPLE_ID_SIZE));
276 }
277
278 let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
279 let row_id = RowId::decode(row_bytes)?;
280 let column_index = col_bytes.get_u16();
281
282 Ok(SampleId {
283 row_id,
284 column_index,
285 })
286 }
287}
288
289impl<const S: usize> TryFrom<&CidGeneric<S>> for SampleId {
290 type Error = CidError;
291
292 fn try_from(cid: &CidGeneric<S>) -> Result<Self, Self::Error> {
293 let codec = cid.codec();
294 if codec != SAMPLE_ID_CODEC {
295 return Err(CidError::InvalidCidCodec(codec));
296 }
297
298 let hash = cid.hash();
299
300 let size = hash.size() as usize;
301 if size != SAMPLE_ID_SIZE {
302 return Err(CidError::InvalidMultihashLength(size));
303 }
304
305 let code = hash.code();
306 if code != SAMPLE_ID_MULTIHASH_CODE {
307 return Err(CidError::InvalidMultihashCode(
308 code,
309 SAMPLE_ID_MULTIHASH_CODE,
310 ));
311 }
312
313 SampleId::decode(hash.digest()).map_err(|e| CidError::InvalidCid(e.to_string()))
314 }
315}
316
317impl<const S: usize> TryFrom<&mut CidGeneric<S>> for SampleId {
318 type Error = CidError;
319
320 fn try_from(cid: &mut CidGeneric<S>) -> Result<Self, Self::Error> {
321 Self::try_from(&*cid)
322 }
323}
324
325impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
326 type Error = CidError;
327
328 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
329 Self::try_from(&cid)
330 }
331}
332
333impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
334 fn from(sample_id: SampleId) -> Self {
335 let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
336 sample_id.encode(&mut bytes);
338
339 let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
340
341 CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::consts::appconsts::AppVersion;
349 use crate::test_utils::generate_dummy_eds;
350
351 #[test]
352 fn round_trip() {
353 let sample_id = SampleId::new(5, 10, 100).unwrap();
354 let cid = CidGeneric::from(sample_id);
355
356 let multihash = cid.hash();
357 assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
358 assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
359
360 let deserialized_sample_id = SampleId::try_from(cid).unwrap();
361 assert_eq!(sample_id, deserialized_sample_id);
362 }
363
364 #[test]
365 fn index_calculation() {
366 let eds = generate_dummy_eds(8, AppVersion::V2);
367
368 Sample::new(0, 0, AxisType::Row, &eds).unwrap();
369 Sample::new(7, 6, AxisType::Row, &eds).unwrap();
370 Sample::new(7, 7, AxisType::Row, &eds).unwrap();
371
372 let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
373 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
374
375 let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
376 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
377 }
378
379 #[test]
380 fn sample_id_size() {
381 assert_eq!(SAMPLE_ID_SIZE, 12);
383
384 let sample_id = SampleId::new(0, 4, 1).unwrap();
385 let mut bytes = BytesMut::new();
386 sample_id.encode(&mut bytes);
387 assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
388 }
389
390 #[test]
391 fn from_buffer() {
392 let bytes = [
393 0x01, 0x90, 0xF0, 0x01, 0x91, 0xF0, 0x01, 0x0C, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, 0, 5, ];
401
402 let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
403 assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
404 let mh = cid.hash();
405 assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
406 assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
407 let sample_id = SampleId::try_from(cid).unwrap();
408 assert_eq!(sample_id.block_height(), 64);
409 assert_eq!(sample_id.row_index(), 7);
410 assert_eq!(sample_id.column_index(), 5);
411 }
412
413 #[test]
414 fn multihash_invalid_code() {
415 let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
416 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
417 let code_err = SampleId::try_from(cid).unwrap_err();
418 assert_eq!(
419 code_err,
420 CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
421 );
422 }
423
424 #[test]
425 fn cid_invalid_codec() {
426 let multihash =
427 Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
428 .unwrap();
429 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
430 let codec_err = SampleId::try_from(cid).unwrap_err();
431 assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
432 }
433
434 #[test]
435 fn test_roundtrip_verify() {
436 for _ in 0..5 {
437 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
438 let dah = DataAvailabilityHeader::from_eds(&eds);
439
440 let row_index = rand::random::<u16>() % eds.square_width();
441 let col_index = rand::random::<u16>() % eds.square_width();
442 let proof_type = if rand::random() {
443 AxisType::Row
444 } else {
445 AxisType::Col
446 };
447
448 let id = SampleId::new(row_index, col_index, 1).unwrap();
449 let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
450
451 let mut buf = BytesMut::new();
452 sample.encode(&mut buf);
453 let decoded = Sample::decode(id, &buf).unwrap();
454
455 decoded.verify(id, &dah).unwrap();
456 }
457 }
458}