use std::iter;
use blockstore::block::CidError;
use bytes::{Buf, BufMut, BytesMut};
use celestia_proto::shwap::{Row as RawRow, Share as RawShare, row::HalfSide as RawHalfSide};
use cid::CidGeneric;
use multihash::Multihash;
use prost::Message;
use serde::Serialize;
use crate::consts::appconsts::SHARE_SIZE;
use crate::eds::{EDS_ID_SIZE, EdsId, ExtendedDataSquare};
use crate::nmt::{Nmt, NmtExt};
use crate::{DataAvailabilityHeader, Error, Result, Share};
pub const ROW_ID_SIZE: usize = EDS_ID_SIZE + 2;
pub const ROW_ID_MULTIHASH_CODE: u64 = 0x7801;
pub const ROW_ID_CODEC: u64 = 0x7800;
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct RowId {
eds_id: EdsId,
index: u16,
}
#[derive(Clone, Debug, Serialize)]
#[serde(into = "RawRow")]
pub struct Row {
pub shares: Vec<Share>,
}
impl Row {
pub fn new(index: u16, eds: &ExtendedDataSquare) -> Result<Self> {
let shares = eds.row(index)?;
Ok(Row { shares })
}
pub fn verify(&self, id: RowId, dah: &DataAvailabilityHeader) -> Result<()> {
let row = id.index;
let mut tree = Nmt::default();
for share in &self.shares {
tree.push_leaf(share.as_ref(), *share.namespace())
.map_err(Error::Nmt)?;
}
let Some(root) = dah.row_root(row) else {
return Err(Error::EdsIndexOutOfRange(row, 0));
};
if tree.root().hash() != root.hash() {
return Err(Error::RootMismatch);
}
Ok(())
}
pub fn encode(&self, bytes: &mut BytesMut) {
let raw = RawRow::from(self.clone());
bytes.reserve(raw.encoded_len());
raw.encode(bytes).expect("capacity reserved");
}
pub fn decode(id: RowId, buffer: &[u8]) -> Result<Self> {
let raw = RawRow::decode(buffer)?;
Self::from_raw(id, raw)
}
pub fn from_raw(id: RowId, row: RawRow) -> Result<Self> {
let data_shares = row.shares_half.len();
let shares = match row.half_side() {
RawHalfSide::Left => {
let mut shares: Vec<_> = row.shares_half.into_iter().map(|shr| shr.data).collect();
shares.resize(shares.len() * 2, vec![0; SHARE_SIZE]);
leopard_codec::encode(&mut shares, data_shares)?;
shares
}
RawHalfSide::Right => {
let mut shares: Vec<_> = iter::repeat_n(vec![], data_shares)
.chain(row.shares_half.into_iter().map(|shr| shr.data))
.collect();
leopard_codec::reconstruct(&mut shares, data_shares)?;
shares
}
};
let row_index = id.index() as usize;
let shares = shares
.into_iter()
.enumerate()
.map(|(col_index, shr)| {
if row_index < data_shares && col_index < data_shares {
Share::from_raw(&shr)
} else {
Share::parity(&shr)
}
})
.collect::<Result<_>>()?;
Ok(Row { shares })
}
}
impl From<Row> for RawRow {
fn from(row: Row) -> RawRow {
let square_width = row.shares.len();
let shares_half = row
.shares
.into_iter()
.map(|shr| RawShare { data: shr.to_vec() })
.take(square_width / 2)
.collect();
RawRow {
shares_half,
half_side: RawHalfSide::Left.into(),
}
}
}
impl RowId {
pub fn new(index: u16, height: u64) -> Result<Self> {
Ok(Self {
index,
eds_id: EdsId::new(height)?,
})
}
pub fn block_height(&self) -> u64 {
self.eds_id.block_height()
}
pub fn index(&self) -> u16 {
self.index
}
pub fn encode(&self, bytes: &mut BytesMut) {
bytes.reserve(ROW_ID_SIZE);
self.eds_id.encode(bytes);
bytes.put_u16(self.index);
}
pub fn decode(buffer: &[u8]) -> Result<Self> {
if buffer.len() != ROW_ID_SIZE {
return Err(Error::InvalidLength(buffer.len(), ROW_ID_SIZE));
}
let (eds_bytes, mut row_bytes) = buffer.split_at(EDS_ID_SIZE);
let eds_id = EdsId::decode(eds_bytes)?;
let index = row_bytes.get_u16();
Ok(Self { eds_id, index })
}
}
impl<const S: usize> TryFrom<CidGeneric<S>> for RowId {
type Error = CidError;
fn try_from(cid: CidGeneric<S>) -> Result<Self, Self::Error> {
let codec = cid.codec();
if codec != ROW_ID_CODEC {
return Err(CidError::InvalidCidCodec(codec));
}
let hash = cid.hash();
let size = hash.size() as usize;
if size != ROW_ID_SIZE {
return Err(CidError::InvalidMultihashLength(size));
}
let code = hash.code();
if code != ROW_ID_MULTIHASH_CODE {
return Err(CidError::InvalidMultihashCode(code, ROW_ID_MULTIHASH_CODE));
}
RowId::decode(hash.digest()).map_err(|e| CidError::InvalidCid(e.to_string()))
}
}
impl From<RowId> for CidGeneric<ROW_ID_SIZE> {
fn from(row: RowId) -> Self {
let mut bytes = BytesMut::with_capacity(ROW_ID_SIZE);
row.encode(&mut bytes);
let mh = Multihash::wrap(ROW_ID_MULTIHASH_CODE, &bytes[..]).unwrap();
CidGeneric::new_v1(ROW_ID_CODEC, mh)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Blob;
use crate::consts::appconsts::SHARE_SIZE;
use crate::test_utils::{generate_dummy_eds, generate_eds};
#[test]
fn round_trip_test() {
let row_id = RowId::new(5, 100).unwrap();
let cid = CidGeneric::from(row_id);
let multihash = cid.hash();
assert_eq!(multihash.code(), ROW_ID_MULTIHASH_CODE);
assert_eq!(multihash.size(), ROW_ID_SIZE as u8);
let deserialized_row_id = RowId::try_from(cid).unwrap();
assert_eq!(row_id, deserialized_row_id);
}
#[test]
fn index_calculation() {
let shares = vec![vec![0; SHARE_SIZE]; 8 * 8];
let eds = ExtendedDataSquare::new(shares, "codec".to_string()).unwrap();
Row::new(1, &eds).unwrap();
Row::new(7, &eds).unwrap();
let row_err = Row::new(8, &eds).unwrap_err();
assert!(matches!(row_err, Error::EdsIndexOutOfRange(8, 0)));
let row_err = Row::new(100, &eds).unwrap_err();
assert!(matches!(row_err, Error::EdsIndexOutOfRange(100, 0)));
}
#[test]
fn row_id_size() {
assert_eq!(ROW_ID_SIZE, 10);
let row_id = RowId::new(0, 1).unwrap();
let mut bytes = BytesMut::new();
row_id.encode(&mut bytes);
assert_eq!(bytes.len(), ROW_ID_SIZE);
}
#[test]
fn from_buffer() {
let bytes = [
0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 64, 0, 7, ];
let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
assert_eq!(cid.codec(), ROW_ID_CODEC);
let mh = cid.hash();
assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
assert_eq!(mh.size(), ROW_ID_SIZE as u8);
let row_id = RowId::try_from(cid).unwrap();
assert_eq!(row_id.index, 7);
assert_eq!(row_id.block_height(), 64);
}
#[test]
fn zero_block_height() {
let bytes = [
0x01, 0x80, 0xF0, 0x01, 0x81, 0xF0, 0x01, 0x0A, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, ];
let cid = CidGeneric::<ROW_ID_SIZE>::read_bytes(bytes.as_ref()).unwrap();
assert_eq!(cid.codec(), ROW_ID_CODEC);
let mh = cid.hash();
assert_eq!(mh.code(), ROW_ID_MULTIHASH_CODE);
assert_eq!(mh.size(), ROW_ID_SIZE as u8);
let row_err = RowId::try_from(cid).unwrap_err();
assert_eq!(
row_err,
CidError::InvalidCid("Invalid zero block height".to_string())
);
}
#[test]
fn multihash_invalid_code() {
let multihash = Multihash::<ROW_ID_SIZE>::wrap(999, &[0; ROW_ID_SIZE]).unwrap();
let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(ROW_ID_CODEC, multihash);
let row_err = RowId::try_from(cid).unwrap_err();
assert_eq!(
row_err,
CidError::InvalidMultihashCode(999, ROW_ID_MULTIHASH_CODE)
);
}
#[test]
fn cid_invalid_codec() {
let multihash =
Multihash::<ROW_ID_SIZE>::wrap(ROW_ID_MULTIHASH_CODE, &[0; ROW_ID_SIZE]).unwrap();
let cid = CidGeneric::<ROW_ID_SIZE>::new_v1(1234, multihash);
let row_err = RowId::try_from(cid).unwrap_err();
assert_eq!(row_err, CidError::InvalidCidCodec(1234));
}
#[test]
fn test_roundtrip_verify() {
for _ in 0..5 {
let eds = generate_dummy_eds(2 << (rand::random::<usize>() % 8));
let dah = DataAvailabilityHeader::from_eds(&eds);
let index = rand::random::<u16>() % eds.square_width();
let id = RowId::new(index, 1).unwrap();
let row = Row {
shares: eds.row(index).unwrap(),
};
let mut buf = BytesMut::new();
row.encode(&mut buf);
let decoded = Row::decode(id, &buf).unwrap();
decoded.verify(id, &dah).unwrap();
}
}
#[test]
fn reconstruct_all() {
for _ in 0..3 {
let eds = generate_eds(8 << (rand::random::<usize>() % 6));
let rows: Vec<_> = (1..4).map(|row| Row::new(row, &eds).unwrap()).collect();
let blobs =
Blob::reconstruct_all(rows.iter().flat_map(|row| row.shares.iter())).unwrap();
assert_eq!(blobs.len(), 2);
}
}
}