1use std::iter;
10
11use blockstore::block::CidError;
12use bytes::{Buf, BufMut, BytesMut};
13use celestia_proto::shwap::{row::HalfSide as RawHalfSide, Row as RawRow, Share as RawShare};
14use cid::CidGeneric;
15use multihash::Multihash;
16use prost::Message;
17use serde::Serialize;
18
19use crate::consts::appconsts::SHARE_SIZE;
20use crate::eds::ExtendedDataSquare;
21use crate::nmt::{Nmt, NmtExt};
22use crate::{DataAvailabilityHeader, Error, Result, Share};
23
24const EDS_ID_SIZE: usize = 8;
26pub(crate) const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
28pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
30pub const ROW_ID_CODEC: u64 = 0x7800;
32
33#[derive(Debug, PartialEq, Clone, Copy)]
40struct EdsId {
41 height: u64,
42}
43
44#[derive(Debug, PartialEq, Clone, Copy)]
46pub struct RowId {
47 eds_id: EdsId,
48 index: u16,
49}
50
51#[derive(Clone, Debug, Serialize)]
53#[serde(into = "RawRow")]
54pub struct Row {
55 pub shares: Vec<Share>,
57}
58
59impl Row {
60 pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
62 let shares = eds.row(index)?;
63
64 Ok(Row { shares })
65 }
66
67 pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
69 let row = id.index;
70 let mut tree = Nmt::default();
71
72 for share in &self.shares {
73 tree.push_leaf(share.as_ref(), *share.namespace())
74 .map_err(Error::Nmt)?;
75 }
76
77 let Some(root) = dah.row_root(row) else {
78 return Err(Error::EdsIndexOutOfRange(row, 0));
79 };
80
81 if tree.root().hash() != root.hash() {
82 return Err(Error::RootMismatch);
83 }
84
85 Ok(())
86 }
87
88 pub fn encode(&self, bytes: &mut BytesMut) {
90 let raw = RawRow::from(self.clone());
91
92 bytes.reserve(raw.encoded_len());
93 raw.encode(bytes).expect("capacity reserved");
94 }
95
96 pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
103 let raw = RawRow::decode(buffer)?;
104 Self::from_raw(id, raw)
105 }
106
107 pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
114 let data_shares = row.shares_half.len();
115
116 let shares = match row.half_side() {
117 RawHalfSide::Left => {
118 let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
120 shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
121 leopard_codec::encode(&mut shares, data_shares)?;
122 shares
123 }
124 RawHalfSide::Right => {
125 let mut shares: Vec<_> = iter::repeat_n(vec![], data_shares)
127 .chain(row.shares_half.into_iter().map(|shr| shr.data))
128 .collect();
129 leopard_codec::reconstruct(&mut shares, data_shares)?;
130 shares
131 }
132 };
133
134 let row_index = id.index() as usize;
135 let shares = shares
136 .into_iter()
137 .enumerate()
138 .map(|(col_index, shr)| {
139 if row_index < data_shares && col_index < data_shares {
140 Share::from_raw(&shr)
141 } else {
142 Share::parity(&shr)
143 }
144 })
145 .collect::<Result<_>>()?;
146
147 Ok(Row { shares })
148 }
149}
150
151impl From<Row> for RawRow {
152 fn from(row: Row) -> RawRow {
153 let square_width = row.shares.len();
155 let shares_half = row
156 .shares
157 .into_iter()
158 .map(|shr| RawShare { data: shr.to_vec() })
159 .take(square_width / 2)
160 .collect();
161
162 RawRow {
163 shares_half,
164 half_side: RawHalfSide::Left.into(),
165 }
166 }
167}
168
169impl RowId {
170 pub fn new(index: u16, height: u64) -> Result<Self> {
176 if height == 0 {
177 return Err(Error::ZeroBlockHeight);
178 }
179
180 Ok(Self {
181 index,
182 eds_id: EdsId { height },
183 })
184 }
185
186 pub fn block_height(&self) -> u64 {
188 self.eds_id.height
189 }
190
191 pub fn index(&self) -> u16 {
195 self.index
196 }
197
198 pub(crate) fn encode(&self, bytes: &mut BytesMut) {
199 bytes.reserve(ROW_ID_SIZE);
200 bytes.put_u64(self.block_height());
201 bytes.put_u16(self.index);
202 }
203
204 pub(crate) fn decode(mut buffer: &[u8]) -> Result<Self, CidError> {
205 if buffer.len() != ROW_ID_SIZE {
206 return Err(CidError::InvalidMultihashLength(buffer.len()));
207 }
208
209 let height = buffer.get_u64();
210 let index = buffer.get_u16();
211
212 if height == 0 {
213 return Err(CidError::InvalidCid("Zero block height".to_string()));
214 }
215
216 Ok(Self {
217 eds_id: EdsId { height },
218 index,
219 })
220 }
221}
222
223impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
224 type Error = CidError;
225
226 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
227 let codec = cid.codec();
228 if codec != ROW_ID_CODEC {
229 return Err(CidError::InvalidCidCodec(codec));
230 }
231
232 let hash = cid.hash();
233
234 let size = hash.size() as usize;
235 if size != ROW_ID_SIZE {
236 return Err(CidError::InvalidMultihashLength(size));
237 }
238
239 let code = hash.code();
240 if code != ROW_ID_MULTIHASH_CODE {
241 return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
242 }
243
244 RowId::decode(hash.digest())
245 }
246}
247
248impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
249 fn from(row: RowId) -> Self {
250 let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
251 row.encode(&mut bytes);
252 let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
254
255 CidGeneric::new_v1(ROW_ID_CODEC, mh)
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::consts::appconsts::{AppVersion, SHARE_SIZE};
263 use crate::test_utils::{generate_dummy_eds, generate_eds};
264 use crate::Blob;
265
266 #[test]
267 fn round_trip_test() {
268 let row_id = RowId::new(5, 100).unwrap();
269 let cid = CidGeneric::from(row_id);
270
271 let multihash = cid.hash();
272 assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
273 assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
274
275 let deserialized_row_id = RowId::try_from(cid).unwrap();
276 assert_eq!(row_id, deserialized_row_id);
277 }
278
279 #[test]
280 fn index_calculation() {
281 let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
282 let eds = ExtendedDataSquare::new(shares, "codec".to_string(), AppVersion::V2).unwrap();
283
284 Row::new(1, &eds).unwrap();
285 Row::new(7, &eds).unwrap();
286 let row_err = Row::new(8, &eds).unwrap_err();
287 assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
288 let row_err = Row::new(100, &eds).unwrap_err();
289 assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
290 }
291
292 #[test]
293 fn row_id_size() {
294 assert_eq!(ROW_ID_SIZE, 10);
296
297 let row_id = RowId::new(0, 1).unwrap();
298 let mut bytes = BytesMut::new();
299 row_id.encode(&mut bytes);
300 assert_eq!(bytes.len(), ROW_ID_SIZE);
301 }
302
303 #[test]
304 fn from_buffer() {
305 let bytes = [
306 0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, ];
313
314 let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
315 assert_eq!(cid.codec(), ROW_ID_CODEC);
316 let mh = cid.hash();
317 assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
318 assert_eq!(mh.size(), ROW_ID_SIZE as u8);
319 let row_id = RowId::try_from(cid).unwrap();
320 assert_eq!(row_id.index, 7);
321 assert_eq!(row_id.block_height(), 64);
322 }
323
324 #[test]
325 fn zero_block_height() {
326 let bytes = [
327 0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, ];
334
335 let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
336 assert_eq!(cid.codec(), ROW_ID_CODEC);
337 let mh = cid.hash();
338 assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
339 assert_eq!(mh.size(), ROW_ID_SIZE as u8);
340 let row_err = RowId::try_from(cid).unwrap_err();
341 assert_eq!(
342 row_err,
343 CidError::InvalidCid("Zero block height".to_string())
344 );
345 }
346
347 #[test]
348 fn multihash_invalid_code() {
349 let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
350 let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
351 let row_err = RowId::try_from(cid).unwrap_err();
352 assert_eq!(
353 row_err,
354 CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
355 );
356 }
357
358 #[test]
359 fn cid_invalid_codec() {
360 let multihash =
361 Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
362 let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
363 let row_err = RowId::try_from(cid).unwrap_err();
364 assert_eq!(row_err, CidError::InvalidCidCodec(1234));
365 }
366
367 #[test]
368 fn test_roundtrip_verify() {
369 for _ in 0..5 {
370 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
371 let dah = DataAvailabilityHeader::from_eds(&eds);
372
373 let index = rand::random::<u16>() % eds.square_width();
374 let id = RowId {
375 eds_id: EdsId { height: 1 },
376 index,
377 };
378
379 let row = Row {
380 shares: eds.row(index).unwrap(),
381 };
382
383 let mut buf = BytesMut::new();
384 row.encode(&mut buf);
385 let decoded = Row::decode(id, &buf).unwrap();
386
387 decoded.verify(id, &dah).unwrap();
388 }
389 }
390
391 #[test]
392 fn reconstruct_all() {
393 for _ in 0..3 {
394 let eds = generate_eds(8 << (rand::random::<usize>() % 6), AppVersion::V2);
395
396 let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
397 let blobs = Blob::reconstruct_all(
398 rows.iter().flat_map(|row| row.shares.iter()),
399 AppVersion::V2,
400 )
401 .unwrap();
402
403 assert_eq!(blobs.len(), 2);
404 }
405 }
406}