1use blockstore::block::CidError;
11use bytes::{Buf, BufMut, BytesMut};
12use celestia_proto::shwap::{Sample as RawSample, Share as RawShare};
13use cid::CidGeneric;
14use multihash::Multihash;
15use nmt_rs::nmt_proof::NamespaceProof as NmtNamespaceProof;
16use prost::Message;
17use serde::Serialize;
18
19use crate::eds::{AxisType, ExtendedDataSquare};
20use crate::nmt::NamespaceProof;
21use crate::row::{RowId, ROW_ID_SIZE};
22use crate::{bail_validation, DataAvailabilityHeader, Error, Result, Share};
23
24const SAMPLE_ID_SIZE: usize = 12;
26pub const SAMPLE_ID_MULTIHASH_CODE: u64 = 0x7811;
28pub const SAMPLE_ID_CODEC: u64 = 0x7810;
30
31#[derive(Debug, PartialEq, Clone, Copy)]
36pub struct SampleId {
37 row_id: RowId,
38 column_index: u16,
39}
40
41#[derive(Clone, Debug, Serialize)]
43#[serde(into = "RawSample")]
44pub struct Sample {
45 pub proof_type: AxisType,
47 pub share: Share,
49 pub proof: NamespaceProof,
51}
52
53impl Sample {
54 pub fn new(
97 row_index: u16,
98 column_index: u16,
99 proof_type: AxisType,
100 eds: &ExtendedDataSquare,
101 ) -> Result<Self> {
102 let share = eds.share(row_index, column_index)?.clone();
103
104 let range_proof = match proof_type {
105 AxisType::Row => eds
106 .row_nmt(row_index)?
107 .build_range_proof(usize::from(column_index)..usize::from(column_index) + 1),
108 AxisType::Col => eds
109 .column_nmt(column_index)?
110 .build_range_proof(usize::from(row_index)..usize::from(row_index) + 1),
111 };
112
113 let proof = NmtNamespaceProof::PresenceProof {
114 proof: range_proof,
115 ignore_max_ns: true,
116 };
117
118 Ok(Sample {
119 share,
120 proof: proof.into(),
121 proof_type,
122 })
123 }
124
125 pub fn verify(&self, id: SampleId, dah: &DataAvailabilityHeader) -> Result<()> {
127 let root = match self.proof_type {
128 AxisType::Row => dah
129 .row_root(id.row_index())
130 .ok_or(Error::EdsIndexOutOfRange(id.row_index(), 0))?,
131 AxisType::Col => dah
132 .column_root(id.column_index())
133 .ok_or(Error::EdsIndexOutOfRange(0, id.column_index()))?,
134 };
135
136 self.proof
137 .verify_range(&root, &[&self.share], *self.share.namespace())
138 .map_err(Error::RangeProofError)
139 }
140
141 pub fn encode(&self, bytes: &mut BytesMut) {
143 let raw = RawSample::from(self.clone());
144
145 bytes.reserve(raw.encoded_len());
146 raw.encode(bytes).expect("capacity reserved");
147 }
148
149 pub fn decode(id: SampleId, buffer: &[u8]) -> Result<Self> {
156 let raw = RawSample::decode(buffer)?;
157 Self::from_raw(id, raw)
158 }
159
160 pub fn from_raw(id: SampleId, sample: RawSample) -> Result<Self> {
167 let Some(proof) = sample.proof else {
168 return Err(Error::MissingProof);
169 };
170
171 let proof: NamespaceProof = proof.try_into()?;
172 let proof_type = AxisType::try_from(sample.proof_type)?;
173
174 if proof.is_of_absence() {
175 return Err(Error::WrongProofType);
176 }
177
178 let Some(share) = sample.share else {
179 bail_validation!("missing share");
180 };
181 let Some(square_size) = proof.total_leaves() else {
182 bail_validation!("proof must be for single leaf");
183 };
184
185 let row_index = id.row_index() as usize;
186 let col_index = id.column_index() as usize;
187 let share = if row_index < square_size / 2 && col_index < square_size / 2 {
188 Share::from_raw(&share.data)?
189 } else {
190 Share::parity(&share.data)?
191 };
192
193 Ok(Sample {
194 proof_type,
195 share,
196 proof,
197 })
198 }
199}
200
201impl From<Sample> for RawSample {
202 fn from(sample: Sample) -> RawSample {
203 RawSample {
204 share: Some(RawShare {
205 data: sample.share.to_vec(),
206 }),
207 proof: Some(sample.proof.into()),
208 proof_type: sample.proof_type as i32,
209 }
210 }
211}
212
213impl SampleId {
214 pub fn new(row_index: u16, column_index: u16, block_height: u64) -> Result<Self> {
234 if block_height == 0 {
235 return Err(Error::ZeroBlockHeight);
236 }
237
238 Ok(SampleId {
239 row_id: RowId::new(row_index, block_height)?,
240 column_index,
241 })
242 }
243
244 pub fn block_height(&self) -> u64 {
246 self.row_id.block_height()
247 }
248
249 pub fn row_index(&self) -> u16 {
253 self.row_id.index()
254 }
255
256 pub fn column_index(&self) -> u16 {
260 self.column_index
261 }
262
263 fn encode(&self, bytes: &mut BytesMut) {
264 bytes.reserve(SAMPLE_ID_SIZE);
265 self.row_id.encode(bytes);
266 bytes.put_u16(self.column_index);
267 }
268
269 fn decode(buffer: &[u8]) -> Result<Self, CidError> {
270 if buffer.len() != SAMPLE_ID_SIZE {
271 return Err(CidError::InvalidMultihashLength(buffer.len()));
272 }
273
274 let (row_bytes, mut col_bytes) = buffer.split_at(ROW_ID_SIZE);
275 let row_id = RowId::decode(row_bytes)?;
276 let column_index = col_bytes.get_u16();
277
278 Ok(SampleId {
279 row_id,
280 column_index,
281 })
282 }
283}
284
285impl<const S: usize> TryFrom<CidGeneric<S>> for SampleId {
286 type Error = CidError;
287
288 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
289 let codec = cid.codec();
290 if codec != SAMPLE_ID_CODEC {
291 return Err(CidError::InvalidCidCodec(codec));
292 }
293
294 let hash = cid.hash();
295
296 let size = hash.size() as usize;
297 if size != SAMPLE_ID_SIZE {
298 return Err(CidError::InvalidMultihashLength(size));
299 }
300
301 let code = hash.code();
302 if code != SAMPLE_ID_MULTIHASH_CODE {
303 return Err(CidError::InvalidMultihashCode(
304 code,
305 SAMPLE_ID_MULTIHASH_CODE,
306 ));
307 }
308
309 SampleId::decode(hash.digest())
310 }
311}
312
313impl From<SampleId> for CidGeneric<SAMPLE_ID_SIZE> {
314 fn from(sample_id: SampleId) -> Self {
315 let mut bytes = BytesMut::with_capacity(SAMPLE_ID_SIZE);
316 sample_id.encode(&mut bytes);
318
319 let mh = Multihash::wrap(SAMPLE_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
320
321 CidGeneric::new_v1(SAMPLE_ID_CODEC, mh)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::consts::appconsts::AppVersion;
329 use crate::test_utils::generate_dummy_eds;
330
331 #[test]
332 fn round_trip() {
333 let sample_id = SampleId::new(5, 10, 100).unwrap();
334 let cid = CidGeneric::from(sample_id);
335
336 let multihash = cid.hash();
337 assert_eq!(multihash.code(), SAMPLE_ID_MULTIHASH_CODE);
338 assert_eq!(multihash.size(), SAMPLE_ID_SIZE as u8);
339
340 let deserialized_sample_id = SampleId::try_from(cid).unwrap();
341 assert_eq!(sample_id, deserialized_sample_id);
342 }
343
344 #[test]
345 fn index_calculation() {
346 let eds = generate_dummy_eds(8, AppVersion::V2);
347
348 Sample::new(0, 0, AxisType::Row, &eds).unwrap();
349 Sample::new(7, 6, AxisType::Row, &eds).unwrap();
350 Sample::new(7, 7, AxisType::Row, &eds).unwrap();
351
352 let sample_err = Sample::new(7, 8, AxisType::Row, &eds).unwrap_err();
353 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(7, 8)));
354
355 let sample_err = Sample::new(12, 3, AxisType::Row, &eds).unwrap_err();
356 assert!(matches!(sample_err, Error::EdsIndexOutOfRange(12, 3)));
357 }
358
359 #[test]
360 fn sample_id_size() {
361 assert_eq!(SAMPLE_ID_SIZE, 12);
363
364 let sample_id = SampleId::new(0, 4, 1).unwrap();
365 let mut bytes = BytesMut::new();
366 sample_id.encode(&mut bytes);
367 assert_eq!(bytes.len(), SAMPLE_ID_SIZE);
368 }
369
370 #[test]
371 fn from_buffer() {
372 let bytes = [
373 0x01, 0x90, 0xF0, 0x01, 0x91, 0xF0, 0x01, 0x0C, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, 0, 5, ];
381
382 let cid = CidGeneric::<SAMPLE_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
383 assert_eq!(cid.codec(), SAMPLE_ID_CODEC);
384 let mh = cid.hash();
385 assert_eq!(mh.code(), SAMPLE_ID_MULTIHASH_CODE);
386 assert_eq!(mh.size(), SAMPLE_ID_SIZE as u8);
387 let sample_id = SampleId::try_from(cid).unwrap();
388 assert_eq!(sample_id.block_height(), 64);
389 assert_eq!(sample_id.row_index(), 7);
390 assert_eq!(sample_id.column_index(), 5);
391 }
392
393 #[test]
394 fn multihash_invalid_code() {
395 let multihash = Multihash::<SAMPLE_ID_SIZE>::wrap(888, &[0; SAMPLE_ID_SIZE]).unwrap();
396 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(SAMPLE_ID_CODEC, multihash);
397 let code_err = SampleId::try_from(cid).unwrap_err();
398 assert_eq!(
399 code_err,
400 CidError::InvalidMultihashCode(888, SAMPLE_ID_MULTIHASH_CODE)
401 );
402 }
403
404 #[test]
405 fn cid_invalid_codec() {
406 let multihash =
407 Multihash::<SAMPLE_ID_SIZE>::wrap(SAMPLE_ID_MULTIHASH_CODE, &[0; SAMPLE_ID_SIZE])
408 .unwrap();
409 let cid = CidGeneric::<SAMPLE_ID_SIZE>::new_v1(4321, multihash);
410 let codec_err = SampleId::try_from(cid).unwrap_err();
411 assert!(matches!(codec_err, CidError::InvalidCidCodec(4321)));
412 }
413
414 #[test]
415 fn test_roundtrip_verify() {
416 for _ in 0..5 {
417 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
418 let dah = DataAvailabilityHeader::from_eds(&eds);
419
420 let row_index = rand::random::<u16>() % eds.square_width();
421 let col_index = rand::random::<u16>() % eds.square_width();
422 let proof_type = if rand::random() {
423 AxisType::Row
424 } else {
425 AxisType::Col
426 };
427
428 let id = SampleId::new(row_index, col_index, 1).unwrap();
429 let sample = Sample::new(row_index, col_index, proof_type, &eds).unwrap();
430
431 let mut buf = BytesMut::new();
432 sample.encode(&mut buf);
433 let decoded = Sample::decode(id, &buf).unwrap();
434
435 decoded.verify(id, &dah).unwrap();
436 }
437 }
438}