use std::cmp::Ordering;
use std::fmt::Display;
use bytes::{Buf, BufMut, BytesMut};
use serde::{Deserialize, Serialize};
use crate::consts::appconsts::SHARE_SIZE;
use crate::consts::data_availability_header::MIN_EXTENDED_SQUARE_WIDTH;
use crate::nmt::{NS_SIZE, Namespace, Nmt, NmtExt};
use crate::row_namespace_data::{RowNamespaceData, RowNamespaceDataId};
use crate::{DataAvailabilityHeader, Error, InfoByte, Result, Share, bail_validation};
pub const EDS_ID_SIZE: usize = 8;
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct EdsId {
height: u64,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum AxisType {
Row = 0,
Col,
}
impl TryFrom<i32> for AxisType {
type Error = Error;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
0 => Ok(AxisType::Row),
1 => Ok(AxisType::Col),
n => Err(Error::InvalidAxis(n)),
}
}
}
impl Display for AxisType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AxisType::Row => write!(f, "Row"),
AxisType::Col => write!(f, "Column"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(into = "RawExtendedDataSquare")]
pub struct ExtendedDataSquare {
data_square: Vec<Share>,
codec: String,
square_width: u16,
}
impl ExtendedDataSquare {
pub fn new(shares: Vec<Vec<u8>>, codec: String) -> Result<Self> {
const MIN_SHARES: usize = MIN_EXTENDED_SQUARE_WIDTH * MIN_EXTENDED_SQUARE_WIDTH;
if shares.len() < MIN_SHARES {
bail_validation!(
"shares len ({}) < MIN_SHARES ({})",
shares.len(),
MIN_SHARES
);
}
let square_width = f64::sqrt(shares.len() as f64) as usize;
if square_width * square_width != shares.len() {
return Err(Error::EdsInvalidDimentions);
}
let square_width = u16::try_from(square_width).map_err(|_| Error::EdsInvalidDimentions)?;
if square_width.count_ones() != 1 {
return Err(Error::EdsInvalidDimentions);
}
let check_share = |row, col, prev_ns: Option<Namespace>, axis| -> Result<Share> {
let idx = flatten_index(row, col, square_width);
let share = if is_ods_square(row, col, square_width) {
Share::from_raw(&shares[idx])?
} else {
Share::parity(&shares[idx])?
};
if prev_ns.is_some_and(|prev_ns| share.namespace() < prev_ns) {
let axis_idx = match axis {
AxisType::Row => row,
AxisType::Col => col,
};
bail_validation!("Shares of {axis} {axis_idx} are not sorted by their namespace");
}
Ok(share)
};
for col in 0..square_width {
let mut prev_ns = None;
for row in 0..square_width {
let share = check_share(row, col, prev_ns, AxisType::Col)?;
prev_ns = Some(share.namespace());
}
}
let mut data_square = Vec::with_capacity(shares.len());
for row in 0..square_width {
let mut prev_ns = None;
for col in 0..square_width {
let share = check_share(row, col, prev_ns, AxisType::Row)?;
prev_ns = Some(share.namespace());
data_square.push(share);
}
}
let eds = ExtendedDataSquare {
data_square,
codec,
square_width,
};
Ok(eds)
}
pub fn from_raw(raw_eds: RawExtendedDataSquare) -> Result<Self> {
ExtendedDataSquare::new(raw_eds.data_square, raw_eds.codec)
}
pub fn empty() -> ExtendedDataSquare {
let ods = vec![
[
Namespace::TAIL_PADDING.as_bytes(),
&[InfoByte::new(0, true).unwrap().as_u8()],
&[0; SHARE_SIZE - NS_SIZE - 1],
]
.concat(),
];
ExtendedDataSquare::from_ods(ods).expect("invalid EDS")
}
pub fn from_ods(mut ods_shares: Vec<Vec<u8>>) -> Result<ExtendedDataSquare> {
let ods_width = f64::sqrt(ods_shares.len() as f64) as usize;
if ods_width * ods_width != ods_shares.len() {
return Err(Error::EdsInvalidDimentions);
}
let eds_width = ods_width * 2;
let mut eds_shares = Vec::with_capacity(eds_width * eds_width);
for _ in 0..ods_width {
eds_shares.extend(ods_shares.drain(..ods_width));
for _ in 0..ods_width {
eds_shares.push(vec![0; SHARE_SIZE]);
}
}
eds_shares.resize(eds_width * eds_width, vec![0; SHARE_SIZE]);
for row in eds_shares.chunks_mut(eds_width).take(ods_width) {
leopard_codec::encode(row, ods_width)?;
}
for col in 0..ods_width {
let mut col: Vec<_> = eds_shares.iter_mut().skip(col).step_by(eds_width).collect();
leopard_codec::encode(&mut col, ods_width)?;
}
for row in eds_shares.chunks_mut(eds_width).skip(ods_width) {
leopard_codec::encode(row, ods_width)?;
}
ExtendedDataSquare::new(eds_shares, "Leopard".to_string())
}
pub fn data_square(&self) -> &[Share] {
&self.data_square
}
pub fn codec(&self) -> &str {
self.codec.as_str()
}
pub fn share(&self, row: u16, column: u16) -> Result<&Share> {
let index = usize::from(row) * usize::from(self.square_width) + usize::from(column);
self.data_square
.get(index)
.ok_or(Error::EdsIndexOutOfRange(row, column))
}
#[cfg(any(test, feature = "test-utils"))]
pub(crate) fn share_mut(&mut self, row: u16, column: u16) -> Result<&mut Share> {
let index = flatten_index(row, column, self.square_width);
self.data_square
.get_mut(index)
.ok_or(Error::EdsIndexOutOfRange(row, column))
}
pub fn row(&self, index: u16) -> Result<Vec<Share>> {
self.axis(AxisType::Row, index)
}
pub fn row_nmt(&self, index: u16) -> Result<Nmt> {
self.axis_nmt(AxisType::Row, index)
}
pub fn column(&self, index: u16) -> Result<Vec<Share>> {
self.axis(AxisType::Col, index)
}
pub fn column_nmt(&self, index: u16) -> Result<Nmt> {
self.axis_nmt(AxisType::Col, index)
}
pub fn axis(&self, axis: AxisType, index: u16) -> Result<Vec<Share>> {
(0..self.square_width)
.map(|i| {
let (row, col) = match axis {
AxisType::Row => (index, i),
AxisType::Col => (i, index),
};
self.share(row, col).map(ToOwned::to_owned)
})
.collect()
}
pub fn axis_nmt(&self, axis: AxisType, index: u16) -> Result<Nmt> {
let mut tree = Nmt::default();
for i in 0..self.square_width {
let (row, col) = match axis {
AxisType::Row => (index, i),
AxisType::Col => (i, index),
};
let share = self.share(row, col)?;
tree.push_leaf(share.as_ref(), *share.namespace())
.map_err(Error::Nmt)?;
}
Ok(tree)
}
pub fn square_width(&self) -> u16 {
self.square_width
}
pub fn get_namespace_data(
&self,
namespace: Namespace,
dah: &DataAvailabilityHeader,
height: u64,
) -> Result<Vec<(RowNamespaceDataId, RowNamespaceData)>> {
let mut rows = Vec::new();
for row in 0..self.square_width {
if !dah.row_contains(row, namespace)? {
continue;
}
let mut shares = Vec::with_capacity(self.square_width.into());
for col in 0..self.square_width {
let share = self.share(row, col)?;
match share.namespace().cmp(&namespace) {
Ordering::Less => {}
Ordering::Equal => shares.push(share.clone()),
Ordering::Greater => break,
}
}
let proof = self.row_nmt(row)?.get_namespace_proof(*namespace);
let id = RowNamespaceDataId::new(namespace, row, height)?;
let data = RowNamespaceData {
proof: proof.into(),
shares,
};
rows.push((id, data))
}
Ok(rows)
}
}
impl EdsId {
pub fn new(height: u64) -> Result<Self> {
if height == 0 {
return Err(Error::ZeroBlockHeight);
}
Ok(EdsId { height })
}
pub fn block_height(&self) -> u64 {
self.height
}
pub fn encode(&self, bytes: &mut BytesMut) {
bytes.reserve(EDS_ID_SIZE);
bytes.put_u64(self.height);
}
pub fn decode(mut buffer: &[u8]) -> Result<Self> {
if buffer.len() != EDS_ID_SIZE {
return Err(Error::InvalidLength(buffer.len(), EDS_ID_SIZE));
}
let height = buffer.get_u64();
EdsId::new(height)
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct RawExtendedDataSquare {
#[serde(with = "tendermint_proto::serializers::bytes::vec_base64string")]
pub data_square: Vec<Vec<u8>>,
pub codec: String,
}
impl From<ExtendedDataSquare> for RawExtendedDataSquare {
fn from(eds: ExtendedDataSquare) -> RawExtendedDataSquare {
RawExtendedDataSquare {
data_square: eds
.data_square
.into_iter()
.map(|shr| shr.to_vec())
.collect(),
codec: eds.codec,
}
}
}
pub(crate) fn is_ods_square(row: u16, column: u16, square_width: u16) -> bool {
let ods_width = square_width / 2;
row < ods_width && column < ods_width
}
fn flatten_index(row: u16, col: u16, square_width: u16) -> usize {
usize::from(row) * usize::from(square_width) + usize::from(col)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::generate_eds;
use crate::{Blob, ExtendedHeader};
#[test]
fn axis_type_serialization() {
assert_eq!(AxisType::Row as u8, 0);
assert_eq!(AxisType::Col as u8, 1);
}
#[test]
fn axis_type_deserialization() {
assert_eq!(AxisType::try_from(0).unwrap(), AxisType::Row);
assert_eq!(AxisType::try_from(1).unwrap(), AxisType::Col);
let axis_type_err = AxisType::try_from(2).unwrap_err();
assert!(matches!(axis_type_err, Error::InvalidAxis(2)));
let axis_type_err = AxisType::try_from(99).unwrap_err();
assert!(matches!(axis_type_err, Error::InvalidAxis(99)));
}
#[test]
fn get_namespaced_data() {
let eds_json = include_str!("../test_data/shwap_samples/eds.json");
let raw_eds: RawExtendedDataSquare = serde_json::from_str(eds_json).unwrap();
let eds = ExtendedDataSquare::from_raw(raw_eds).unwrap();
let dah_json = include_str!("../test_data/shwap_samples/dah.json");
let dah: DataAvailabilityHeader = serde_json::from_str(dah_json).unwrap();
let height = 45577;
let rows = eds
.get_namespace_data(Namespace::new_v0(&[1, 170]).unwrap(), &dah, height)
.unwrap();
assert_eq!(rows.len(), 1);
let (id, row) = &rows[0];
row.verify(*id, &dah).unwrap();
assert_eq!(row.shares.len(), 2);
let rows = eds
.get_namespace_data(Namespace::new_v0(&[1, 187]).unwrap(), &dah, height)
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].1.shares.len(), 1);
assert_eq!(rows[1].1.shares.len(), 4);
for (id, row) in rows {
row.verify(id, &dah).unwrap();
}
}
#[test]
fn nmt_roots() {
let eds_json = include_str!("../test_data/shwap_samples/eds.json");
let raw_eds: RawExtendedDataSquare = serde_json::from_str(eds_json).unwrap();
let eds = ExtendedDataSquare::from_raw(raw_eds).unwrap();
let dah_json = include_str!("../test_data/shwap_samples/dah.json");
let dah: DataAvailabilityHeader = serde_json::from_str(dah_json).unwrap();
assert_eq!(dah.row_roots().len(), usize::from(eds.square_width()));
assert_eq!(dah.column_roots().len(), usize::from(eds.square_width()));
for (i, root) in dah.row_roots().iter().enumerate() {
let mut tree = eds.row_nmt(i as u16).unwrap();
assert_eq!(*root, tree.root());
let mut tree = eds.axis_nmt(AxisType::Row, i as u16).unwrap();
assert_eq!(*root, tree.root());
}
for (i, root) in dah.column_roots().iter().enumerate() {
let mut tree = eds.column_nmt(i as u16).unwrap();
assert_eq!(*root, tree.root());
let mut tree = eds.axis_nmt(AxisType::Col, i as u16).unwrap();
assert_eq!(*root, tree.root());
}
}
#[test]
fn ods_square() {
assert!(is_ods_square(0, 0, 4));
assert!(is_ods_square(0, 1, 4));
assert!(is_ods_square(1, 0, 4));
assert!(is_ods_square(1, 1, 4));
assert!(!is_ods_square(0, 2, 4));
assert!(!is_ods_square(0, 3, 4));
assert!(!is_ods_square(1, 2, 4));
assert!(!is_ods_square(1, 3, 4));
assert!(!is_ods_square(2, 0, 4));
assert!(!is_ods_square(2, 1, 4));
assert!(!is_ods_square(3, 0, 4));
assert!(!is_ods_square(3, 1, 4));
assert!(!is_ods_square(2, 2, 4));
assert!(!is_ods_square(2, 3, 4));
assert!(!is_ods_square(3, 2, 4));
assert!(!is_ods_square(3, 3, 4));
}
#[test]
fn get_row_and_col() {
let raw_share = |x, y| {
[
Namespace::new_v0(&[x, y]).unwrap().as_bytes(),
&[0u8; SHARE_SIZE - NS_SIZE][..],
]
.concat()
};
let share = |x, y, parity: bool| {
if !parity {
Share::from_raw(&raw_share(x, y)).unwrap()
} else {
Share::parity(&raw_share(x, y)).unwrap()
}
};
#[rustfmt::skip]
let shares = vec![
raw_share(0, 0), raw_share(0, 1), raw_share(0, 2), raw_share(0, 3),
raw_share(1, 0), raw_share(1, 1), raw_share(1, 2), raw_share(1, 3),
raw_share(2, 0), raw_share(2, 1), raw_share(2, 2), raw_share(2, 3),
raw_share(3, 0), raw_share(3, 1), raw_share(3, 2), raw_share(3, 3),
];
let eds = ExtendedDataSquare::new(shares, "fake".to_string()).unwrap();
assert_eq!(
eds.row(0).unwrap(),
vec![
share(0, 0, false),
share(0, 1, false),
share(0, 2, true),
share(0, 3, true)
],
);
assert_eq!(
eds.row(1).unwrap(),
vec![
share(1, 0, false),
share(1, 1, false),
share(1, 2, true),
share(1, 3, true)
],
);
assert_eq!(
eds.row(2).unwrap(),
vec![
share(2, 0, true),
share(2, 1, true),
share(2, 2, true),
share(2, 3, true)
],
);
assert_eq!(
eds.row(3).unwrap(),
vec![
share(3, 0, true),
share(3, 1, true),
share(3, 2, true),
share(3, 3, true)
],
);
assert_eq!(
eds.axis(AxisType::Row, 0).unwrap(),
vec![
share(0, 0, false),
share(0, 1, false),
share(0, 2, true),
share(0, 3, true)
],
);
assert_eq!(
eds.axis(AxisType::Row, 1).unwrap(),
vec![
share(1, 0, false),
share(1, 1, false),
share(1, 2, true),
share(1, 3, true)
],
);
assert_eq!(
eds.axis(AxisType::Row, 2).unwrap(),
vec![
share(2, 0, true),
share(2, 1, true),
share(2, 2, true),
share(2, 3, true)
],
);
assert_eq!(
eds.axis(AxisType::Row, 3).unwrap(),
vec![
share(3, 0, true),
share(3, 1, true),
share(3, 2, true),
share(3, 3, true)
],
);
assert_eq!(
eds.column(0).unwrap(),
vec![
share(0, 0, false),
share(1, 0, false),
share(2, 0, true),
share(3, 0, true)
],
);
assert_eq!(
eds.column(1).unwrap(),
vec![
share(0, 1, false),
share(1, 1, false),
share(2, 1, true),
share(3, 1, true)
],
);
assert_eq!(
eds.column(2).unwrap(),
vec![
share(0, 2, true),
share(1, 2, true),
share(2, 2, true),
share(3, 2, true)
],
);
assert_eq!(
eds.column(3).unwrap(),
vec![
share(0, 3, true),
share(1, 3, true),
share(2, 3, true),
share(3, 3, true)
],
);
assert_eq!(
eds.axis(AxisType::Col, 0).unwrap(),
vec![
share(0, 0, false),
share(1, 0, false),
share(2, 0, true),
share(3, 0, true)
],
);
assert_eq!(
eds.axis(AxisType::Col, 1).unwrap(),
vec![
share(0, 1, false),
share(1, 1, false),
share(2, 1, true),
share(3, 1, true)
],
);
assert_eq!(
eds.axis(AxisType::Col, 2).unwrap(),
vec![
share(0, 2, true),
share(1, 2, true),
share(2, 2, true),
share(3, 2, true)
],
);
assert_eq!(
eds.axis(AxisType::Col, 3).unwrap(),
vec![
share(0, 3, true),
share(1, 3, true),
share(2, 3, true),
share(3, 3, true)
],
);
}
#[test]
fn validation() {
ExtendedDataSquare::new(vec![], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![]], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![]; 4], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![0u8; SHARE_SIZE]; 4], "fake".to_string()).unwrap();
ExtendedDataSquare::new(vec![vec![0u8; SHARE_SIZE]; 6], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![0u8; SHARE_SIZE]; 16], "fake".to_string()).unwrap();
let share = |n| {
[
Namespace::new_v0(&[n]).unwrap().as_bytes(),
&[0u8; SHARE_SIZE - NS_SIZE][..],
]
.concat()
};
ExtendedDataSquare::from_ods(vec![
share(0), ])
.unwrap();
ExtendedDataSquare::from_ods(vec![
share(1),
share(2),
share(1),
share(3),
])
.unwrap();
ExtendedDataSquare::from_ods(vec![
share(1),
share(2),
share(1),
share(1), ])
.unwrap_err();
ExtendedDataSquare::from_ods(vec![
share(1),
share(1),
share(2),
share(1), ])
.unwrap_err();
ExtendedDataSquare::new(vec![share(1); 6 * 6], "fake".to_string()).unwrap_err();
}
#[test]
fn empty_block_eds() {
let s = include_str!("../test_data/chain1/extended_header_block_1.json");
let genesis: ExtendedHeader = serde_json::from_str(s).unwrap();
let eds = ExtendedDataSquare::empty();
let dah = DataAvailabilityHeader::from_eds(&eds);
assert_eq!(dah, genesis.dah);
}
#[test]
fn reconstruct_all() {
let eds = generate_eds(8 << (rand::random::<usize>() % 6));
let blobs = Blob::reconstruct_all(eds.data_square()).unwrap();
let expected = eds.square_width() as usize / 2 - 2;
assert_eq!(blobs.len(), expected);
}
}