use std::collections::HashSet;
use thiserror::Error;
use crate::{
internal::{skip_fields, skip_fields_named},
DataType, EnumRepr, EnumType, EnumVariants, GenericType, List, LiteralType, Map, PrimitiveType,
SpectaID, StructFields, TypeMap,
};
#[derive(Error, Debug, PartialEq)]
pub enum SerdeError {
#[error("A map key must be a 'string' or 'number' type")]
InvalidMapKey,
#[error("#[specta(tag = \"...\")] cannot be used with tuple variants")]
InvalidInternallyTaggedEnum,
#[error("the usage of #[specta(skip)] means the type can't be serialized")]
InvalidUsageOfSkip,
}
pub(crate) fn is_valid_ty(dt: &DataType, type_map: &TypeMap) -> Result<(), SerdeError> {
is_valid_ty_internal(dt, type_map, &mut Default::default())
}
fn is_valid_ty_internal(
dt: &DataType,
type_map: &TypeMap,
checked_references: &mut HashSet<SpectaID>,
) -> Result<(), SerdeError> {
match dt {
DataType::Nullable(ty) => is_valid_ty(ty, type_map)?,
DataType::Map(ty) => {
is_valid_map_key(ty.key_ty(), type_map)?;
is_valid_ty_internal(ty.value_ty(), type_map, checked_references)?;
}
DataType::Struct(ty) => match ty.fields() {
StructFields::Unit => {}
StructFields::Unnamed(ty) => {
for (_, ty) in skip_fields(ty.fields()) {
is_valid_ty_internal(ty, type_map, checked_references)?;
}
}
StructFields::Named(ty) => {
for (_, (_, ty)) in skip_fields_named(ty.fields()) {
is_valid_ty_internal(ty, type_map, checked_references)?;
}
}
},
DataType::Enum(ty) => {
validate_enum(ty, type_map)?;
for (_variant_name, variant) in ty.variants().iter() {
match &variant.inner {
EnumVariants::Unit => {}
EnumVariants::Named(variant) => {
for (_, (_, ty)) in skip_fields_named(variant.fields()) {
is_valid_ty_internal(ty, type_map, checked_references)?;
}
}
EnumVariants::Unnamed(variant) => {
for (_, ty) in skip_fields(variant.fields()) {
is_valid_ty_internal(ty, type_map, checked_references)?;
}
}
}
}
}
DataType::Tuple(ty) => {
for ty in ty.elements() {
is_valid_ty_internal(ty, type_map, checked_references)?;
}
}
DataType::Result(ty) => {
is_valid_ty_internal(&ty.0, type_map, checked_references)?;
is_valid_ty_internal(&ty.1, type_map, checked_references)?;
}
DataType::Reference(ty) => {
for (_, generic) in ty.generics() {
is_valid_ty_internal(generic, type_map, checked_references)?;
}
#[allow(clippy::panic)]
if !checked_references.contains(&ty.sid) {
checked_references.insert(ty.sid);
let ty = type_map
.get(ty.sid)
.unwrap_or_else(|| panic!("Type '{}' was never populated.", ty.sid.type_name));
is_valid_ty_internal(&ty.inner, type_map, checked_references)?;
}
}
_ => {}
}
Ok(())
}
fn is_valid_map_key(key_ty: &DataType, type_map: &TypeMap) -> Result<(), SerdeError> {
match key_ty {
DataType::Any => Ok(()),
DataType::Primitive(ty) => match ty {
PrimitiveType::i8
| PrimitiveType::i16
| PrimitiveType::i32
| PrimitiveType::i64
| PrimitiveType::i128
| PrimitiveType::isize
| PrimitiveType::u8
| PrimitiveType::u16
| PrimitiveType::u32
| PrimitiveType::u64
| PrimitiveType::u128
| PrimitiveType::usize
| PrimitiveType::f32
| PrimitiveType::f64
| PrimitiveType::String
| PrimitiveType::char => Ok(()),
_ => Err(SerdeError::InvalidMapKey),
},
DataType::Literal(ty) => match ty {
LiteralType::i8(_)
| LiteralType::i16(_)
| LiteralType::i32(_)
| LiteralType::u8(_)
| LiteralType::u16(_)
| LiteralType::u32(_)
| LiteralType::f32(_)
| LiteralType::f64(_)
| LiteralType::String(_)
| LiteralType::char(_) => Ok(()),
_ => Err(SerdeError::InvalidMapKey),
},
DataType::Enum(ty) => {
for (_variant_name, variant) in &ty.variants {
match &variant.inner {
EnumVariants::Unit => {}
EnumVariants::Unnamed(item) => {
if item.fields.len() > 1 {
return Err(SerdeError::InvalidMapKey);
}
if ty.repr != EnumRepr::Untagged {
return Err(SerdeError::InvalidMapKey);
}
}
_ => return Err(SerdeError::InvalidMapKey),
}
}
Ok(())
}
DataType::Reference(r) => {
let ty = type_map.get(r.sid).expect("Type was never populated");
is_valid_map_key(&resolve_generics(ty.inner.clone(), &r.generics), type_map)
}
_ => Err(SerdeError::InvalidMapKey),
}
}
fn validate_enum(e: &EnumType, type_map: &TypeMap) -> Result<(), SerdeError> {
let valid_variants = e.variants().iter().filter(|(_, v)| !v.skip).count();
if valid_variants == 0 && !e.variants().is_empty() {
return Err(SerdeError::InvalidUsageOfSkip);
}
if let EnumRepr::Internal { .. } = e.repr() {
validate_internally_tag_enum(e, type_map)?;
}
Ok(())
}
fn validate_internally_tag_enum(e: &EnumType, type_map: &TypeMap) -> Result<(), SerdeError> {
for (_variant_name, variant) in &e.variants {
match &variant.inner {
EnumVariants::Unit => {}
EnumVariants::Named(_) => {}
EnumVariants::Unnamed(item) => {
let mut fields = skip_fields(item.fields());
let Some(first_field) = fields.next() else {
continue;
};
if fields.next().is_some() {
return Err(SerdeError::InvalidInternallyTaggedEnum);
}
validate_internally_tag_enum_datatype(first_field.1, type_map)?;
}
}
}
Ok(())
}
fn validate_internally_tag_enum_datatype(
ty: &DataType,
type_map: &TypeMap,
) -> Result<(), SerdeError> {
match ty {
DataType::Any => return Err(SerdeError::InvalidInternallyTaggedEnum),
DataType::Map(_) => {}
DataType::Struct(_) => {}
DataType::Enum(ty) => match ty.repr {
EnumRepr::Untagged => validate_internally_tag_enum(ty, type_map)?,
EnumRepr::External => {}
EnumRepr::Internal { .. } => {}
EnumRepr::Adjacent { .. } => {}
},
DataType::Tuple(ty) if ty.elements.is_empty() => {}
DataType::Result(_) => {}
DataType::Reference(ty) => {
let ty = type_map.get(ty.sid).expect("Type was never populated");
validate_internally_tag_enum_datatype(&ty.inner, type_map)?;
}
_ => return Err(SerdeError::InvalidInternallyTaggedEnum),
}
Ok(())
}
fn resolve_generics(mut dt: DataType, generics: &Vec<(GenericType, DataType)>) -> DataType {
match dt {
DataType::Primitive(_) | DataType::Literal(_) | DataType::Any | DataType::Unknown => dt,
DataType::List(v) => DataType::List(List {
ty: Box::new(resolve_generics(*v.ty, generics)),
length: v.length,
unique: v.unique,
}),
DataType::Nullable(v) => DataType::Nullable(Box::new(resolve_generics(*v, generics))),
DataType::Map(v) => DataType::Map(Map {
key_ty: Box::new(resolve_generics(*v.key_ty, generics)),
value_ty: Box::new(resolve_generics(*v.value_ty, generics)),
}),
DataType::Struct(ref mut v) => match &mut v.fields {
StructFields::Unit => dt,
StructFields::Unnamed(f) => {
for field in f.fields.iter_mut() {
field.ty = field.ty.take().map(|v| resolve_generics(v, generics));
}
dt
}
StructFields::Named(f) => {
for (_, field) in f.fields.iter_mut() {
field.ty = field.ty.take().map(|v| resolve_generics(v, generics));
}
dt
}
},
DataType::Enum(ref mut v) => {
for (_, v) in v.variants.iter_mut() {
match &mut v.inner {
EnumVariants::Unit => {}
EnumVariants::Named(f) => {
for (_, field) in f.fields.iter_mut() {
field.ty = field.ty.take().map(|v| resolve_generics(v, generics));
}
}
EnumVariants::Unnamed(f) => {
for field in f.fields.iter_mut() {
field.ty = field.ty.take().map(|v| resolve_generics(v, generics));
}
}
}
}
dt
}
DataType::Tuple(ref mut v) => {
for ty in v.elements.iter_mut() {
*ty = resolve_generics(ty.clone(), generics);
}
dt
}
DataType::Result(result) => DataType::Result(Box::new({
let (ok, err) = *result;
(
resolve_generics(ok, generics),
resolve_generics(err, generics),
)
})),
DataType::Reference(ref mut r) => {
for (_, generic) in r.generics.iter_mut() {
*generic = resolve_generics(generic.clone(), generics);
}
dt
}
DataType::Generic(g) => generics
.iter()
.find(|(name, _)| name == &g)
.map(|(_, ty)| ty.clone())
.unwrap_or_else(|| format!("Generic type `{g}` was referenced but not found").into()), }
}