use crate::ion_path::{IonPath, IonPathElement};
use crate::result::ValidationResult;
use crate::system::TypeStore;
use crate::type_reference::{TypeReference, VariablyOccurringTypeRef};
use crate::types::TypeValidator;
use crate::violation::{Violation, ViolationCode};
use crate::IonSchemaElement;
use ion_rs::Element;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::mem::swap;
use std::ops::RangeInclusive;
use std::vec;
type StateId = usize;
type StateVisitCount = (StateId, usize);
type ElementOrEndOfSequence<'a> = Option<&'a Element>;
#[derive(Debug, Clone, PartialEq)]
pub struct OrderedElementsNfa {
states: Vec<State>,
edges: Vec<RangeInclusive<usize>>,
terminal_state: StateVisitCount,
}
impl OrderedElementsNfa {
pub fn new(intermediate_states: Vec<(VariablyOccurringTypeRef, Option<String>)>) -> Self {
let mut states = vec![State::Initial];
intermediate_states
.into_iter()
.enumerate()
.for_each(|(i, (var_type_ref, description))| {
let description =
description.unwrap_or_else(|| format!("<ORDERED_ELEMENT[{}]>", i));
let (min_visits, max_visits) = var_type_ref.occurs_range().inclusive_endpoints();
let state = IntermediateState {
type_ref: var_type_ref.type_ref(),
min_visits,
max_visits,
description,
};
states.push(State::Intermediate(i + 1, state))
});
let max_id = states.len();
let mut edges = vec![];
for (i, s) in states.iter().enumerate() {
let min_transition = if s.can_reenter(1) { i } else { i + 1 };
let mut j = i + 1;
while j < max_id {
if !states[j].can_exit(0) {
break;
}
j += 1;
}
edges.push(min_transition..=j)
}
states.push(State::Final(max_id));
let terminal_state: StateVisitCount = (max_id, 1usize);
OrderedElementsNfa {
states,
edges,
terminal_state,
}
}
pub fn matches<'a, I: Iterator<Item = &'a Element>>(
&self,
mut iter: I,
type_store: &'a TypeStore,
ion_path: &mut IonPath,
) -> ValidationResult {
let mut current_state_set: HashSet<StateVisitCount> = HashSet::new();
let mut new_states: HashSet<StateVisitCount> = HashSet::new();
let mut input_index = 0;
current_state_set.insert((0usize, 1usize));
loop {
let element: ElementOrEndOfSequence = iter.next();
let mut invalid_transitions: HashSet<TraversalError> = HashSet::new();
ion_path.push(IonPathElement::Index(input_index));
for &(from_state_id, num_visits) in ¤t_state_set {
let from_state = &self.states[from_state_id];
let edges = if let Some(edges) = self.edges.get(from_state_id) {
edges.clone()
} else {
invalid_transitions.insert(TraversalError::CannotExitState(from_state_id));
break;
};
for to_state_id in edges {
let to_state: &State = &self.states[to_state_id];
let can_reenter = from_state.can_reenter(num_visits);
let can_exit = from_state.can_exit(num_visits);
let is_loop = to_state_id == from_state_id;
if !is_loop && !can_exit {
invalid_transitions.insert(TraversalError::CannotExitState(from_state_id));
break;
}
let can_enter = to_state.can_enter(element, type_store, ion_path);
if let Err(violation) = can_enter {
invalid_transitions
.insert(TraversalError::CannotEnterState(to_state_id, violation));
} else if is_loop && !can_reenter {
invalid_transitions.insert(TraversalError::CannotReEnterState(to_state_id));
} else {
let new_num_visits = if is_loop { num_visits + 1 } else { 1 };
new_states.insert((to_state_id, new_num_visits));
}
}
}
if new_states.is_empty() {
return Err(self.build_violation(element, ion_path, invalid_transitions));
}
ion_path.pop();
if new_states.contains(&self.terminal_state) {
return Ok(());
}
current_state_set.clear();
swap(&mut current_state_set, &mut new_states);
input_index += 1;
}
}
fn build_violation(
&self,
event: ElementOrEndOfSequence,
ion_path: &mut IonPath,
invalid_transitions: HashSet<TraversalError>,
) -> Violation {
let mut reasons: Vec<_> = invalid_transitions.into_iter().collect();
reasons.sort();
let reasons = reasons
.into_iter()
.map(|it| match it {
TraversalError::CannotExitState(s) => Violation::new(
"ordered_elements",
ViolationCode::ElementMismatched,
format!("{}: min occurs not reached", &self.states[s]),
ion_path,
),
TraversalError::CannotReEnterState(s) => Violation::new(
"ordered_elements",
ViolationCode::ElementMismatched,
format!("{}: max occurs already reached", &self.states[s],),
ion_path,
),
TraversalError::CannotEnterState(s, v) => Violation::with_violations(
"ordered_elements",
ViolationCode::ElementMismatched,
format!("{}: does not match type", &self.states[s]),
ion_path,
vec![v],
),
})
.collect();
let index = ion_path.pop().unwrap();
Violation::with_violations(
"ordered_elements",
ViolationCode::ElementMismatched,
format!(
"input does not match ordered_elements at index {}: {}",
index,
event
.map(Element::to_string)
.unwrap_or_else(|| "<END_OF_INPUT>".to_string())
),
ion_path,
reasons,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct IntermediateState {
type_ref: TypeReference,
min_visits: usize,
max_visits: usize,
description: String,
}
#[derive(Debug, Clone, PartialEq)]
enum State {
Initial,
Intermediate(StateId, IntermediateState),
Final(StateId),
}
impl State {
fn id(&self) -> StateId {
match self {
State::Initial => 0usize,
State::Intermediate(id, _) => *id,
State::Final(id) => *id,
}
}
fn can_reenter(&self, num_visits: usize) -> bool {
match self {
State::Initial => false,
State::Intermediate(_, s) => num_visits < s.max_visits,
State::Final(_) => false,
}
}
fn can_exit(&self, num_visits: usize) -> bool {
match self {
State::Initial => true,
State::Intermediate(_, s) => num_visits >= s.min_visits,
State::Final(_) => false,
}
}
fn can_enter(
&self,
element: Option<&Element>,
type_store: &TypeStore,
ion_path: &mut IonPath,
) -> ValidationResult {
match self {
State::Initial => unreachable!("There are no transitions to the initial state."),
State::Intermediate(_, s) => {
if let Some(el) = element {
let t = s.type_ref;
t.validate(&IonSchemaElement::from(el), type_store, ion_path)
} else {
Err(Violation::new(
"ordered_elements",
ViolationCode::ElementMismatched,
"expected another element; found <END OF SEQUENCE>",
ion_path,
))
}
}
State::Final(_) => {
if element.is_some() {
Err(Violation::new(
"ordered_elements",
ViolationCode::ElementMismatched,
format!("expected <END OF SEQUENCE>; found: {}", element.unwrap()),
ion_path,
))
} else {
Ok(())
}
}
}
}
}
impl Display for State {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let string = match self {
State::Initial => "<START OF SEQUENCE>",
State::Intermediate(i, s) => &s.description,
State::Final(_) => "<END OF SEQUENCE>",
};
f.write_str(string)
}
}
#[derive(Debug)]
enum TraversalError {
CannotEnterState(StateId, Violation),
CannotExitState(StateId),
CannotReEnterState(StateId),
}
impl PartialOrd for TraversalError {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for TraversalError {
fn cmp(&self, other: &Self) -> Ordering {
let self_id = match self {
TraversalError::CannotEnterState(id, _)
| TraversalError::CannotExitState(id)
| TraversalError::CannotReEnterState(id) => id,
};
let other_id = match other {
TraversalError::CannotEnterState(id, _)
| TraversalError::CannotExitState(id)
| TraversalError::CannotReEnterState(id) => id,
};
self_id.cmp(other_id)
}
}
impl Eq for TraversalError {}
impl PartialEq for TraversalError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
TraversalError::CannotExitState(self_id),
TraversalError::CannotExitState(other_id),
) => self_id == other_id,
(
TraversalError::CannotReEnterState(self_id),
TraversalError::CannotReEnterState(other_id),
) => self_id == other_id,
(
TraversalError::CannotEnterState(self_id, _),
TraversalError::CannotEnterState(other_id, _),
) => self_id == other_id,
(_, _) => false,
}
}
}
impl Hash for TraversalError {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_usize(match self {
TraversalError::CannotEnterState(id, _) => id * 503,
TraversalError::CannotExitState(id) => id * 307,
TraversalError::CannotReEnterState(id) => id * 107,
})
}
}