use std::any::TypeId;
use std::collections::HashMap;
use std::fmt::Debug;
use once_cell::sync::OnceCell;
use sha2::{Digest, Sha256};
pub mod container;
mod primitive;
pub mod safe_string;
pub mod transaction_templates;
use borsh::{BorshDeserialize, BorshSerialize};
pub use container::Container;
use nmt_rs::simple_merkle::db::MemDb;
use nmt_rs::simple_merkle::tree::MerkleTree;
use nmt_rs::TmSha2Hasher;
pub use primitive::Primitive;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod schema_impls;
use thiserror::Error;
use transaction_templates::TransactionTemplateSet;
use crate::display::{Context as DisplayContext, DisplayVisitor, FormatError};
#[cfg(feature = "serde")]
use crate::json_to_borsh::{Context as EncodeContext, EncodeError, EncodeVisitor};
use crate::ty::byte_display::ByteParseError;
use crate::ty::{ContainerSerdeMetadata, LinkingScheme, Ty};
#[cfg(feature = "eip712")]
use crate::visitors::eip712::{Context as Eip712Context, Eip712Error, Eip712Visitor};
#[cfg(feature = "eip712")]
use alloy_dyn_abi::{Eip712Types, Error as AlloyEip712Error, PropertyDef, TypedData};
#[derive(Debug, Error)]
pub enum SchemaError {
#[error(transparent)]
FormatError(#[from] FormatError),
#[error(transparent)]
BorshError(#[from] borsh::io::Error),
#[cfg(feature = "serde")]
#[error(transparent)]
EncodeError(#[from] EncodeError),
#[cfg(feature = "serde")]
#[error(transparent)]
JsonError(#[from] serde_json::Error),
#[cfg(feature = "eip712")]
#[error(transparent)]
Eip712Error(#[from] Eip712Error),
#[cfg(feature = "eip712")]
#[error(transparent)]
AlloyEip712Error(#[from] AlloyEip712Error),
#[error(transparent)]
Bech32Error(#[from] ByteParseError),
#[error("Rollup type {0:?} was missing from schema")]
MissingRollupRoot(RollupRoots),
#[error("Template {0} not found in schema")]
UnknownTemplate(String),
#[error("Index {0} not found in schema")]
InvalidIndex(usize),
#[error("Metadata hash must be provided but was not initialized. The schema was not properly finalized, or the serialized schema was invalid.")]
MetadataHashNotInitialized,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct IndexLinking;
impl LinkingScheme for IndexLinking {
type TypeLink = Link;
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Link {
ByIndex(usize),
Immediate(Primitive),
Placeholder,
IndexedPlaceholder(usize),
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum MaybePartialLink {
Partial(Link),
Complete(Link),
}
impl MaybePartialLink {
fn into_inner(self) -> Link {
match self {
MaybePartialLink::Partial(link) => link,
MaybePartialLink::Complete(link) => link,
}
}
}
#[derive(Default)]
#[allow(clippy::type_complexity)] struct ConstructedMerkleTree(OnceCell<(MerkleTree<MemDb<[u8; 32]>, TmSha2Hasher>, [u8; 32])>);
impl Debug for ConstructedMerkleTree {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
#[derive(Default)]
struct MetadataHash(OnceCell<[u8; 32]>);
impl Debug for MetadataHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.get().fmt(f)
}
}
impl BorshSerialize for MetadataHash {
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
let hash = self.0.get().copied()
.ok_or_else(|| std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Cannot serialize Schema: metadata_hash not initialized. Call finalize() before serializing"
))?;
BorshSerialize::serialize(&hash, writer)
}
}
impl BorshDeserialize for MetadataHash {
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
let hash: [u8; 32] = BorshDeserialize::deserialize_reader(reader)?;
let metadata_hash = MetadataHash::default();
metadata_hash.0.set(hash).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Failed to set metadata_hash in OnceCell during deserialization",
)
})?;
Ok(metadata_hash)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ItemId(pub TypeId);
impl ItemId {
pub fn of<T: 'static + UniversalWallet>() -> Self {
T::id_override().unwrap_or(ItemId(TypeId::of::<T>()))
}
}
#[derive(Debug, Copy, Clone)]
pub enum RollupRoots {
Transaction = 0,
UnsignedTransaction = 1,
RuntimeCall = 2,
Address = 3,
}
#[derive(Debug, Default, Clone, BorshSerialize, BorshDeserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ChainData {
pub chain_id: u64,
pub chain_name: String,
}
#[derive(Default, Debug, BorshSerialize, BorshDeserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Schema {
types: Vec<Ty<IndexLinking>>,
root_type_indices: Vec<usize>,
chain_data: ChainData,
#[cfg_attr(feature = "serde", serde(skip))]
extra_metadata_hash: MetadataHash,
#[cfg_attr(feature = "serde", serde(skip))]
#[borsh(skip)]
chain_hash: OnceCell<[u8; 32]>,
#[borsh(skip)]
templates: Vec<TransactionTemplateSet>,
#[borsh(skip)]
serde_metadata: Vec<ContainerSerdeMetadata>,
#[cfg_attr(feature = "serde", serde(skip))]
#[borsh(skip)]
merkle_tree: ConstructedMerkleTree,
#[cfg_attr(feature = "serde", serde(skip))]
#[borsh(skip)]
known_types: HashMap<ItemId, usize>,
#[cfg_attr(feature = "serde", serde(skip))]
#[borsh(skip)]
under_construction: HashMap<ItemId, usize>,
}
impl Schema {
pub fn of_single_type<T: UniversalWallet>() -> Result<Self, SchemaError> {
let mut schema = Self::default();
T::make_root_of(&mut schema);
schema.finalize()?;
Ok(schema)
}
pub fn of_rollup_types_with_chain_data<
Transaction: UniversalWallet,
UnsignedTransaction: UniversalWallet,
RuntimeCall: UniversalWallet,
Address: UniversalWallet,
>(
chain_data: ChainData,
) -> Result<Self, SchemaError> {
let mut schema = Schema {
chain_data,
..Self::default()
};
Transaction::make_root_of(&mut schema);
UnsignedTransaction::make_root_of(&mut schema);
RuntimeCall::make_root_of(&mut schema);
Address::make_root_of(&mut schema);
schema.finalize()?;
Ok(schema)
}
pub fn chain_data(&self) -> &ChainData {
&self.chain_data
}
#[cfg(not(feature = "serde"))]
pub fn metadata_hash(&self) -> Result<[u8; 32], SchemaError> {
self.extra_metadata_hash
.0
.get()
.copied()
.ok_or(SchemaError::MetadataHashNotInitialized)
}
#[cfg(feature = "serde")]
pub fn metadata_hash(&self) -> Result<[u8; 32], SchemaError> {
self.extra_metadata_hash
.0
.get_or_try_init(|| self.calculate_metadata_hash())
.copied()
}
#[cfg(feature = "serde")]
fn calculate_metadata_hash(&self) -> Result<[u8; 32], SchemaError> {
let mut hasher = Sha256::new();
hasher.update(&borsh::to_vec(&self.templates)?);
hasher.update(&borsh::to_vec(&self.serde_metadata)?);
Ok(hasher.finalize().into())
}
#[cfg(feature = "serde")]
pub fn from_json(input: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(input)
}
pub fn rollup_expected_index(&self, rollup_type: RollupRoots) -> Result<usize, SchemaError> {
self.root_type_indices
.get(rollup_type as usize)
.copied()
.ok_or(SchemaError::MissingRollupRoot(rollup_type))
}
pub fn display(&self, type_index: usize, input: &[u8]) -> Result<String, SchemaError> {
let mut output = String::new();
let input = &mut &input[..];
let mut visitor = DisplayVisitor::new(input, &mut output);
self.types
.get(type_index)
.ok_or(SchemaError::InvalidIndex(type_index))?
.visit(self, &mut visitor, DisplayContext::default())?;
if !visitor.has_displayed_whole_input() {
return Err(FormatError::UnusedInput.into());
}
Ok(output)
}
#[cfg(feature = "eip712")]
pub fn eip712_json(&self, type_index: usize, input: &[u8]) -> Result<String, SchemaError> {
let Some(typed_data) = self.eip712_get_typed_data_inner(type_index, input)? else {
return Ok(String::default());
};
Ok(serde_json::to_string(&typed_data)?)
}
#[cfg(feature = "eip712")]
pub fn eip712_signing_hash(
&self,
type_index: usize,
input: &[u8],
) -> Result<[u8; 32], SchemaError> {
let Some(typed_data) = self.eip712_get_typed_data_inner(type_index, input)? else {
return Ok(Default::default());
};
Ok(typed_data.eip712_signing_hash()?.into())
}
#[cfg(feature = "eip712")]
pub fn eip712_signing_digest(
&self,
type_index: usize,
input: &[u8],
) -> Result<[u8; 66], SchemaError> {
let Some(typed_data) = self.eip712_get_typed_data_inner(type_index, input)? else {
return Ok([0; 66]);
};
let mut buf = [0u8; 66];
buf[0] = 0x19;
buf[1] = 0x01;
buf[2..34].copy_from_slice(typed_data.domain.separator().as_slice());
buf[34..].copy_from_slice(typed_data.hash_struct()?.as_slice());
Ok(buf)
}
#[cfg(feature = "eip712")]
fn eip712_get_typed_data_inner(
&self,
type_index: usize,
input: &[u8],
) -> Result<Option<TypedData>, SchemaError> {
let mut out_types = Eip712Types::default();
let input = &mut &input[..];
let mut visitor = Eip712Visitor::new(input, &mut out_types);
let root_type = self
.types()
.get(type_index)
.ok_or(SchemaError::InvalidIndex(type_index))?;
let Some(visitor_return) = root_type.visit(self, &mut visitor, Eip712Context::default())?
else {
return Ok(None);
};
if !visitor.has_displayed_whole_input() {
return Err(FormatError::UnusedInput.into());
}
out_types.insert(
"EIP712Domain".to_string(),
vec![
PropertyDef::new("string", "name").unwrap(),
PropertyDef::new("uint256", "chainId").unwrap(),
PropertyDef::new("bytes32", "salt").unwrap(),
],
);
Ok(Some(TypedData {
domain: alloy_dyn_abi::Eip712Domain {
name: Some(self.chain_data.chain_name.clone().into()),
version: None,
chain_id: Some(alloy_primitives::U256::from(self.chain_data.chain_id)),
verifying_contract: None,
salt: Some(self.chain_hash()?.into()),
},
resolver: out_types.into(),
primary_type: visitor_return.unique_type_name,
message: visitor_return.json_value,
}))
}
#[cfg(feature = "serde")]
pub fn json_to_borsh(&self, type_index: usize, input: &str) -> Result<Vec<u8>, SchemaError> {
let mut output = Vec::new();
let mut visitor = EncodeVisitor::new(&mut output)?;
self.types
.get(type_index)
.ok_or(SchemaError::InvalidIndex(type_index))?
.visit(self, &mut visitor, EncodeContext::new(input, type_index)?)?;
Ok(output)
}
#[cfg(feature = "serde")]
pub fn fill_template_from_json(
&self,
root_index: usize,
template_name: &str,
input: &str,
) -> Result<Vec<u8>, SchemaError> {
fn serde_to_schema_err(e: serde_json::Error) -> SchemaError {
SchemaError::EncodeError(EncodeError::Json(e.to_string()))
}
let template = self
.templates
.get(root_index)
.ok_or(SchemaError::InvalidIndex(root_index))?
.0
.get(template_name)
.ok_or(SchemaError::UnknownTemplate(template_name.to_string()))?;
let mut input_map: serde_json::Map<String, serde_json::Value> =
serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(input)
.map_err(serde_to_schema_err)?;
let mut output = template.preencoded_bytes().to_owned();
for (name, input) in template.inputs().iter().rev() {
let ty = match input.type_link() {
Link::ByIndex(i) => self.types.get(*i).expect("Template {name} contained an invalid link: {i}. This is a major bug with template generation."),
Link::Immediate(ty) => &ty.clone().into(),
Link::Placeholder | Link::IndexedPlaceholder(_) => panic!("Template {name} contained placeholder link. This is a major bug with template generation.")
};
let json_value = input_map.remove(name).ok_or(EncodeError::MissingType {
name: name.to_owned(),
})?;
let mut buf = Vec::new();
let mut visitor = EncodeVisitor::new(&mut buf)?;
ty.visit(
self,
&mut visitor,
EncodeContext::from_val(json_value, input.type_link()),
)?;
output.splice(input.offset()..input.offset(), buf);
}
if !input_map.is_empty() {
return Err(SchemaError::EncodeError(EncodeError::UnusedInput {
value: input_map.iter().next().unwrap().0.to_owned(),
}));
}
Ok(output)
}
#[cfg(feature = "serde")]
pub fn templates(&self, index: usize) -> Result<Vec<String>, SchemaError> {
Ok(self
.templates
.get(index)
.ok_or(SchemaError::InvalidIndex(index))?
.0
.keys()
.cloned()
.collect())
}
pub fn chain_hash(&self) -> Result<[u8; 32], SchemaError> {
self.chain_hash
.get_or_try_init(|| {
let merkle_root = self.merkle_root()?;
let mut hasher = Sha256::new();
hasher.update(&borsh::to_vec(&self.root_type_indices)?);
hasher.update(&borsh::to_vec(&self.chain_data)?);
let internal_data_hash: [u8; 32] = hasher.finalize().into();
let metadata_hash = self.metadata_hash()?;
let mut hasher = Sha256::new();
hasher.update(merkle_root);
hasher.update(internal_data_hash);
hasher.update(metadata_hash);
let chain_hash: [u8; 32] = hasher.finalize().into();
Ok(chain_hash)
})
.copied()
}
fn merkle_root(&self) -> Result<[u8; 32], SchemaError> {
let (_, root) = self.merkle_tree.0.get_or_try_init(|| {
let mut tree = MerkleTree::new();
for ty in &self.types {
tree.push_raw_leaf(&borsh::to_vec(ty)?)
}
let root = tree.root();
Ok::<_, SchemaError>((tree, root))
})?;
Ok(*root)
}
fn finalize(&self) -> Result<(), SchemaError> {
self.metadata_hash()?;
self.chain_hash()?;
Ok(())
}
pub fn types(&self) -> &[Ty<IndexLinking>] {
&self.types
}
pub fn serde_metadata(&self) -> &[ContainerSerdeMetadata] {
&self.serde_metadata
}
pub fn root_types(&self) -> &[usize] {
&self.root_type_indices
}
fn find_item_by_id(&self, item_id: &ItemId) -> Option<usize> {
self.known_types.get(item_id).copied()
}
fn link_child_to_parent(&mut self, parent: ItemId, child: Link) {
let idx = self.known_types.get(&parent).unwrap_or_else(|| panic!("Tried to link a child to a parent ({parent:?}) that the schema doesn't have. This is a bug in a hand-written schema."));
let remaining_children = *self.under_construction.get(&parent).unwrap_or_else(|| panic!("Tried to link too many children to parent ({parent:?}). This is a bug in a hand-written schema."));
if remaining_children == 1 {
self.under_construction.remove(&parent);
} else {
self.under_construction
.insert(parent, remaining_children - 1);
}
self.types[*idx].fill_next_placholder(child);
}
fn get_partial_link_to(
&mut self,
item: Item<IndexLinking>,
item_id: ItemId,
) -> MaybePartialLink {
match item {
Item::Container(c) => {
if let Some(location) = self.find_item_by_id(&item_id) {
MaybePartialLink::Complete(Link::ByIndex(location))
} else {
let num_children = c.num_children();
let serde_metadata = c.serde();
let location = self.types.len();
self.known_types.insert(item_id.clone(), location);
self.types.push(c.into());
self.serde_metadata.push(serde_metadata);
if num_children != 0 {
self.under_construction.insert(item_id, num_children);
MaybePartialLink::Partial(Link::ByIndex(location))
} else {
MaybePartialLink::Complete(Link::ByIndex(location))
}
}
}
Item::Atom(primitive) => MaybePartialLink::Complete(Link::Immediate(primitive)),
}
}
fn push_root_link(&mut self, link: Link) {
match link {
Link::ByIndex(i) => self.root_type_indices.push(i),
Link::Immediate(..) => {},
Link::Placeholder | Link::IndexedPlaceholder(_) => panic!("Attempted to register a placeholder link as a schema root - are you passing the right link?"),
}
}
}
pub enum Item<L: LinkingScheme> {
Container(Container<L>),
Atom(Primitive),
}
pub trait UniversalWallet: Sized + 'static {
fn get_child_links(schema: &mut Schema) -> Vec<Link>;
fn scaffold() -> Item<IndexLinking>;
fn write_schema(schema: &mut Schema) -> Link {
let item = Self::scaffold();
let item_id = ItemId::of::<Self>();
match item {
Item::Atom(_primitive) => {
panic!("Creating a schema for primitive root types is not supported. If this is necessary, wrap the primitive in a newtype struct. If you did not specify a primitive root type, this may be a bug in schema generation.");
}
Item::Container(container) => {
let link = schema.get_partial_link_to(Item::Container(container), item_id.clone());
if let MaybePartialLink::Complete(link) = link {
return link;
}
for child in Self::get_child_links(schema) {
schema.link_child_to_parent(item_id.clone(), child);
}
link.into_inner()
}
}
}
fn make_root_of(schema: &mut Schema) {
let link = Self::write_schema(schema);
assert!(
schema.under_construction.is_empty(),
"Schema generation left some types partially constructed. This is a bug in the schema. {schema:?}"
);
schema.push_root_link(link);
let templates = Self::get_child_templates(schema);
schema.templates.push(templates);
}
fn get_child_templates(_schema: &mut Schema) -> TransactionTemplateSet {
Default::default()
}
fn make_linkable(schema: &mut Schema) -> Link {
match Self::scaffold() {
Item::Container(_) => Self::write_schema(schema),
Item::Atom(atom) => Link::Immediate(atom),
}
}
fn id_override() -> Option<ItemId> {
None
}
}
pub trait OverrideSchema {
type Output: UniversalWallet;
}
impl<T: OverrideSchema + 'static> UniversalWallet for T {
fn scaffold() -> Item<IndexLinking> {
<Self as OverrideSchema>::Output::scaffold()
}
fn get_child_links(schema: &mut Schema) -> Vec<Link> {
<Self as OverrideSchema>::Output::get_child_links(schema)
}
fn id_override() -> Option<ItemId> {
<Self as OverrideSchema>::Output::id_override()
}
fn make_linkable(schema: &mut Schema) -> Link {
<Self as OverrideSchema>::Output::make_linkable(schema)
}
fn write_schema(schema: &mut Schema) -> Link {
<Self as OverrideSchema>::Output::write_schema(schema)
}
}