use crate::{
error::{Error, ErrorKind, Kind, Location},
EncodeAsFields, EncodeAsType, Field, FieldIter,
};
use alloc::collections::BTreeMap;
use alloc::{string::ToString, vec::Vec};
use scale_info::{PortableRegistry, TypeDef};
pub struct Composite<Vals>(pub Vals);
impl<'a, Vals> EncodeAsType for Composite<Vals>
where
Vals: ExactSizeIterator<Item = (Option<&'a str>, &'a dyn EncodeAsType)> + Clone,
{
fn encode_as_type_to(
&self,
type_id: u32,
types: &PortableRegistry,
out: &mut Vec<u8>,
) -> Result<(), Error> {
let mut vals_iter = self.0.clone();
let vals_iter_len = vals_iter.len();
let type_id = skip_through_single_unnamed_fields(type_id, types);
let ty = types
.resolve(type_id)
.ok_or_else(|| Error::new(ErrorKind::TypeNotFound(type_id)))?;
match &ty.type_def {
TypeDef::Tuple(tuple) => {
if vals_iter_len == 1 {
return vals_iter
.next()
.unwrap()
.1
.encode_as_type_to(type_id, types, out);
}
let mut fields = tuple.fields.iter().map(|f| Field::unnamed(f.id));
self.encode_as_fields_to(&mut fields, types, out)
}
TypeDef::Composite(composite) => {
let is_named_vals = vals_iter.clone().any(|(name, _)| name.is_some());
if !is_named_vals && vals_iter_len == 1 {
return vals_iter
.next()
.unwrap()
.1
.encode_as_type_to(type_id, types, out);
}
let mut fields = composite
.fields
.iter()
.map(|f| Field::new(f.ty.id, f.name.as_deref()));
self.encode_as_fields_to(&mut fields, types, out)
}
_ => {
if vals_iter_len == 1 {
return vals_iter
.next()
.unwrap()
.1
.encode_as_type_to(type_id, types, out);
}
Err(Error::new(ErrorKind::WrongShape {
actual: Kind::Tuple,
expected: type_id,
}))
}
}
}
}
impl<'a, Vals> EncodeAsFields for Composite<Vals>
where
Vals: ExactSizeIterator<Item = (Option<&'a str>, &'a dyn EncodeAsType)> + Clone,
{
fn encode_as_fields_to(
&self,
fields: &mut dyn FieldIter<'_>,
types: &PortableRegistry,
out: &mut Vec<u8>,
) -> Result<(), Error> {
let vals_iter = self.0.clone();
let fields = smallvec::SmallVec::<[_; 16]>::from_iter(fields);
let is_named = {
let is_target_named = fields.iter().any(|f| f.name().is_some());
let is_source_named = vals_iter.clone().any(|(name, _)| name.is_some());
is_target_named && is_source_named
};
if is_named {
let source_fields_by_name: BTreeMap<&str, &dyn EncodeAsType> = vals_iter
.map(|(name, val)| (name.unwrap_or(""), val))
.collect();
for field in fields {
let name = field.name().unwrap_or("");
let Some(value) = source_fields_by_name.get(name) else {
return Err(Error::new(ErrorKind::CannotFindField { name: name.to_string() }))
};
value
.encode_as_type_to(field.id(), types, out)
.map_err(|e| e.at_field(name.to_string()))?;
}
Ok(())
} else {
let fields_len = fields.len();
if fields_len != vals_iter.len() {
return Err(Error::new(ErrorKind::WrongLength {
actual_len: vals_iter.len(),
expected_len: fields_len,
}));
}
for (idx, (field, (name, val))) in fields.iter().zip(vals_iter).enumerate() {
val.encode_as_type_to(field.id(), types, out).map_err(|e| {
let loc = if let Some(name) = name {
Location::field(name.to_string())
} else {
Location::idx(idx)
};
e.at(loc)
})?;
}
Ok(())
}
}
}
fn skip_through_single_unnamed_fields(type_id: u32, types: &PortableRegistry) -> u32 {
let Some(ty) = types.resolve(type_id) else {
return type_id
};
match &ty.type_def {
TypeDef::Tuple(tuple) if tuple.fields.len() == 1 => {
skip_through_single_unnamed_fields(tuple.fields[0].id, types)
}
TypeDef::Composite(composite)
if composite.fields.len() == 1 && composite.fields[0].name.is_none() =>
{
skip_through_single_unnamed_fields(composite.fields[0].ty.id, types)
}
_ => type_id,
}
}