use alloc::borrow::Cow;
use core::hash::{BuildHasher, Hash};
use std::collections::HashMap;
use ahash::HashMapExt;
use super::{FromRon, FromRonFields, expr_type_name};
use crate::{
ast::{AnonStructExpr, Expr, FieldsBody, StringExpr, StringKind, StructField},
error::{Error, ErrorKind, Result, Span},
};
pub const MAX_FIELDS: usize = 64;
#[derive(Debug)]
pub struct AstMapAccess<'a> {
fields: &'a [StructField<'a>],
field_indices: ahash::HashMap<&'a str, usize>,
consumed: u64,
struct_span: Span,
struct_name: Option<&'a str>,
}
impl<'a> AstMapAccess<'a> {
#[inline]
pub fn from_anon(s: &'a AnonStructExpr<'a>, struct_name: Option<&'a str>) -> Result<Self> {
if s.fields.len() > MAX_FIELDS {
return Err(Error::with_span(
ErrorKind::TooManyFields {
count: s.fields.len(),
limit: MAX_FIELDS,
},
s.span,
));
}
let mut field_indices = ahash::HashMap::with_capacity(s.fields.len());
for (i, field) in s.fields.iter().enumerate() {
let name = field.name.name.as_ref();
if field_indices.insert(name, i).is_some() {
return Err(Error::with_span(
ErrorKind::DuplicateField {
field: Cow::Owned(name.to_string()),
outer: struct_name.map(|s| Cow::Owned(s.to_string())),
},
field.name.span,
));
}
}
Ok(Self {
fields: &s.fields,
field_indices,
consumed: 0,
struct_span: s.span,
struct_name,
})
}
#[inline]
pub fn from_fields(
fields_body: &'a FieldsBody<'a>,
struct_name: Option<&'a str>,
struct_span: Span,
) -> Result<Self> {
if fields_body.fields.len() > MAX_FIELDS {
return Err(Error::with_span(
ErrorKind::TooManyFields {
count: fields_body.fields.len(),
limit: MAX_FIELDS,
},
struct_span,
));
}
let mut field_indices = ahash::HashMap::with_capacity(fields_body.fields.len());
for (i, field) in fields_body.fields.iter().enumerate() {
let name = field.name.name.as_ref();
if field_indices.insert(name, i).is_some() {
return Err(Error::with_span(
ErrorKind::DuplicateField {
field: Cow::Owned(name.to_string()),
outer: struct_name.map(|s| Cow::Owned(s.to_string())),
},
field.name.span,
));
}
}
Ok(Self {
fields: &fields_body.fields,
field_indices,
consumed: 0,
struct_span,
struct_name,
})
}
#[inline]
fn find_field(&self, name: &str) -> Option<(usize, &'a StructField<'a>)> {
self.field_indices
.get(name)
.map(|&idx| (idx, &self.fields[idx]))
}
#[inline]
pub fn required<T: FromRon>(&mut self, name: &'static str) -> Result<T> {
match self.find_field(name) {
Some((idx, field)) => {
self.consumed |= 1 << idx;
T::from_ast(&field.value)
}
None => Err(Error::with_span(
ErrorKind::MissingField {
field: Cow::Borrowed(name),
outer: self.struct_name.map(|s| Cow::Owned(s.to_string())),
},
self.struct_span,
)),
}
}
#[inline]
pub fn optional<T: FromRon>(&mut self, name: &'static str) -> Result<Option<T>> {
match self.find_field(name) {
Some((idx, field)) => {
self.consumed |= 1 << idx;
Ok(Some(T::from_ast(&field.value)?))
}
None => Ok(None),
}
}
#[inline]
pub fn with_default<T: FromRon + Default>(&mut self, name: &'static str) -> Result<T> {
match self.find_field(name) {
Some((idx, field)) => {
self.consumed |= 1 << idx;
T::from_ast(&field.value)
}
None => Ok(T::default()),
}
}
#[inline]
pub fn with_default_fn<T: FromRon, F: FnOnce() -> T>(
&mut self,
name: &'static str,
default_fn: F,
) -> Result<T> {
match self.find_field(name) {
Some((idx, field)) => {
self.consumed |= 1 << idx;
T::from_ast(&field.value)
}
None => Ok(default_fn()),
}
}
pub fn deny_unknown_fields(&self, expected: &'static [&'static str]) -> Result<()> {
for (i, field) in self.fields.iter().enumerate() {
if self.consumed & (1 << i) == 0 {
let name = field.name.name.as_ref();
return Err(Error::with_span(
ErrorKind::UnknownField {
field: Cow::Owned(name.to_string()),
expected,
outer: self.struct_name.map(|s| Cow::Owned(s.to_string())),
},
field.name.span,
));
}
}
Ok(())
}
pub fn remaining_keys(&self) -> impl Iterator<Item = &str> {
self.fields
.iter()
.enumerate()
.filter(|(i, _)| self.consumed & (1 << *i) == 0)
.map(|(_, f)| f.name.name.as_ref())
}
#[inline]
pub fn required_explicit<T: FromRon>(&mut self, name: &'static str) -> Result<Option<T>> {
match self.find_field(name) {
Some((idx, field)) => {
self.consumed |= 1 << idx;
match &field.value {
Expr::Option(opt) => match &opt.value {
None => Ok(None),
Some(inner) => Ok(Some(T::from_ast(&inner.expr)?)),
},
other => Err(Error::with_span(
ErrorKind::TypeMismatch {
expected: "Some(...) or None".to_string(),
found: expr_type_name(other).to_string(),
},
*other.span(),
)),
}
}
None => Err(Error::with_span(
ErrorKind::MissingField {
field: Cow::Borrowed(name),
outer: self.struct_name.map(|s| Cow::Owned(s.to_string())),
},
self.struct_span,
)),
}
}
#[inline]
pub fn with_default_explicit<T: FromRon>(&mut self, name: &'static str) -> Result<Option<T>> {
match self.find_field(name) {
Some((idx, field)) => {
self.consumed |= 1 << idx;
match &field.value {
Expr::Option(opt) => match &opt.value {
None => Ok(None),
Some(inner) => Ok(Some(T::from_ast(&inner.expr)?)),
},
other => Err(Error::with_span(
ErrorKind::TypeMismatch {
expected: "Some(...) or None".to_string(),
found: expr_type_name(other).to_string(),
},
*other.span(),
)),
}
}
None => Ok(None),
}
}
#[inline]
pub fn flatten<T: FromRonFields>(&mut self) -> Result<T> {
T::from_fields(self)
}
}
impl<T: FromRonFields> FromRonFields for Option<T> {
fn from_fields(access: &mut AstMapAccess<'_>) -> Result<Self> {
let consumed_before = access.consumed;
match T::from_fields(access) {
Ok(value) => {
if access.consumed == consumed_before {
Ok(None)
} else {
Ok(Some(value))
}
}
Err(err) => {
if access.consumed == consumed_before
&& matches!(err.kind(), ErrorKind::MissingField { .. })
{
Ok(None)
} else {
Err(err)
}
}
}
}
}
impl<K: FromRon + Eq + Hash, V: FromRon, S: BuildHasher + Default> FromRonFields
for HashMap<K, V, S>
{
fn from_fields(access: &mut AstMapAccess<'_>) -> Result<Self> {
let remaining = access
.fields
.iter()
.enumerate()
.filter(|(i, _)| access.consumed & (1 << *i) == 0)
.count();
let mut map = HashMap::with_capacity_and_hasher(remaining, S::default());
for (i, field) in access.fields.iter().enumerate() {
if access.consumed & (1 << i) != 0 {
continue;
}
access.consumed |= 1 << i;
let key_expr = Expr::String(StringExpr {
span: field.name.span,
raw: Cow::Owned(format!("\"{}\"", field.name.name)),
value: field.name.name.to_string(),
kind: StringKind::Regular,
});
let key = K::from_ast(&key_expr)?;
let value = V::from_ast(&field.value)?;
map.insert(key, value);
}
Ok(map)
}
}
#[cfg(test)]
mod tests {
use core::fmt::Write;
use super::*;
use crate::ast::parse_document;
#[test]
fn test_basic_field_access() {
let ron = "(a: 1, b: 2, c: 3)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, Some("Test")).unwrap();
assert_eq!(access.required::<i32>("a").unwrap(), 1);
assert_eq!(access.required::<i32>("b").unwrap(), 2);
assert_eq!(access.required::<i32>("c").unwrap(), 3);
assert!(access.deny_unknown_fields(&["a", "b", "c"]).is_ok());
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_consumed_bitmask_tracking() {
let ron = "(a: 1, b: 2, c: 3)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, Some("Test")).unwrap();
assert_eq!(access.required::<i32>("a").unwrap(), 1);
assert_eq!(access.required::<i32>("c").unwrap(), 3);
let err = access.deny_unknown_fields(&["a", "c"]).unwrap_err();
assert!(matches!(err.kind(), ErrorKind::UnknownField { .. }));
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_remaining_keys() {
let ron = "(a: 1, b: 2, c: 3, d: 4)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, None).unwrap();
let _ = access.required::<i32>("a").unwrap();
let _ = access.required::<i32>("c").unwrap();
let remaining: Vec<_> = access.remaining_keys().collect();
assert_eq!(remaining.len(), 2);
assert!(remaining.contains(&"b"));
assert!(remaining.contains(&"d"));
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_duplicate_field_detection() {
let ron = "(a: 1, b: 2, a: 3)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let err = AstMapAccess::from_anon(s, Some("Test")).unwrap_err();
match err.kind() {
ErrorKind::DuplicateField { field, outer } => {
assert_eq!(field.as_ref(), "a");
assert_eq!(outer.as_deref(), Some("Test"));
}
_ => panic!("Expected DuplicateField error"),
}
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_too_many_fields_error() {
let mut fields = String::new();
for i in 0..65 {
if i > 0 {
fields.push_str(", ");
}
let _ = write!(fields, "field_{i}: {i}");
}
let ron = format!("({fields})");
let doc = parse_document(&ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let err = AstMapAccess::from_anon(s, Some("TooLarge")).unwrap_err();
match err.kind() {
ErrorKind::TooManyFields { count, limit } => {
assert_eq!(*count, 65);
assert_eq!(*limit, MAX_FIELDS);
}
_ => panic!("Expected TooManyFields error, got {:?}", err.kind()),
}
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_exactly_64_fields_ok() {
let mut fields = String::new();
for i in 0..64 {
if i > 0 {
fields.push_str(", ");
}
let _ = write!(fields, "f{i}: {i}");
}
let ron = format!("({fields})");
let doc = parse_document(&ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let access = AstMapAccess::from_anon(s, None).unwrap();
assert_eq!(access.fields.len(), 64);
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_missing_field_error() {
let ron = "(a: 1)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, Some("Test")).unwrap();
let err = access.required::<i32>("missing").unwrap_err();
match err.kind() {
ErrorKind::MissingField { field, outer } => {
assert_eq!(field.as_ref(), "missing");
assert_eq!(outer.as_deref(), Some("Test"));
}
_ => panic!("Expected MissingField error"),
}
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_optional_field() {
let ron = "(a: 1)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, None).unwrap();
assert_eq!(access.optional::<i32>("a").unwrap(), Some(1));
assert_eq!(access.optional::<i32>("missing").unwrap(), None);
} else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_with_default() {
let ron = "(a: 1)";
let doc = parse_document(ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, None).unwrap();
assert_eq!(access.with_default::<i32>("a").unwrap(), 1);
assert_eq!(access.with_default::<i32>("missing").unwrap(), 0); } else {
panic!("Expected anonymous struct");
}
}
#[test]
fn test_high_index_field_consumed() {
let mut fields = String::new();
for i in 0..64 {
if i > 0 {
fields.push_str(", ");
}
let _ = write!(fields, "f{i}: {i}");
}
let ron = format!("({fields})");
let doc = parse_document(&ron).unwrap();
if let Some(Expr::AnonStruct(s)) = &doc.value {
let mut access = AstMapAccess::from_anon(s, None).unwrap();
assert_eq!(access.required::<i32>("f63").unwrap(), 63);
assert_eq!(access.consumed & (1u64 << 63), 1u64 << 63);
assert_eq!(access.consumed & (1u64 << 62), 0);
} else {
panic!("Expected anonymous struct");
}
}
}