use crate::ast::{
CaseLabel, ConstExpr, ConstrTypeDecl, Definition, LiteralKind, Specification, SwitchTypeSpec,
TypeDecl, UnionDcl, UnionDef,
};
use crate::errors::Span;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UnionValidationError {
InvalidDiscriminator {
kind: String,
span: Span,
},
DuplicateDefault {
span: Span,
},
DuplicateCaseLabel {
label: String,
span: Span,
},
LabelTypeMismatch {
discriminator: String,
label: String,
span: Span,
},
MissingCaseLabel {
span: Span,
},
DuplicateElementDeclarator {
name: String,
span: Span,
},
DefaultLabelRedundant {
discriminator: String,
span: Span,
},
}
#[must_use]
pub fn validate_unions(spec: &Specification) -> Vec<UnionValidationError> {
let mut errs = Vec::new();
for d in &spec.definitions {
walk_def(d, &mut errs);
}
errs
}
fn walk_def(d: &Definition, errs: &mut Vec<UnionValidationError>) {
match d {
Definition::Module(m) => {
for d in &m.definitions {
walk_def(d, errs);
}
}
Definition::Type(TypeDecl::Constr(ConstrTypeDecl::Union(UnionDcl::Def(u)))) => {
validate_union(u, errs);
}
_ => {}
}
}
pub fn validate_union(u: &UnionDef, errs: &mut Vec<UnionValidationError>) {
let disc_kind = check_discriminator(&u.switch_type, errs);
let mut default_seen = false;
let mut default_span: Option<Span> = None;
let mut seen_labels: Vec<String> = Vec::new();
let mut seen_declarators: Vec<String> = Vec::new();
let mut bool_value_labels: Vec<bool> = Vec::new();
for case in &u.cases {
let decl_name = case.element.declarator.name().text.clone();
if seen_declarators.iter().any(|n| n == &decl_name) {
errs.push(UnionValidationError::DuplicateElementDeclarator {
name: decl_name.clone(),
span: case.element.span,
});
} else {
seen_declarators.push(decl_name);
}
if case.labels.is_empty() {
errs.push(UnionValidationError::MissingCaseLabel { span: case.span });
continue;
}
for label in &case.labels {
match label {
CaseLabel::Default => {
if default_seen {
errs.push(UnionValidationError::DuplicateDefault { span: case.span });
}
default_seen = true;
default_span = Some(case.span);
}
CaseLabel::Value(expr) => {
let raw = const_expr_str(expr);
if seen_labels.iter().any(|l| l == &raw) {
errs.push(UnionValidationError::DuplicateCaseLabel {
label: raw.clone(),
span: case.span,
});
} else {
seen_labels.push(raw.clone());
}
if let Some(ref disc) = disc_kind {
if !label_matches_disc(expr, disc) {
errs.push(UnionValidationError::LabelTypeMismatch {
discriminator: disc.clone(),
label: raw.clone(),
span: case.span,
});
}
if disc == "boolean" {
if let ConstExpr::Literal(l) = expr {
if matches!(l.kind, LiteralKind::Boolean) {
let v = l.raw == "TRUE" || l.raw == "true";
if !bool_value_labels.contains(&v) {
bool_value_labels.push(v);
}
}
} else if let ConstExpr::Scoped(s) = expr {
if let Some(p) = s.parts.last() {
let v = matches!(p.text.as_str(), "TRUE" | "true");
if !bool_value_labels.contains(&v) {
bool_value_labels.push(v);
}
}
}
}
}
}
}
}
}
if default_seen {
if let Some(ref disc) = disc_kind {
if disc == "boolean" && bool_value_labels.len() == 2 {
errs.push(UnionValidationError::DefaultLabelRedundant {
discriminator: disc.clone(),
span: default_span.unwrap_or(u.span),
});
}
}
}
}
fn check_discriminator(s: &SwitchTypeSpec, errs: &mut Vec<UnionValidationError>) -> Option<String> {
match s {
SwitchTypeSpec::Integer(_) => Some("integer".to_string()),
SwitchTypeSpec::Char => Some("char".to_string()),
SwitchTypeSpec::Boolean => Some("boolean".to_string()),
SwitchTypeSpec::Octet => Some("octet".to_string()),
SwitchTypeSpec::Scoped(_) => Some("enum".to_string()), #[allow(unreachable_patterns)]
other => {
errs.push(UnionValidationError::InvalidDiscriminator {
kind: format!("{other:?}"),
span: Span::SYNTHETIC,
});
None
}
}
}
fn const_expr_str(e: &ConstExpr) -> String {
match e {
ConstExpr::Literal(l) => l.raw.clone(),
ConstExpr::Scoped(s) => s
.parts
.iter()
.map(|p| p.text.as_str())
.collect::<Vec<_>>()
.join("::"),
ConstExpr::Unary { op, operand, .. } => format!("({op:?} {})", const_expr_str(operand)),
ConstExpr::Binary { op, lhs, rhs, .. } => {
format!("({} {op:?} {})", const_expr_str(lhs), const_expr_str(rhs))
}
}
}
fn label_matches_disc(expr: &ConstExpr, disc: &str) -> bool {
match (expr, disc) {
(ConstExpr::Literal(l), "integer" | "octet") => matches!(l.kind, LiteralKind::Integer),
(ConstExpr::Literal(l), "char") => matches!(l.kind, LiteralKind::Char),
(ConstExpr::Literal(l), "boolean") => matches!(l.kind, LiteralKind::Boolean),
(ConstExpr::Scoped(_), _) => true, (ConstExpr::Unary { .. } | ConstExpr::Binary { .. }, "integer" | "octet") => true,
_ => false,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::config::ParserConfig;
use crate::parser::parse;
fn parse_to_ast(src: &str) -> Specification {
parse(src, &ParserConfig::default()).expect("parse ok")
}
#[test]
fn long_discriminator_with_int_labels_ok() {
let ast = parse_to_ast(
"union U switch (long) { case 1: long a; case 2: long b; default: long c; };",
);
let errs = validate_unions(&ast);
assert!(errs.is_empty(), "got {errs:?}");
}
#[test]
fn duplicate_default_branch_errors() {
let ast = parse_to_ast("union U switch (long) { default: long a; default: long b; };");
let errs = validate_unions(&ast);
assert!(
errs.iter()
.any(|e| matches!(e, UnionValidationError::DuplicateDefault { .. }))
);
}
#[test]
fn duplicate_case_label_errors() {
let ast = parse_to_ast("union U switch (long) { case 1: long a; case 1: long b; };");
let errs = validate_unions(&ast);
assert!(
errs.iter()
.any(|e| matches!(e, UnionValidationError::DuplicateCaseLabel { .. }))
);
}
#[test]
fn boolean_discriminator_with_int_label_is_mismatch() {
let ast = parse_to_ast("union U switch (boolean) { case 1: long a; default: long b; };");
let errs = validate_unions(&ast);
assert!(
errs.iter()
.any(|e| matches!(e, UnionValidationError::LabelTypeMismatch { .. }))
);
}
#[test]
fn boolean_discriminator_with_bool_labels_ok() {
let ast =
parse_to_ast("union U switch (boolean) { case TRUE: long a; case FALSE: long b; };");
let errs = validate_unions(&ast);
assert!(errs.is_empty(), "got {errs:?}");
}
#[test]
fn char_discriminator_with_char_labels_ok() {
let ast = parse_to_ast("union U switch (char) { case 'a': long x; case 'b': long y; };");
let errs = validate_unions(&ast);
assert!(errs.is_empty(), "got {errs:?}");
}
#[test]
fn union_with_duplicate_element_declarator_errors() {
let ast = parse_to_ast(
"union U switch (long) { case 1: long a; case 2: long a; default: long b; };",
);
let errs = validate_unions(&ast);
assert!(
errs.iter()
.any(|e| matches!(e, UnionValidationError::DuplicateElementDeclarator { .. })),
"got {errs:?}"
);
}
#[test]
fn union_default_redundant_for_full_boolean_coverage_errors() {
let ast = parse_to_ast(
"union U switch (boolean) { case TRUE: long a; case FALSE: long b; default: long c; };",
);
let errs = validate_unions(&ast);
assert!(
errs.iter()
.any(|e| matches!(e, UnionValidationError::DefaultLabelRedundant { .. })),
"got {errs:?}"
);
}
#[test]
fn union_default_required_for_partial_int_coverage_ok() {
let ast = parse_to_ast(
"union U switch (long) { case 1: long a; case 2: long b; default: long c; };",
);
let errs = validate_unions(&ast);
assert!(
!errs
.iter()
.any(|e| matches!(e, UnionValidationError::DefaultLabelRedundant { .. })),
"got {errs:?}"
);
}
#[test]
fn union_default_coverage_required_when_partial_range() {
let ast = parse_to_ast(
"union U switch (octet) { case 1: long a; case 2: long b; case 3: long c; };",
);
let errs = validate_unions(&ast);
assert!(
errs.is_empty(),
"partial-range no-default unzulaessig? errs={errs:?}"
);
}
}