use std::collections::HashMap;
use crate::dom::{Document, NodeId, NodeKind};
use crate::error::ValidationError;
use super::types::{IdentityConstraint, IdentityConstraintKind, XsdValidator};
impl XsdValidator {
pub(super) fn evaluate_identity_constraints(
&self,
doc: &Document,
context_node: NodeId,
constraints: &[IdentityConstraint],
errors: &mut Vec<ValidationError>,
) {
let mut key_tables: HashMap<String, Vec<Vec<String>>> = HashMap::new();
for constraint in constraints {
if constraint.kind == IdentityConstraintKind::KeyRef {
continue; }
let selected = idc_select_nodes(doc, context_node, &constraint.selector);
debug_log!(
"identity constraint '{}' ({:?}): selector='{}' selected {} nodes",
constraint.name,
constraint.kind,
constraint.selector,
selected.len()
);
let mut tuples: Vec<Vec<String>> = Vec::new();
for &sel_node in &selected {
let mut field_values: Vec<Option<String>> = Vec::new();
let mut field_source_nodes: Vec<Option<NodeId>> = Vec::new();
let mut all_present = true;
let mut multiplicity_error = false;
for field_xpath in &constraint.fields {
let (value, match_count, source_node) =
idc_evaluate_field(doc, sel_node, field_xpath);
if match_count > 1 && constraint.kind == IdentityConstraintKind::Key {
let elem_name = doc
.element(sel_node)
.map(|e| &*e.name.local_name)
.unwrap_or("?");
errors.push(ValidationError {
message: format!(
"Key '{}': field '{}' selects {} nodes for element '{}' (must select at most one)",
constraint.name, field_xpath, match_count, elem_name
),
line: Some(doc.node_line(sel_node)),
column: Some(doc.node_column(sel_node)),
});
multiplicity_error = true;
break;
}
if value.is_none() {
all_present = false;
}
field_values.push(value);
field_source_nodes.push(source_node);
}
if multiplicity_error {
continue;
}
if constraint.kind == IdentityConstraintKind::Key {
if !all_present {
let elem_name = doc
.element(sel_node)
.map(|e| &*e.name.local_name)
.unwrap_or("?");
errors.push(ValidationError {
message: format!(
"Key '{}': field value missing for element '{}'",
constraint.name, elem_name
),
line: Some(doc.node_line(sel_node)),
column: Some(doc.node_column(sel_node)),
});
continue;
}
}
if !all_present {
continue;
}
let tuple: Vec<String> = field_values
.into_iter()
.enumerate()
.map(|(i, v)| {
let val = v.unwrap();
if let Some(source) = field_source_nodes[i] {
idc_normalize_qname(doc, source, &val)
} else {
val
}
})
.collect();
let is_dup = tuples.iter().any(|existing| {
existing.len() == tuple.len()
&& existing
.iter()
.zip(tuple.iter())
.all(|(a, b)| idc_values_equal(a, b))
});
if is_dup {
let kind_str = match constraint.kind {
IdentityConstraintKind::Key => "Key",
IdentityConstraintKind::Unique => "Unique",
_ => "Constraint",
};
errors.push(ValidationError {
message: format!(
"{} '{}': duplicate value {:?}",
kind_str, constraint.name, tuple
),
line: Some(doc.node_line(sel_node)),
column: Some(doc.node_column(sel_node)),
});
} else {
tuples.push(tuple);
}
}
key_tables.insert(constraint.name.clone(), tuples);
}
for constraint in constraints {
if constraint.kind != IdentityConstraintKind::KeyRef {
continue;
}
let refer_name = match &constraint.refer {
Some(name) => name,
None => continue,
};
let referred_tuples = key_tables.get(refer_name);
if referred_tuples.is_none() {
debug_log!(
"keyref '{}' refers to '{}' which was not found in this scope",
constraint.name,
refer_name
);
continue;
}
let referred_tuples = referred_tuples.unwrap();
let selected = idc_select_nodes(doc, context_node, &constraint.selector);
debug_log!(
"keyref '{}': selector='{}' selected {} nodes, referred key '{}' has {} tuples",
constraint.name,
constraint.selector,
selected.len(),
refer_name,
referred_tuples.len()
);
for &sel_node in &selected {
let mut field_values: Vec<Option<String>> = Vec::new();
let mut field_source_nodes: Vec<Option<NodeId>> = Vec::new();
let mut all_present = true;
for field_xpath in &constraint.fields {
let (value, _match_count, source_node) =
idc_evaluate_field(doc, sel_node, field_xpath);
if value.is_none() {
all_present = false;
}
field_values.push(value);
field_source_nodes.push(source_node);
}
if !all_present {
continue;
}
let tuple: Vec<String> = field_values
.into_iter()
.enumerate()
.map(|(i, v)| {
let val = v.unwrap();
if let Some(source) = field_source_nodes[i] {
idc_normalize_qname(doc, source, &val)
} else {
val
}
})
.collect();
let found = referred_tuples.iter().any(|key_tuple| {
key_tuple.len() == tuple.len()
&& key_tuple
.iter()
.zip(tuple.iter())
.all(|(a, b)| idc_values_equal(a, b))
});
if !found {
errors.push(ValidationError {
message: format!(
"KeyRef '{}': no matching key value {:?} in referred constraint '{}'",
constraint.name, tuple, refer_name
),
line: Some(doc.node_line(sel_node)),
column: Some(doc.node_column(sel_node)),
});
}
}
}
}
}
fn idc_select_nodes(doc: &Document, context: NodeId, selector: &str) -> Vec<NodeId> {
let mut results = Vec::new();
for path_str in selector.split('|') {
let path = path_str.trim();
if path.is_empty() {
continue;
}
let (descendant, steps) = idc_parse_path(path);
if descendant {
let mut descendants = Vec::new();
idc_collect_descendants(doc, context, &mut descendants);
for desc in descendants {
if idc_match_steps(doc, context, desc, &steps, 0) && !results.contains(&desc) {
results.push(desc);
}
}
} else {
let mut candidates = vec![context];
for (i, step) in steps.iter().enumerate() {
let mut next_candidates = Vec::new();
for &cand in &candidates {
for child in doc.children(cand) {
if let Some(NodeKind::Element(_)) = doc.node_kind(child) {
if idc_step_matches(doc, child, step) {
if i == steps.len() - 1 {
if !results.contains(&child) {
results.push(child);
}
} else {
next_candidates.push(child);
}
}
}
}
}
candidates = next_candidates;
if i == steps.len() - 1 {
break;
}
}
}
}
results
}
fn idc_parse_path(path: &str) -> (bool, Vec<String>) {
let mut s = path.trim();
let descendant = if s.starts_with(".//") {
s = &s[3..];
true
} else if s.starts_with("./") {
s = &s[2..];
false
} else {
false
};
let steps: Vec<String> = s.split('/').map(|st| st.trim().to_string()).collect();
(descendant, steps)
}
fn idc_step_matches(doc: &Document, node: NodeId, step: &str) -> bool {
if step == "*" {
return doc.element(node).is_some();
}
if step == "." {
return true; }
if let Some(elem) = doc.element(node) {
if let Some(colon) = step.find(':') {
let _prefix = &step[..colon];
let local = &step[colon + 1..];
elem.name.local_name == local
} else {
elem.name.local_name == step
}
} else {
false
}
}
fn idc_collect_descendants(doc: &Document, node: NodeId, result: &mut Vec<NodeId>) {
for child in doc.children(node) {
if let Some(NodeKind::Element(_)) = doc.node_kind(child) {
result.push(child);
idc_collect_descendants(doc, child, result);
}
}
}
fn idc_match_steps(
doc: &Document,
context: NodeId,
target: NodeId,
steps: &[String],
_step_idx: usize,
) -> bool {
if steps.is_empty() {
return false;
}
let mut path_to_target = Vec::new();
let mut current = target;
while current != context {
path_to_target.push(current);
match doc.parent(current) {
Some(parent) => current = parent,
None => return false, }
}
path_to_target.reverse();
if path_to_target.len() < steps.len() {
return false;
}
let offset = path_to_target.len() - steps.len();
for (i, step) in steps.iter().enumerate() {
if !idc_step_matches(doc, path_to_target[offset + i], step) {
return false;
}
}
true
}
fn idc_evaluate_field(
doc: &Document,
node: NodeId,
field: &str,
) -> (Option<String>, usize, Option<NodeId>) {
let field = field.trim();
if field == "." {
let text = doc.text_content_deep(node);
let trimmed = text.trim();
if trimmed.is_empty() {
return (None, 1, Some(node));
}
return (Some(trimmed.to_string()), 1, Some(node));
}
if let Some(attr_name) = field.strip_prefix('@') {
let attr_name = if let Some(pipe) = attr_name.find('|') {
&attr_name[..pipe]
} else {
attr_name
};
if let Some(elem) = doc.element(node) {
let mut count = 0;
let mut value = None;
for attr in &elem.attributes {
if attr.name.local_name == attr_name {
count += 1;
if value.is_none() {
value = Some(attr.value.to_string());
}
}
}
if count > 0 {
return (value, count, Some(node));
}
}
return (None, 0, Some(node));
}
let parts: Vec<&str> = field.split('/').collect();
let mut current_nodes = vec![node];
for part in &parts {
let part = part.trim();
let mut next_nodes = Vec::new();
for &cn in ¤t_nodes {
for child in doc.children(cn) {
if let Some(NodeKind::Element(_)) = doc.node_kind(child) {
if idc_step_matches(doc, child, part) {
next_nodes.push(child);
}
}
}
}
current_nodes = next_nodes;
if current_nodes.is_empty() {
return (None, 0, None);
}
}
let match_count = current_nodes.len();
if let Some(&result_node) = current_nodes.first() {
let text = doc.text_content_deep(result_node);
let trimmed = text.trim();
if trimmed.is_empty() {
return (Some(String::new()), match_count, Some(result_node));
}
(Some(trimmed.to_string()), match_count, Some(result_node))
} else {
(None, 0, None)
}
}
fn idc_normalize_qname(doc: &Document, source_node: NodeId, value: &str) -> String {
let value = value.trim();
if let Some(colon) = value.find(':') {
let prefix = &value[..colon];
let local = &value[colon + 1..];
if prefix.is_empty() || local.is_empty() {
return value.to_string();
}
if let Some(ns_uri) = idc_resolve_prefix(doc, source_node, prefix) {
return format!("{{{}}}{}", ns_uri, local);
}
}
value.to_string()
}
fn idc_resolve_prefix(doc: &Document, node: NodeId, prefix: &str) -> Option<String> {
let mut current = Some(node);
while let Some(n) = current {
if let Some(elem) = doc.element(n) {
if let Some((_, uri)) = elem
.namespace_declarations
.iter()
.find(|(p, _)| &**p == prefix)
{
return Some(uri.to_string());
}
}
current = doc.parent(n);
}
None
}
fn idc_values_equal(a: &str, b: &str) -> bool {
if a == b {
return true;
}
if let (Some(da), Some(db)) = (idc_parse_decimal(a), idc_parse_decimal(b)) {
return da == db;
}
false
}
fn idc_parse_decimal(s: &str) -> Option<String> {
let s = s.trim();
if s.is_empty() {
return None;
}
let mut chars = s.chars().peekable();
let negative = if chars.peek() == Some(&'-') {
chars.next();
true
} else if chars.peek() == Some(&'+') {
chars.next();
false
} else {
false
};
let remaining: String = chars.collect();
if remaining.is_empty() {
return None;
}
let (int_part, frac_part) = if let Some(dot_pos) = remaining.find('.') {
(&remaining[..dot_pos], &remaining[dot_pos + 1..])
} else {
(remaining.as_str(), "")
};
if !int_part.chars().all(|c| c.is_ascii_digit()) {
return None;
}
if !frac_part.chars().all(|c| c.is_ascii_digit()) {
return None;
}
let int_normalized = int_part.trim_start_matches('0');
let int_normalized = if int_normalized.is_empty() {
"0"
} else {
int_normalized
};
let frac_normalized = frac_part.trim_end_matches('0');
if int_normalized == "0" && frac_normalized.is_empty() {
return Some("0".to_string()); }
let sign = if negative { "-" } else { "" };
if frac_normalized.is_empty() {
Some(format!("{}{}", sign, int_normalized))
} else {
Some(format!("{}{}.{}", sign, int_normalized, frac_normalized))
}
}