1use std::iter;
10
11use blockstore::block::CidError;
12use bytes::{Buf, BufMut, BytesMut};
13use celestia_proto::shwap::{Row as RawRow, Share as RawShare, row::HalfSide as RawHalfSide};
14use cid::CidGeneric;
15use multihash::Multihash;
16use prost::Message;
17use serde::Serialize;
18
19use crate::consts::appconsts::SHARE_SIZE;
20use crate::eds::{EDS_ID_SIZE, EdsId, ExtendedDataSquare};
21use crate::nmt::{Nmt, NmtExt};
22use crate::{DataAvailabilityHeader, Error, Result, Share};
23
24pub const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
26pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
28pub const ROW_ID_CODEC: u64 = 0x7800;
30
31#[derive(Debug, PartialEq, Clone, Copy)]
33pub struct RowId {
34 eds_id: EdsId,
35 index: u16,
36}
37
38#[derive(Clone, Debug, Serialize)]
40#[serde(into = "RawRow")]
41pub struct Row {
42 pub shares: Vec<Share>,
44}
45
46impl Row {
47 pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
49 let shares = eds.row(index)?;
50
51 Ok(Row { shares })
52 }
53
54 pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
56 let row = id.index;
57 let mut tree = Nmt::default();
58
59 for share in &self.shares {
60 tree.push_leaf(share.as_ref(), *share.namespace())
61 .map_err(Error::Nmt)?;
62 }
63
64 let Some(root) = dah.row_root(row) else {
65 return Err(Error::EdsIndexOutOfRange(row, 0));
66 };
67
68 if tree.root().hash() != root.hash() {
69 return Err(Error::RootMismatch);
70 }
71
72 Ok(())
73 }
74
75 pub fn encode(&self, bytes: &mut BytesMut) {
77 let raw = RawRow::from(self.clone());
78
79 bytes.reserve(raw.encoded_len());
80 raw.encode(bytes).expect("capacity reserved");
81 }
82
83 pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
90 let raw = RawRow::decode(buffer)?;
91 Self::from_raw(id, raw)
92 }
93
94 pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
101 let data_shares = row.shares_half.len();
102
103 let shares = match row.half_side() {
104 RawHalfSide::Left => {
105 let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
107 shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
108 leopard_codec::encode(&mut shares, data_shares)?;
109 shares
110 }
111 RawHalfSide::Right => {
112 let mut shares: Vec<_> = iter::repeat_n(vec![], data_shares)
114 .chain(row.shares_half.into_iter().map(|shr| shr.data))
115 .collect();
116 leopard_codec::reconstruct(&mut shares, data_shares)?;
117 shares
118 }
119 };
120
121 let row_index = id.index() as usize;
122 let shares = shares
123 .into_iter()
124 .enumerate()
125 .map(|(col_index, shr)| {
126 if row_index < data_shares && col_index < data_shares {
127 Share::from_raw(&shr)
128 } else {
129 Share::parity(&shr)
130 }
131 })
132 .collect::<Result<_>>()?;
133
134 Ok(Row { shares })
135 }
136}
137
138impl From<Row> for RawRow {
139 fn from(row: Row) -> RawRow {
140 let square_width = row.shares.len();
142 let shares_half = row
143 .shares
144 .into_iter()
145 .map(|shr| RawShare { data: shr.to_vec() })
146 .take(square_width / 2)
147 .collect();
148
149 RawRow {
150 shares_half,
151 half_side: RawHalfSide::Left.into(),
152 }
153 }
154}
155
156impl RowId {
157 pub fn new(index: u16, height: u64) -> Result<Self> {
163 Ok(Self {
164 index,
165 eds_id: EdsId::new(height)?,
166 })
167 }
168
169 pub fn block_height(&self) -> u64 {
171 self.eds_id.block_height()
172 }
173
174 pub fn index(&self) -> u16 {
178 self.index
179 }
180
181 pub fn encode(&self, bytes: &mut BytesMut) {
183 bytes.reserve(ROW_ID_SIZE);
184 self.eds_id.encode(bytes);
185 bytes.put_u16(self.index);
186 }
187
188 pub fn decode(buffer: &[u8]) -> Result<Self> {
190 if buffer.len() != ROW_ID_SIZE {
191 return Err(Error::InvalidLength(buffer.len(), ROW_ID_SIZE));
192 }
193
194 let (eds_bytes, mut row_bytes) = buffer.split_at(EDS_ID_SIZE);
195 let eds_id = EdsId::decode(eds_bytes)?;
196 let index = row_bytes.get_u16();
197
198 Ok(Self { eds_id, index })
199 }
200}
201
202impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
203 type Error = CidError;
204
205 fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
206 let codec = cid.codec();
207 if codec != ROW_ID_CODEC {
208 return Err(CidError::InvalidCidCodec(codec));
209 }
210
211 let hash = cid.hash();
212
213 let size = hash.size() as usize;
214 if size != ROW_ID_SIZE {
215 return Err(CidError::InvalidMultihashLength(size));
216 }
217
218 let code = hash.code();
219 if code != ROW_ID_MULTIHASH_CODE {
220 return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
221 }
222
223 RowId::decode(hash.digest()).map_err(|e| CidError::InvalidCid(e.to_string()))
224 }
225}
226
227impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
228 fn from(row: RowId) -> Self {
229 let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
230 row.encode(&mut bytes);
231 let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
233
234 CidGeneric::new_v1(ROW_ID_CODEC, mh)
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::Blob;
242 use crate::consts::appconsts::{AppVersion, SHARE_SIZE};
243 use crate::test_utils::{generate_dummy_eds, generate_eds};
244
245 #[test]
246 fn round_trip_test() {
247 let row_id = RowId::new(5, 100).unwrap();
248 let cid = CidGeneric::from(row_id);
249
250 let multihash = cid.hash();
251 assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
252 assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
253
254 let deserialized_row_id = RowId::try_from(cid).unwrap();
255 assert_eq!(row_id, deserialized_row_id);
256 }
257
258 #[test]
259 fn index_calculation() {
260 let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
261 let eds = ExtendedDataSquare::new(shares, "codec".to_string(), AppVersion::V2).unwrap();
262
263 Row::new(1, &eds).unwrap();
264 Row::new(7, &eds).unwrap();
265 let row_err = Row::new(8, &eds).unwrap_err();
266 assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
267 let row_err = Row::new(100, &eds).unwrap_err();
268 assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
269 }
270
271 #[test]
272 fn row_id_size() {
273 assert_eq!(ROW_ID_SIZE, 10);
275
276 let row_id = RowId::new(0, 1).unwrap();
277 let mut bytes = BytesMut::new();
278 row_id.encode(&mut bytes);
279 assert_eq!(bytes.len(), ROW_ID_SIZE);
280 }
281
282 #[test]
283 fn from_buffer() {
284 let bytes = [
285 0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, ];
292
293 let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
294 assert_eq!(cid.codec(), ROW_ID_CODEC);
295 let mh = cid.hash();
296 assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
297 assert_eq!(mh.size(), ROW_ID_SIZE as u8);
298 let row_id = RowId::try_from(cid).unwrap();
299 assert_eq!(row_id.index, 7);
300 assert_eq!(row_id.block_height(), 64);
301 }
302
303 #[test]
304 fn zero_block_height() {
305 let bytes = [
306 0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, ];
313
314 let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
315 assert_eq!(cid.codec(), ROW_ID_CODEC);
316 let mh = cid.hash();
317 assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
318 assert_eq!(mh.size(), ROW_ID_SIZE as u8);
319 let row_err = RowId::try_from(cid).unwrap_err();
320 assert_eq!(
321 row_err,
322 CidError::InvalidCid("Invalid zero block height".to_string())
323 );
324 }
325
326 #[test]
327 fn multihash_invalid_code() {
328 let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
329 let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
330 let row_err = RowId::try_from(cid).unwrap_err();
331 assert_eq!(
332 row_err,
333 CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
334 );
335 }
336
337 #[test]
338 fn cid_invalid_codec() {
339 let multihash =
340 Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
341 let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
342 let row_err = RowId::try_from(cid).unwrap_err();
343 assert_eq!(row_err, CidError::InvalidCidCodec(1234));
344 }
345
346 #[test]
347 fn test_roundtrip_verify() {
348 for _ in 0..5 {
349 let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8), AppVersion::V2);
350 let dah = DataAvailabilityHeader::from_eds(&eds);
351
352 let index = rand::random::<u16>() % eds.square_width();
353 let id = RowId::new(index, 1).unwrap();
354
355 let row = Row {
356 shares: eds.row(index).unwrap(),
357 };
358
359 let mut buf = BytesMut::new();
360 row.encode(&mut buf);
361 let decoded = Row::decode(id, &buf).unwrap();
362
363 decoded.verify(id, &dah).unwrap();
364 }
365 }
366
367 #[test]
368 fn reconstruct_all() {
369 for _ in 0..3 {
370 let eds = generate_eds(8 << (rand::random::<usize>() % 6), AppVersion::V2);
371
372 let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
373 let blobs = Blob::reconstruct_all(
374 rows.iter().flat_map(|row| row.shares.iter()),
375 AppVersion::V2,
376 )
377 .unwrap();
378
379 assert_eq!(blobs.len(), 2);
380 }
381 }
382}