use eure_document::parse::union::has_explicit_variant_tag;
use eure_document::parse::{DocumentParser, ParseContext};
use crate::{SchemaNodeId, UnionSchema};
use super::SchemaValidator;
use super::context::ValidationContext;
use super::error::{ValidationError, ValidatorError, select_best_variant_match};
pub struct UnionValidator<'a, 'doc, 's> {
pub ctx: &'a ValidationContext<'doc>,
pub schema: &'s UnionSchema,
pub schema_node_id: SchemaNodeId,
}
#[derive(Copy, Clone)]
struct VariantValidationOptions {
propagate_errors: bool,
requires_explicit_tag: bool,
has_explicit_tag: bool,
}
impl<'a, 'doc, 's> DocumentParser<'doc> for UnionValidator<'a, 'doc, 's> {
type Output = ();
type Error = ValidatorError;
fn parse(&mut self, parse_ctx: &ParseContext<'doc>) -> Result<(), ValidatorError> {
let union_parser = match parse_ctx.parse_union::<(), ValidatorError>() {
Ok(p) => p,
Err(e) => {
if let Some(parse_error) = e.as_parse_error() {
self.ctx.record_error(ValidationError::ParseError {
path: self.ctx.path(),
node_id: parse_ctx.node_id(),
schema_node_id: self.schema_node_id,
error: parse_error.clone(),
});
} else {
self.ctx.record_error(ValidationError::InvalidVariantTag {
tag: format!("{e}"),
path: self.ctx.path(),
node_id: parse_ctx.node_id(),
schema_node_id: self.schema_node_id,
});
}
return Ok(());
}
};
let mut builder = union_parser;
let has_explicit_tag =
match has_explicit_variant_tag(self.ctx.document, parse_ctx.node_id()) {
Ok(has_tag) => has_tag,
Err(parse_error) => {
self.ctx.record_error(ValidationError::ParseError {
path: self.ctx.path(),
node_id: parse_ctx.node_id(),
schema_node_id: self.schema_node_id,
error: parse_error,
});
return Ok(());
}
};
let is_tagged = has_explicit_tag;
let deny_untagged = &self.schema.deny_untagged;
let unambiguous = &self.schema.unambiguous;
for (name, &variant_schema_id) in &self.schema.variants {
let ctx = self.ctx;
let schema_node_id = variant_schema_id;
let variant_name = name.clone();
let options = VariantValidationOptions {
propagate_errors: is_tagged,
requires_explicit_tag: deny_untagged.contains(name),
has_explicit_tag,
};
let validator = move |parse_ctx: &ParseContext<'_>| {
validate_variant(ctx, parse_ctx, schema_node_id, &variant_name, options)
};
if unambiguous.contains(name) {
builder = builder.variant_unambiguous(name, validator);
} else {
builder = builder.variant(name, validator);
}
}
match builder.parse() {
Ok(()) => {
self.ctx.clear_variant_errors();
Ok(())
}
Err(e) => {
if matches!(e, ValidatorError::InnerErrorsPropagated) {
return Ok(());
}
if let Some(parse_error) = e.as_parse_error() {
self.ctx.record_error(ValidationError::ParseError {
path: self.ctx.path(),
node_id: parse_ctx.node_id(),
schema_node_id: self.schema_node_id,
error: parse_error.clone(),
});
} else {
let variant_errors = self.ctx.take_variant_errors();
let best_match = select_best_variant_match(variant_errors).map(Box::new);
self.ctx.record_error(ValidationError::NoVariantMatched {
path: self.ctx.path(),
best_match,
node_id: parse_ctx.node_id(),
schema_node_id: self.schema_node_id,
});
}
Ok(())
}
}
}
}
fn validate_variant<'doc>(
ctx: &ValidationContext<'doc>,
parse_ctx: &ParseContext<'doc>,
schema_node_id: SchemaNodeId,
variant_name: &str,
options: VariantValidationOptions,
) -> Result<(), ValidatorError> {
let forked_state = ctx.fork_state();
let trial_ctx = ValidationContext::with_state(ctx.document, ctx.schema, forked_state);
let child_validator = SchemaValidator {
ctx: &trial_ctx,
schema_node_id,
};
let result = parse_ctx.parse_with(child_validator);
if result.is_ok() && !trial_ctx.has_errors() {
if options.requires_explicit_tag && !options.has_explicit_tag {
ctx.record_error(ValidationError::RequiresExplicitVariant {
variant: variant_name.to_string(),
path: ctx.path(),
node_id: parse_ctx.node_id(),
schema_node_id,
});
return Err(ValidatorError::InnerErrorsPropagated);
}
ctx.merge_state(trial_ctx.state.into_inner());
Ok(())
} else {
let trial_state = trial_ctx.state.into_inner();
if options.propagate_errors && !trial_state.errors.is_empty() {
ctx.merge_state(trial_state);
Err(ValidatorError::InnerErrorsPropagated)
} else {
if !trial_state.errors.is_empty() {
ctx.record_variant_errors(
variant_name.to_string(),
schema_node_id,
trial_state.errors,
);
}
Err(ValidatorError::InvalidVariantTag {
tag: variant_name.to_string(),
reason: "type mismatch".to_string(),
})
}
}
}