specta-typescript 0.0.11

Export your Rust types to TypeScript
Documentation
use std::collections::HashSet;

use specta::{
    Types,
    datatype::{DataType, Fields, GenericReference, Primitive, Reference},
};

use crate::Error;

pub(crate) fn validate_map_key(
    key_ty: &DataType,
    types: &Types,
    generics: &[(GenericReference, DataType)],
    path: String,
) -> Result<(), Error> {
    validate_map_key_inner(
        key_ty,
        types,
        generics,
        path,
        &mut HashSet::new(),
        &mut HashSet::new(),
    )
}

fn validate_map_key_inner(
    key_ty: &DataType,
    types: &Types,
    generics: &[(GenericReference, DataType)],
    path: String,
    visiting_named_refs: &mut HashSet<Reference>,
    visiting_generic_refs: &mut HashSet<(GenericReference, DataType)>,
) -> Result<(), Error> {
    match key_ty {
        DataType::Primitive(primitive) if primitive_is_valid_key(primitive.clone()) => Ok(()),
        DataType::Primitive(other) => Err(Error::invalid_map_key(
            path,
            invalid_primitive_reason(other.clone()),
        )),
        DataType::Enum(enm) => {
            for (variant_name, variant) in enm.variants() {
                match variant.fields() {
                    Fields::Unit => {}
                    Fields::Unnamed(unnamed) => {
                        let non_skipped = unnamed
                            .fields()
                            .iter()
                            .filter_map(|field| field.ty())
                            .count();
                        if non_skipped != 1 {
                            return Err(Error::invalid_map_key(
                                &path,
                                format!(
                                    "enum key variant '{variant_name}' must serialize as a newtype value"
                                ),
                            ));
                        }
                    }
                    Fields::Named(_) => {
                        return Err(Error::invalid_map_key(
                            &path,
                            format!(
                                "enum key variant '{variant_name}' serializes as a struct variant, which serde_json rejects"
                            ),
                        ));
                    }
                }
            }

            Ok(())
        }
        DataType::Struct(strct) => {
            let Fields::Unnamed(unnamed) = strct.fields() else {
                return Err(Error::invalid_map_key(
                    path,
                    "struct keys must serialize as a newtype struct to be valid serde_json map keys",
                ));
            };

            let mut non_skipped = unnamed.fields().iter().filter_map(|field| field.ty());
            let Some(inner_ty) = non_skipped.next() else {
                return Err(Error::invalid_map_key(
                    path,
                    "newtype struct map keys must have exactly one serializable field",
                ));
            };

            if non_skipped.next().is_some() {
                return Err(Error::invalid_map_key(
                    path,
                    "newtype struct map keys must have exactly one serializable field",
                ));
            }

            validate_map_key_inner(
                inner_ty,
                types,
                generics,
                path,
                visiting_named_refs,
                visiting_generic_refs,
            )
        }
        DataType::Reference(Reference::Named(reference)) => {
            let reference_key = Reference::Named(reference.clone());
            if !visiting_named_refs.insert(reference_key.clone()) {
                return Err(Error::invalid_map_key(
                    path,
                    "recursive map key reference cycle detected",
                ));
            }

            let result = if let Some(ndt) = reference.get(types) {
                let merged_generics = merged_generics(generics, reference.generics());
                validate_map_key_inner(
                    ndt.ty(),
                    types,
                    &merged_generics,
                    path,
                    visiting_named_refs,
                    visiting_generic_refs,
                )
            } else {
                Err(Error::invalid_map_key(
                    path,
                    format!("unresolved named map key reference {reference:?}"),
                ))
            };

            visiting_named_refs.remove(&reference_key);
            result
        }
        DataType::Reference(Reference::Generic(generic)) => {
            let Some((_, ty)) = generics.iter().find(|(candidate, _)| candidate == generic) else {
                return Ok(());
            };

            if matches!(ty, DataType::Reference(Reference::Generic(inner)) if inner == generic) {
                return Ok(());
            }

            let resolved = resolve_generics_in_datatype(ty, generics);
            let generic_state = (generic.clone(), resolved.clone());
            if !visiting_generic_refs.insert(generic_state.clone()) {
                return Ok(());
            }

            let result = validate_map_key_inner(
                &resolved,
                types,
                generics,
                path,
                visiting_named_refs,
                visiting_generic_refs,
            );
            visiting_generic_refs.remove(&generic_state);

            result
        }
        DataType::Reference(Reference::Opaque(_)) => Err(Error::invalid_map_key(
            path,
            "opaque references cannot be validated as serde_json map keys",
        )),
        DataType::Tuple(_) => Err(Error::invalid_map_key(
            path,
            "tuple keys are not supported by serde_json map key serialization",
        )),
        DataType::List(_) | DataType::Map(_) | DataType::Nullable(_) => {
            Err(Error::invalid_map_key(
                path,
                "collection, map, and nullable keys are not supported by serde_json map key serialization",
            ))
        }
    }
}

fn primitive_is_valid_key(primitive: Primitive) -> bool {
    matches!(
        primitive,
        Primitive::bool
            | Primitive::i8
            | Primitive::i16
            | Primitive::i32
            | Primitive::i64
            | Primitive::i128
            | Primitive::isize
            | Primitive::u8
            | Primitive::u16
            | Primitive::u32
            | Primitive::u64
            | Primitive::u128
            | Primitive::usize
            | Primitive::f32
            | Primitive::f64
            | Primitive::str
            | Primitive::char
    )
}

fn invalid_primitive_reason(primitive: Primitive) -> &'static str {
    match primitive {
        Primitive::f16 | Primitive::f128 => {
            "f16 and f128 keys are not supported by serde_json map key serialization"
        }
        _ => "unsupported primitive key type for serde_json map key serialization",
    }
}

fn merged_generics(
    parent: &[(GenericReference, DataType)],
    child: &[(GenericReference, DataType)],
) -> Vec<(GenericReference, DataType)> {
    let unshadowed_parent = parent
        .iter()
        .filter(|(parent_generic, _)| {
            !child
                .iter()
                .any(|(child_generic, _)| child_generic == parent_generic)
        })
        .cloned();

    child
        .iter()
        .map(|(generic, dt)| (generic.clone(), resolve_generics_in_datatype(dt, parent)))
        .chain(unshadowed_parent)
        .collect()
}

fn resolve_generics_in_datatype(
    dt: &DataType,
    generics: &[(GenericReference, DataType)],
) -> DataType {
    fn resolve(
        dt: &DataType,
        generics: &[(GenericReference, DataType)],
        visiting: &mut Vec<GenericReference>,
    ) -> DataType {
        match dt {
            DataType::Primitive(_) => dt.clone(),
            DataType::List(list) => {
                let mut out = list.clone();
                out.set_ty(resolve(list.ty(), generics, visiting));
                DataType::List(out)
            }
            DataType::Map(map) => {
                let mut out = map.clone();
                out.set_key_ty(resolve(map.key_ty(), generics, visiting));
                out.set_value_ty(resolve(map.value_ty(), generics, visiting));
                DataType::Map(out)
            }
            DataType::Nullable(inner) => {
                DataType::Nullable(Box::new(resolve(inner, generics, visiting)))
            }
            DataType::Struct(strct) => {
                let mut out = strct.clone();
                match out.fields_mut() {
                    Fields::Unit => {}
                    Fields::Unnamed(unnamed) => {
                        for field in unnamed.fields_mut() {
                            if let Some(ty) = field.ty_mut() {
                                *ty = resolve(ty, generics, visiting);
                            }
                        }
                    }
                    Fields::Named(named) => {
                        for (_, field) in named.fields_mut() {
                            if let Some(ty) = field.ty_mut() {
                                *ty = resolve(ty, generics, visiting);
                            }
                        }
                    }
                }
                DataType::Struct(out)
            }
            DataType::Enum(enm) => {
                let mut out = enm.clone();
                for (_, variant) in out.variants_mut() {
                    match variant.fields_mut() {
                        Fields::Unit => {}
                        Fields::Unnamed(unnamed) => {
                            for field in unnamed.fields_mut() {
                                if let Some(ty) = field.ty_mut() {
                                    *ty = resolve(ty, generics, visiting);
                                }
                            }
                        }
                        Fields::Named(named) => {
                            for (_, field) in named.fields_mut() {
                                if let Some(ty) = field.ty_mut() {
                                    *ty = resolve(ty, generics, visiting);
                                }
                            }
                        }
                    }
                }
                DataType::Enum(out)
            }
            DataType::Tuple(tuple) => {
                let mut out = tuple.clone();
                for element in out.elements_mut() {
                    *element = resolve(element, generics, visiting);
                }
                DataType::Tuple(out)
            }
            DataType::Reference(Reference::Generic(generic)) => {
                if visiting.iter().any(|seen| seen == generic) {
                    return dt.clone();
                }

                if let Some((_, resolved)) =
                    generics.iter().find(|(candidate, _)| candidate == generic)
                {
                    if matches!(resolved, DataType::Reference(Reference::Generic(inner)) if inner == generic)
                    {
                        dt.clone()
                    } else {
                        visiting.push(generic.clone());
                        let out = resolve(resolved, generics, visiting);
                        visiting.pop();
                        out
                    }
                } else {
                    dt.clone()
                }
            }
            DataType::Reference(_) => dt.clone(),
        }
    }

    resolve(dt, generics, &mut Vec::new())
}