1use std::iter;
10
11use blockstore::block::CidError;
12use bytes::{Buf, BufMut, BytesMut};
13use celestia_proto::shwap::{row::HalfSide as RawHalfSide, Row as RawRow, Share as RawShare};
14use cid::CidGeneric;
15use multihash::Multihash;
16use prost::Message;
17use serde::Serialize;
18
19use crate::consts::appconsts::SHARE_SIZE;
20use crate::eds::ExtendedDataSquare;
21use crate::nmt::{Nmt, NmtExt};
22use crate::{DataAvailabilityHeader, Error, Result, Share};
23
24const EDS_ID_SIZE: usize = 8;
26pub(crate) const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
28pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
30pub const ROW_ID_CODEC: u64 = 0x7800;
32
33#[derive(Debug, PartialEq, Clone, Copy)]
40struct EdsId {
41 height: u64,
42}
43
44#[derive(Debug, PartialEq, Clone, Copy)]
46pub struct RowId {
47 eds_id: EdsId,
48 index: u16,
49}
50
51#[derive(Clone, Debug, Serialize)]
53#[serde(into = "RawRow")]
54pub struct Row {
55 pub shares: Vec<Share>,
57}
58
59impl Row {
60 pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
62 let shares = eds.row(index)?;
63
64 Ok(Row { shares })
65 }
66
67 pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
69 let row = id.index;
70 let mut tree = Nmt::default();
71
72 for share in &self.shares {
73 tree.push_leaf(share.as_ref(), *share.namespace())
74 .map_err(Error::Nmt)?;
75 }
76
77 let Some(root) = dah.row_root(row) else {
78 return Err(Error::EdsIndexOutOfRange(row, 0));
79 };
80
81 if tree.root().hash() != root.hash() {
82 return Err(Error::RootMismatch);
83 }
84
85 Ok(())
86 }
87
88 pub fn encode(&self, bytes: &mut BytesMut) {
90 let raw = RawRow::from(self.clone());
91
92 bytes.reserve(raw.encoded_len());
93 raw.encode(bytes).expect("capacity reserved");
94 }
95
96 pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
103 let raw = RawRow::decode(buffer)?;
104 Self::from_raw(id, raw)
105 }
106
107 pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
114 let data_shares = row.shares_half.len();
115
116 let shares = match row.half_side() {
117 RawHalfSide::Left => {
118 let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
120 shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
121 leopard_codec::encode(&mut shares, data_shares)?;
122 shares
123 }
124 RawHalfSide::Right => {
125 let mut shares: Vec<_> = iter::repeat(vec![])
127 .take(data_shares)
128 .chain(row.shares_half.into_iter().map(|shr| shr.data))
129 .collect();
130 leopard_codec::reconstruct(&mut shares, data_shares)?;
131 shares
132 }
133 };
134
135 let row_index = id.index() as usize;
136 let shares = shares
137 .into_iter()
138 .enumerate()
139 .map(|(col_index, shr)| {
140 if row_index < data_shares && col_index < data_shares {
141 Share::from_raw(&shr)
142 } else {
143 Share::parity(&shr)
144 }
145 })
146 .collect::<Result<_>>()?;
147
148 Ok(Row { shares })
149 }
150}
151
152impl From<Row> for RawRow {
153 fn from(row: Row) -> RawRow {
154 let square_width = row.shares.len();
156 let shares_half = row
157 .shares
158 .into_iter()
159 .map(|shr| RawShare { data: shr.to_vec() })
160 .take(square_width / 2)
161 .collect();
162
163 RawRow {
164 shares_half,
165 half_side: RawHalfSide::Left.into(),
166 }
167 }
168}
169
170impl RowId {
171 pub fn new(index: u16, height: u64) -> Result<Self> {
177 if height == 0 {
178 return Err(Error::ZeroBlockHeight);
179 }
180
181 Ok(Self {
182 index,
183 eds_id: EdsId { height },
184 })
185 }
186
187 pub fn block_height(&self) -> u64 {
189 self.eds_id.height
190 }
191
192 pub fn index(&self) -> u16 {
196 self.index
197 }
198
199 pub(crate) fn encode(&self, bytes: &mut BytesMut) {
200 bytes.reserve(ROW_ID_SIZE);
201 bytes.put_u64(self.block_height());
202 bytes.put_u16(self.index);
203 }
204
205 pub(crate) fn decode(mut buffer: &[u8]) -> Result<Self, CidError> {
206 if buffer.len() != ROW_ID_SIZE {
207 return Err(CidError::InvalidMultihashLength(buffer.len()));
208 }
209
210 let height = buffer.get_u64();
211 let index = buffer.get_u16();
212
213 if height == 0 {
214 return Err(CidError::InvalidCid("Zero block height".to_string()));
215 }
216
217 Ok(Self {
218 eds_id: EdsId { height },
219 index,
220 })
221 }
222}
223
224impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
225 type Error = CidError;
226
227 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
228 let codec = cid.codec();
229 if codec != ROW_ID_CODEC {
230 return Err(CidError::InvalidCidCodec(codec));
231 }
232
233 let hash = cid.hash();
234
235 let size = hash.size() as usize;
236 if size != ROW_ID_SIZE {
237 return Err(CidError::InvalidMultihashLength(size));
238 }
239
240 let code = hash.code();
241 if code != ROW_ID_MULTIHASH_CODE {
242 return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
243 }
244
245 RowId::decode(hash.digest())
246 }
247}
248
249impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
250 fn from(row: RowId) -> Self {
251 let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
252 row.encode(&mut bytes);
253 let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
255
256 CidGeneric::new_v1(ROW_ID_CODEC, mh)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::consts::appconsts::{AppVersion, SHARE_SIZE};
264 use crate::test_utils::{generate_dummy_eds, generate_eds};
265 use crate::Blob;
266
267 #[test]
268 fn round_trip_test() {
269 let row_id = RowId::new(5, 100).unwrap();
270 let cid = CidGeneric::from(row_id);
271
272 let multihash = cid.hash();
273 assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
274 assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
275
276 let deserialized_row_id = RowId::try_from(cid).unwrap();
277 assert_eq!(row_id, deserialized_row_id);
278 }
279
280 #[test]
281 fn index_calculation() {
282 let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
283 let eds = ExtendedDataSquare::new(shares, "codec".to_string(), AppVersion::V2).unwrap();
284
285 Row::new(1, &eds).unwrap();
286 Row::new(7, &eds).unwrap();
287 let row_err = Row::new(8, &eds).unwrap_err();
288 assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
289 let row_err = Row::new(100, &eds).unwrap_err();
290 assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
291 }
292
293 #[test]
294 fn row_id_size() {
295 assert_eq!(ROW_ID_SIZE, 10);
297
298 let row_id = RowId::new(0, 1).unwrap();
299 let mut bytes = BytesMut::new();
300 row_id.encode(&mut bytes);
301 assert_eq!(bytes.len(), ROW_ID_SIZE);
302 }
303
304 #[test]
305 fn from_buffer() {
306 let bytes = [
307 0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, ];
314
315 let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
316 assert_eq!(cid.codec(), ROW_ID_CODEC);
317 let mh = cid.hash();
318 assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
319 assert_eq!(mh.size(), ROW_ID_SIZE as u8);
320 let row_id = RowId::try_from(cid).unwrap();
321 assert_eq!(row_id.index, 7);
322 assert_eq!(row_id.block_height(), 64);
323 }
324
325 #[test]
326 fn zero_block_height() {
327 let bytes = [
328 0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, ];
335
336 let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
337 assert_eq!(cid.codec(), ROW_ID_CODEC);
338 let mh = cid.hash();
339 assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
340 assert_eq!(mh.size(), ROW_ID_SIZE as u8);
341 let row_err = RowId::try_from(cid).unwrap_err();
342 assert_eq!(
343 row_err,
344 CidError::InvalidCid("Zero block height".to_string())
345 );
346 }
347
348 #[test]
349 fn multihash_invalid_code() {
350 let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
351 let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
352 let row_err = RowId::try_from(cid).unwrap_err();
353 assert_eq!(
354 row_err,
355 CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
356 );
357 }
358
359 #[test]
360 fn cid_invalid_codec() {
361 let multihash =
362 Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
363 let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
364 let row_err = RowId::try_from(cid).unwrap_err();
365 assert_eq!(row_err, CidError::InvalidCidCodec(1234));
366 }
367
368 #[test]
369 fn test_roundtrip_verify() {
370 for _ in 0..5 {
371 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
372 let dah = DataAvailabilityHeader::from_eds(&eds);
373
374 let index = rand::random::<u16>() % eds.square_width();
375 let id = RowId {
376 eds_id: EdsId { height: 1 },
377 index,
378 };
379
380 let row = Row {
381 shares: eds.row(index).unwrap(),
382 };
383
384 let mut buf = BytesMut::new();
385 row.encode(&mut buf);
386 let decoded = Row::decode(id, &buf).unwrap();
387
388 decoded.verify(id, &dah).unwrap();
389 }
390 }
391
392 #[test]
393 fn reconstruct_all() {
394 for _ in 0..3 {
395 let eds = generate_eds(8 << (rand::random::<usize>() % 6), AppVersion::V2);
396
397 let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
398 let blobs = Blob::reconstruct_all(
399 rows.iter().flat_map(|row| row.shares.iter()),
400 AppVersion::V2,
401 )
402 .unwrap();
403
404 assert_eq!(blobs.len(), 2);
405 }
406 }
407}