use std::cmp::Ordering;
use std::fmt::Display;
use serde::{Deserialize, Deserializer, Serialize};
use crate::consts::appconsts::SHARE_SIZE;
use crate::consts::data_availability_header::{
MAX_EXTENDED_SQUARE_WIDTH, MIN_EXTENDED_SQUARE_WIDTH,
};
use crate::namespaced_data::{NamespacedData, NamespacedDataId};
use crate::nmt::{Namespace, NamespacedSha2Hasher, Nmt, NmtExt, NS_SIZE};
use crate::{bail_validation, DataAvailabilityHeader, Error, InfoByte, Result};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum AxisType {
Row = 0,
Col,
}
impl TryFrom<i32> for AxisType {
type Error = Error;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
0 => Ok(AxisType::Row),
1 => Ok(AxisType::Col),
n => Err(Error::InvalidAxis(n)),
}
}
}
impl Display for AxisType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AxisType::Row => write!(f, "Row"),
AxisType::Col => write!(f, "Column"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct ExtendedDataSquare {
#[serde(with = "celestia_tendermint_proto::serializers::bytes::vec_base64string")]
data_square: Vec<Vec<u8>>,
codec: String,
#[serde(skip)]
square_width: u16,
}
impl ExtendedDataSquare {
pub fn new(shares: Vec<Vec<u8>>, codec: String) -> Result<Self> {
const MIN_SHARES: usize = MIN_EXTENDED_SQUARE_WIDTH * MIN_EXTENDED_SQUARE_WIDTH;
const MAX_SHARES: usize = MAX_EXTENDED_SQUARE_WIDTH * MAX_EXTENDED_SQUARE_WIDTH;
if shares.len() < MIN_SHARES {
bail_validation!(
"shares len ({}) < MIN_SHARES ({})",
shares.len(),
MIN_SHARES
);
}
if shares.len() > MAX_SHARES {
bail_validation!(
"shares len ({}) > MAX_SHARES ({})",
shares.len(),
MAX_SHARES
);
}
let square_width = f64::sqrt(shares.len() as f64) as usize;
if square_width * square_width != shares.len() {
return Err(Error::EdsInvalidDimentions);
}
let square_width = u16::try_from(square_width).map_err(|_| Error::EdsInvalidDimentions)?;
if square_width.count_ones() != 1 {
return Err(Error::EdsInvalidDimentions);
}
let eds = ExtendedDataSquare {
data_square: shares,
codec,
square_width,
};
let check_share = |row, col, prev_ns: Option<Namespace>, axis| -> Result<Namespace> {
let share = eds.share(row, col)?;
if share.len() != SHARE_SIZE {
bail_validation!("share len ({}) != SHARE_SIZE ({})", share.len(), SHARE_SIZE);
}
let ns = if is_ods_square(row, col, eds.square_width()) {
Namespace::from_raw(&share[..NS_SIZE])?
} else {
Namespace::PARITY_SHARE
};
if prev_ns.map_or(false, |prev_ns| ns < prev_ns) {
let axis_idx = match axis {
AxisType::Row => row,
AxisType::Col => col,
};
bail_validation!("Shares of {axis} {axis_idx} are not sorted by their namespace");
}
Ok(ns)
};
for row in 0..eds.square_width() {
let mut prev_ns = None;
for col in 0..eds.square_width() {
prev_ns = Some(check_share(row, col, prev_ns, AxisType::Row)?);
}
}
for col in 0..eds.square_width() {
let mut prev_ns = None;
for row in 0..eds.square_width() {
prev_ns = Some(check_share(row, col, prev_ns, AxisType::Col)?);
}
}
Ok(eds)
}
pub fn empty() -> ExtendedDataSquare {
let ods = vec![[
Namespace::TAIL_PADDING.as_bytes(),
&[InfoByte::new(0, true).unwrap().as_u8()],
&[0; SHARE_SIZE - NS_SIZE - 1],
]
.concat()];
ExtendedDataSquare::from_ods(ods).expect("invalid EDS")
}
pub fn from_ods(mut ods_shares: Vec<Vec<u8>>) -> Result<ExtendedDataSquare> {
let ods_width = f64::sqrt(ods_shares.len() as f64) as usize;
if ods_width * ods_width != ods_shares.len() {
return Err(Error::EdsInvalidDimentions);
}
let eds_width = ods_width * 2;
let mut eds_shares = Vec::with_capacity(eds_width * eds_width);
for _ in 0..ods_width {
eds_shares.extend(ods_shares.drain(..ods_width));
for _ in 0..ods_width {
eds_shares.push(vec![0; SHARE_SIZE]);
}
}
eds_shares.resize(eds_width * eds_width, vec![0; SHARE_SIZE]);
for row in eds_shares.chunks_mut(eds_width).take(ods_width) {
leopard_codec::encode(row, ods_width)?;
}
for col in 0..ods_width {
let mut col: Vec<_> = eds_shares.iter_mut().skip(col).step_by(eds_width).collect();
leopard_codec::encode(&mut col, ods_width)?;
}
for row in eds_shares.chunks_mut(eds_width).skip(ods_width) {
leopard_codec::encode(row, ods_width)?;
}
ExtendedDataSquare::new(eds_shares, "Leopard".to_string())
}
pub fn data_square(&self) -> &[Vec<u8>] {
&self.data_square
}
pub fn codec(&self) -> &str {
self.codec.as_str()
}
pub fn share(&self, row: u16, column: u16) -> Result<&[u8]> {
let index = usize::from(row) * usize::from(self.square_width) + usize::from(column);
self.data_square
.get(index)
.map(Vec::as_slice)
.ok_or(Error::EdsIndexOutOfRange(row, column))
}
#[cfg(any(test, feature = "test-utils"))]
pub(crate) fn share_mut(&mut self, row: u16, column: u16) -> Result<&mut [u8]> {
let index = usize::from(row) * usize::from(self.square_width) + usize::from(column);
self.data_square
.get_mut(index)
.map(Vec::as_mut_slice)
.ok_or(Error::EdsIndexOutOfRange(row, column))
}
pub fn row(&self, index: u16) -> Result<Vec<Vec<u8>>> {
self.axis(AxisType::Row, index)
}
pub fn row_nmt(&self, index: u16) -> Result<Nmt> {
self.axis_nmt(AxisType::Row, index)
}
pub fn column(&self, index: u16) -> Result<Vec<Vec<u8>>> {
self.axis(AxisType::Col, index)
}
pub fn column_nmt(&self, index: u16) -> Result<Nmt> {
self.axis_nmt(AxisType::Col, index)
}
pub fn axis(&self, axis: AxisType, index: u16) -> Result<Vec<Vec<u8>>> {
(0..self.square_width)
.map(|i| {
let (row, col) = match axis {
AxisType::Row => (index, i),
AxisType::Col => (i, index),
};
self.share(row, col).map(ToOwned::to_owned)
})
.collect()
}
pub fn axis_nmt(&self, axis: AxisType, index: u16) -> Result<Nmt> {
let mut tree = Nmt::default();
for i in 0..self.square_width {
let (row, col) = match axis {
AxisType::Row => (index, i),
AxisType::Col => (i, index),
};
let share = self.share(row, col)?;
let ns = if is_ods_square(col, row, self.square_width) {
Namespace::from_raw(&share[..NS_SIZE])?
} else {
Namespace::PARITY_SHARE
};
tree.push_leaf(share, *ns).map_err(Error::Nmt)?;
}
Ok(tree)
}
pub fn square_width(&self) -> u16 {
self.square_width
}
pub fn get_namespaced_data(
&self,
namespace: Namespace,
dah: &DataAvailabilityHeader,
height: u64,
) -> Result<Vec<NamespacedData>> {
let mut data = Vec::new();
for row in 0..self.square_width {
let Some(row_root) = dah.row_root(row) else {
break;
};
if !row_root.contains::<NamespacedSha2Hasher>(*namespace) {
continue;
}
let mut shares = Vec::with_capacity(self.square_width.into());
for col in 0..self.square_width {
let share = self.share(row, col)?;
let ns = if is_ods_square(row, col, self.square_width) {
Namespace::from_raw(&share[..NS_SIZE])?
} else {
Namespace::PARITY_SHARE
};
match ns.cmp(&namespace) {
Ordering::Less => {}
Ordering::Equal => shares.push(share.to_owned()),
Ordering::Greater => break,
}
}
let proof = self.row_nmt(row)?.get_namespace_proof(*namespace);
let id = NamespacedDataId::new(namespace, row, height)?;
data.push(NamespacedData {
id,
proof: proof.into(),
shares,
})
}
Ok(data)
}
}
#[derive(Deserialize)]
struct RawExtendedDataSquare {
#[serde(with = "celestia_tendermint_proto::serializers::bytes::vec_base64string")]
pub data_square: Vec<Vec<u8>>,
pub codec: String,
}
impl<'de> Deserialize<'de> for ExtendedDataSquare {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let eds = RawExtendedDataSquare::deserialize(deserializer)?;
let share_number = eds.data_square.len();
ExtendedDataSquare::new(eds.data_square, eds.codec).map_err(|_| {
<D::Error as serde::de::Error>::invalid_length(
share_number,
&"number of shares must be a perfect square",
)
})
}
}
pub(crate) fn is_ods_square(row: u16, column: u16, square_width: u16) -> bool {
let ods_width = square_width / 2;
row < ods_width && column < ods_width
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ExtendedHeader;
#[test]
fn axis_type_serialization() {
assert_eq!(AxisType::Row as u8, 0);
assert_eq!(AxisType::Col as u8, 1);
}
#[test]
fn axis_type_deserialization() {
assert_eq!(AxisType::try_from(0).unwrap(), AxisType::Row);
assert_eq!(AxisType::try_from(1).unwrap(), AxisType::Col);
let axis_type_err = AxisType::try_from(2).unwrap_err();
assert!(matches!(axis_type_err, Error::InvalidAxis(2)));
let axis_type_err = AxisType::try_from(99).unwrap_err();
assert!(matches!(axis_type_err, Error::InvalidAxis(99)));
}
#[test]
fn get_namespaced_data() {
let eds_json = include_str!("../test_data/shwap_samples/eds.json");
let eds: ExtendedDataSquare = serde_json::from_str(eds_json).unwrap();
let dah_json = include_str!("../test_data/shwap_samples/dah.json");
let dah: DataAvailabilityHeader = serde_json::from_str(dah_json).unwrap();
let height = 45577;
let rows = eds
.get_namespaced_data(Namespace::new_v0(&[1, 170]).unwrap(), &dah, height)
.unwrap();
assert_eq!(rows.len(), 1);
let row = &rows[0];
row.verify(&dah).unwrap();
assert_eq!(row.shares.len(), 2);
let rows = eds
.get_namespaced_data(Namespace::new_v0(&[1, 187]).unwrap(), &dah, height)
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].shares.len(), 1);
assert_eq!(rows[1].shares.len(), 4);
for row in rows {
row.verify(&dah).unwrap();
}
}
#[test]
fn nmt_roots() {
let eds_json = include_str!("../test_data/shwap_samples/eds.json");
let eds: ExtendedDataSquare = serde_json::from_str(eds_json).unwrap();
let dah_json = include_str!("../test_data/shwap_samples/dah.json");
let dah: DataAvailabilityHeader = serde_json::from_str(dah_json).unwrap();
assert_eq!(dah.row_roots().len(), usize::from(eds.square_width()));
assert_eq!(dah.column_roots().len(), usize::from(eds.square_width()));
for (i, root) in dah.row_roots().iter().enumerate() {
let mut tree = eds.row_nmt(i as u16).unwrap();
assert_eq!(*root, tree.root());
let mut tree = eds.axis_nmt(AxisType::Row, i as u16).unwrap();
assert_eq!(*root, tree.root());
}
for (i, root) in dah.column_roots().iter().enumerate() {
let mut tree = eds.column_nmt(i as u16).unwrap();
assert_eq!(*root, tree.root());
let mut tree = eds.axis_nmt(AxisType::Col, i as u16).unwrap();
assert_eq!(*root, tree.root());
}
}
#[test]
fn ods_square() {
assert!(is_ods_square(0, 0, 4));
assert!(is_ods_square(0, 1, 4));
assert!(is_ods_square(1, 0, 4));
assert!(is_ods_square(1, 1, 4));
assert!(!is_ods_square(0, 2, 4));
assert!(!is_ods_square(0, 3, 4));
assert!(!is_ods_square(1, 2, 4));
assert!(!is_ods_square(1, 3, 4));
assert!(!is_ods_square(2, 0, 4));
assert!(!is_ods_square(2, 1, 4));
assert!(!is_ods_square(3, 0, 4));
assert!(!is_ods_square(3, 1, 4));
assert!(!is_ods_square(2, 2, 4));
assert!(!is_ods_square(2, 3, 4));
assert!(!is_ods_square(3, 2, 4));
assert!(!is_ods_square(3, 3, 4));
}
#[test]
fn get_row_and_col() {
let share = |x, y| {
[
Namespace::new_v0(&[x, y]).unwrap().as_bytes(),
&[0u8; SHARE_SIZE - NS_SIZE][..],
]
.concat()
};
#[rustfmt::skip]
let shares = vec![
share(0, 0), share(0, 1), share(0, 2), share(0, 3),
share(1, 0), share(1, 1), share(1, 2), share(1, 3),
share(2, 0), share(2, 1), share(2, 2), share(2, 3),
share(3, 0), share(3, 1), share(3, 2), share(3, 3),
];
let eds = ExtendedDataSquare::new(shares, "fake".to_string()).unwrap();
assert_eq!(
eds.row(0).unwrap(),
vec![share(0, 0), share(0, 1), share(0, 2), share(0, 3)]
);
assert_eq!(
eds.row(1).unwrap(),
vec![share(1, 0), share(1, 1), share(1, 2), share(1, 3)]
);
assert_eq!(
eds.row(2).unwrap(),
vec![share(2, 0), share(2, 1), share(2, 2), share(2, 3)]
);
assert_eq!(
eds.row(3).unwrap(),
vec![share(3, 0), share(3, 1), share(3, 2), share(3, 3)]
);
assert_eq!(
eds.axis(AxisType::Row, 0).unwrap(),
vec![share(0, 0), share(0, 1), share(0, 2), share(0, 3)]
);
assert_eq!(
eds.axis(AxisType::Row, 1).unwrap(),
vec![share(1, 0), share(1, 1), share(1, 2), share(1, 3)]
);
assert_eq!(
eds.axis(AxisType::Row, 2).unwrap(),
vec![share(2, 0), share(2, 1), share(2, 2), share(2, 3)]
);
assert_eq!(
eds.axis(AxisType::Row, 3).unwrap(),
vec![share(3, 0), share(3, 1), share(3, 2), share(3, 3)]
);
assert_eq!(
eds.column(0).unwrap(),
vec![share(0, 0), share(1, 0), share(2, 0), share(3, 0)]
);
assert_eq!(
eds.column(1).unwrap(),
vec![share(0, 1), share(1, 1), share(2, 1), share(3, 1)]
);
assert_eq!(
eds.column(2).unwrap(),
vec![share(0, 2), share(1, 2), share(2, 2), share(3, 2)]
);
assert_eq!(
eds.column(3).unwrap(),
vec![share(0, 3), share(1, 3), share(2, 3), share(3, 3)]
);
assert_eq!(
eds.axis(AxisType::Col, 0).unwrap(),
vec![share(0, 0), share(1, 0), share(2, 0), share(3, 0)]
);
assert_eq!(
eds.axis(AxisType::Col, 1).unwrap(),
vec![share(0, 1), share(1, 1), share(2, 1), share(3, 1)]
);
assert_eq!(
eds.axis(AxisType::Col, 2).unwrap(),
vec![share(0, 2), share(1, 2), share(2, 2), share(3, 2)]
);
assert_eq!(
eds.axis(AxisType::Col, 3).unwrap(),
vec![share(0, 3), share(1, 3), share(2, 3), share(3, 3)]
);
}
#[test]
fn validation() {
ExtendedDataSquare::new(vec![], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![]], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![]; 4], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![0u8; SHARE_SIZE]; 4], "fake".to_string()).unwrap();
ExtendedDataSquare::new(vec![vec![0u8; SHARE_SIZE]; 6], "fake".to_string()).unwrap_err();
ExtendedDataSquare::new(vec![vec![0u8; SHARE_SIZE]; 16], "fake".to_string()).unwrap();
let share = |n| {
[
Namespace::new_v0(&[n]).unwrap().as_bytes(),
&[0u8; SHARE_SIZE - NS_SIZE][..],
]
.concat()
};
ExtendedDataSquare::from_ods(vec![
share(0), ])
.unwrap();
ExtendedDataSquare::from_ods(vec![
share(1),
share(2),
share(1),
share(3),
])
.unwrap();
ExtendedDataSquare::from_ods(vec![
share(1),
share(2),
share(1),
share(1), ])
.unwrap_err();
ExtendedDataSquare::from_ods(vec![
share(1),
share(1),
share(2),
share(1), ])
.unwrap_err();
ExtendedDataSquare::new(vec![share(1); 6 * 6], "fake".to_string()).unwrap_err();
let square_width = MAX_EXTENDED_SQUARE_WIDTH * 2;
ExtendedDataSquare::new(vec![share(1); square_width.pow(2)], "fake".to_string())
.unwrap_err();
}
#[test]
fn empty_block_eds() {
let s = include_str!("../test_data/chain1/extended_header_block_1.json");
let genesis: ExtendedHeader = serde_json::from_str(s).unwrap();
let eds = ExtendedDataSquare::empty();
let dah = DataAvailabilityHeader::from_eds(&eds);
assert_eq!(dah, genesis.dah);
}
}