1use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::{Sample as RawSample, Share as RawShare};
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{RowId, ROW_ID_SIZE};
22use crate::{bail_validation, DataAvailabilityHeader, Error, Result, Share};
23
24const SAMPLE_ID_SIZE: usize = 12;
26pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
28pub const SAMPLE_ID_CODEC: u64 = 0x7810;
30
31#[derive(Debug, PartialEq, Clone, Copy)]
36pub struct SampleId {
37 row_id: RowId,
38 column_index: u16,
39}
40
41#[derive(Clone, Debug, Serialize)]
43#[serde(into = "RawSample")]
44pub struct Sample {
45 pub proof_type: AxisType,
47 pub share: Share,
49 pub proof: NamespaceProof,
51}
52
53impl Sample {
54 pub fn new(
97 row_index: u16,
98 column_index: u16,
99 proof_type: AxisType,
100 eds: &ExtendedDataSquare,
101 ) -> Result<Self> {
102 let share = eds.share(row_index, column_index)?.clone();
103
104 let range_proof = match proof_type {
105 AxisType::Row => eds
106 .row_nmt(row_index)?
107 .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
108 AxisType::Col => eds
109 .column_nmt(column_index)?
110 .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
111 };
112
113 let proof = NmtNamespaceProof::PresenceProof {
114 proof: range_proof,
115 ignore_max_ns: true,
116 };
117
118 Ok(Sample {
119 share,
120 proof: proof.into(),
121 proof_type,
122 })
123 }
124
125 pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
127 let root = match self.proof_type {
128 AxisType::Row => dah
129 .row_root(id.row_index())
130 .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
131 AxisType::Col => dah
132 .column_root(id.column_index())
133 .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
134 };
135
136 self.proof
137 .verify_range(&root, &[&self.share], *self.share.namespace())
138 .map_err(Error::RangeProofError)
139 }
140
141 pub fn encode(&self, bytes: &mut BytesMut) {
143 let raw = RawSample::from(self.clone());
144
145 bytes.reserve(raw.encoded_len());
146 raw.encode(bytes).expect("capacity reserved");
147 }
148
149 pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
156 let raw = RawSample::decode(buffer)?;
157 Self::from_raw(id, raw)
158 }
159
160 pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
167 let Some(proof) = sample.proof else {
168 return Err(Error::MissingProof);
169 };
170
171 let proof: NamespaceProof = proof.try_into()?;
172 let proof_type = AxisType::try_from(sample.proof_type)?;
173
174 if proof.is_of_absence() {
175 return Err(Error::WrongProofType);
176 }
177
178 let Some(share) = sample.share else {
179 bail_validation!("missing share");
180 };
181 let Some(square_size) = proof.total_leaves() else {
182 bail_validation!("proof must be for single leaf");
183 };
184
185 let row_index = id.row_index() as usize;
186 let col_index = id.column_index() as usize;
187 let share = if row_index < square_size / 2 && col_index < square_size / 2 {
188 Share::from_raw(&share.data)?
189 } else {
190 Share::parity(&share.data)?
191 };
192
193 Ok(Sample {
194 proof_type,
195 share,
196 proof,
197 })
198 }
199}
200
201impl From<Sample> for RawSample {
202 fn from(sample: Sample) -> RawSample {
203 RawSample {
204 share: Some(RawShare {
205 data: sample.share.to_vec(),
206 }),
207 proof: Some(sample.proof.into()),
208 proof_type: sample.proof_type as i32,
209 }
210 }
211}
212
213impl SampleId {
214 pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
234 if block_height == 0 {
235 return Err(Error::ZeroBlockHeight);
236 }
237
238 Ok(SampleId {
239 row_id: RowId::new(row_index, block_height)?,
240 column_index,
241 })
242 }
243
244 pub fn block_height(&self) -> u64 {
246 self.row_id.block_height()
247 }
248
249 pub fn row_index(&self) -> u16 {
253 self.row_id.index()
254 }
255
256 pub fn column_index(&self) -> u16 {
260 self.column_index
261 }
262
263 fn encode(&self, bytes: &mut BytesMut) {
264 bytes.reserve(SAMPLE_ID_SIZE);
265 self.row_id.encode(bytes);
266 bytes.put_u16(self.column_index);
267 }
268
269 fn decode(buffer: &[u8]) -> Result<Self, CidError> {
270 if buffer.len() != SAMPLE_ID_SIZE {
271 return Err(CidError::InvalidMultihashLength(buffer.len()));
272 }
273
274 let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
275 let row_id = RowId::decode(row_bytes)?;
276 let column_index = col_bytes.get_u16();
277
278 Ok(SampleId {
279 row_id,
280 column_index,
281 })
282 }
283}
284
285impl<const S: usize> TryFrom<&CidGeneric<S>> for SampleId {
286 type Error = CidError;
287
288 fn try_from(cid: &CidGeneric<S>) -> Result<Self, Self::Error> {
289 let codec = cid.codec();
290 if codec != SAMPLE_ID_CODEC {
291 return Err(CidError::InvalidCidCodec(codec));
292 }
293
294 let hash = cid.hash();
295
296 let size = hash.size() as usize;
297 if size != SAMPLE_ID_SIZE {
298 return Err(CidError::InvalidMultihashLength(size));
299 }
300
301 let code = hash.code();
302 if code != SAMPLE_ID_MULTIHASH_CODE {
303 return Err(CidError::InvalidMultihashCode(
304 code,
305 SAMPLE_ID_MULTIHASH_CODE,
306 ));
307 }
308
309 SampleId::decode(hash.digest())
310 }
311}
312
313impl<const S: usize> TryFrom<&mut CidGeneric<S>> for SampleId {
314 type Error = CidError;
315
316 fn try_from(cid: &mut CidGeneric<S>) -> Result<Self, Self::Error> {
317 Self::try_from(&*cid)
318 }
319}
320
321impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
322 type Error = CidError;
323
324 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
325 Self::try_from(&cid)
326 }
327}
328
329impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
330 fn from(sample_id: SampleId) -> Self {
331 let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
332 sample_id.encode(&mut bytes);
334
335 let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
336
337 CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::consts::appconsts::AppVersion;
345 use crate::test_utils::generate_dummy_eds;
346
347 #[test]
348 fn round_trip() {
349 let sample_id = SampleId::new(5, 10, 100).unwrap();
350 let cid = CidGeneric::from(sample_id);
351
352 let multihash = cid.hash();
353 assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
354 assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
355
356 let deserialized_sample_id = SampleId::try_from(cid).unwrap();
357 assert_eq!(sample_id, deserialized_sample_id);
358 }
359
360 #[test]
361 fn index_calculation() {
362 let eds = generate_dummy_eds(8, AppVersion::V2);
363
364 Sample::new(0, 0, AxisType::Row, &eds).unwrap();
365 Sample::new(7, 6, AxisType::Row, &eds).unwrap();
366 Sample::new(7, 7, AxisType::Row, &eds).unwrap();
367
368 let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
369 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
370
371 let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
372 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
373 }
374
375 #[test]
376 fn sample_id_size() {
377 assert_eq!(SAMPLE_ID_SIZE, 12);
379
380 let sample_id = SampleId::new(0, 4, 1).unwrap();
381 let mut bytes = BytesMut::new();
382 sample_id.encode(&mut bytes);
383 assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
384 }
385
386 #[test]
387 fn from_buffer() {
388 let bytes = [
389 0x01, 0x90, 0xF0, 0x01, 0x91, 0xF0, 0x01, 0x0C, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, 0, 5, ];
397
398 let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
399 assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
400 let mh = cid.hash();
401 assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
402 assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
403 let sample_id = SampleId::try_from(cid).unwrap();
404 assert_eq!(sample_id.block_height(), 64);
405 assert_eq!(sample_id.row_index(), 7);
406 assert_eq!(sample_id.column_index(), 5);
407 }
408
409 #[test]
410 fn multihash_invalid_code() {
411 let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
412 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
413 let code_err = SampleId::try_from(cid).unwrap_err();
414 assert_eq!(
415 code_err,
416 CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
417 );
418 }
419
420 #[test]
421 fn cid_invalid_codec() {
422 let multihash =
423 Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
424 .unwrap();
425 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
426 let codec_err = SampleId::try_from(cid).unwrap_err();
427 assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
428 }
429
430 #[test]
431 fn test_roundtrip_verify() {
432 for _ in 0..5 {
433 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
434 let dah = DataAvailabilityHeader::from_eds(&eds);
435
436 let row_index = rand::random::<u16>() % eds.square_width();
437 let col_index = rand::random::<u16>() % eds.square_width();
438 let proof_type = if rand::random() {
439 AxisType::Row
440 } else {
441 AxisType::Col
442 };
443
444 let id = SampleId::new(row_index, col_index, 1).unwrap();
445 let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
446
447 let mut buf = BytesMut::new();
448 sample.encode(&mut buf);
449 let decoded = Sample::decode(id, &buf).unwrap();
450
451 decoded.verify(id, &dah).unwrap();
452 }
453 }
454}