aethellib 0.9.6

Composable text generation primitives over target-specific TOML corpora with provenance tracking.
Documentation
//! plan validation for typed rule graphs.
//!
//! this module validates structural and semantic constraints before
//! compilation.

use std::collections::{HashMap, HashSet};

use super::{
    combinators::RuleExpr,
    error::{PlanError, PlanErrorReport},
    plan::{PlanBuilder, PlanNode, PoolRef, RuleKey, ValidatedPlan},
};

pub(crate) fn validate_pool_ref(value: &str, kind: &str) -> Result<(), PlanError> {
    if value.is_empty() {
        return Err(PlanError::InvalidIdentifier {
            kind: kind.to_string(),
            value: value.to_string(),
            reason: "must not be empty".to_string(),
        });
    }

    if !value
        .chars()
        .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | '.'))
    {
        return Err(PlanError::InvalidIdentifier {
            kind: kind.to_string(),
            value: value.to_string(),
            reason: "allowed characters are ascii letters, digits, '_', '-', and '.'".to_string(),
        });
    }

    Ok(())
}

pub(crate) fn validate_plan(
    builder: PlanBuilder<'_>,
) -> Result<ValidatedPlan<'_>, PlanErrorReport> {
    let mut errors = Vec::new();

    errors.extend(check_duplicates(&builder.nodes));
    errors.extend(check_missing_dependencies(&builder.nodes));
    errors.extend(check_cycles(&builder.nodes));
    errors.extend(check_pool_refs(builder.corpus, &builder.nodes));
    errors.extend(check_numeric_constraints(&builder.nodes));

    if errors.is_empty() {
        Ok(ValidatedPlan {
            corpus: builder.corpus,
            nodes: builder.nodes,
        })
    } else {
        Err(PlanErrorReport { errors })
    }
}

fn check_duplicates(nodes: &[PlanNode]) -> Vec<PlanError> {
    let mut seen: HashSet<RuleKey> = HashSet::new();
    let mut errors = Vec::new();

    for node in nodes {
        if !seen.insert(node.key.clone()) {
            errors.push(PlanError::DuplicateRuleKey(node.key.as_str().to_string()));
        }
    }

    errors
}

fn check_missing_dependencies(nodes: &[PlanNode]) -> Vec<PlanError> {
    let known: HashSet<RuleKey> = nodes.iter().map(|node| node.key.clone()).collect();
    let mut errors = Vec::new();

    for node in nodes {
        let mut deps = Vec::new();
        collect_dependencies(&node.expr, &mut deps);

        for dep in deps {
            if !known.contains(&dep) {
                errors.push(PlanError::MissingDependency {
                    from: node.key.as_str().to_string(),
                    to: dep.as_str().to_string(),
                });
            }
        }
    }

    errors
}

fn check_cycles(nodes: &[PlanNode]) -> Vec<PlanError> {
    let key_to_index: HashMap<RuleKey, usize> = nodes
        .iter()
        .enumerate()
        .map(|(idx, node)| (node.key.clone(), idx))
        .collect();

    let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); nodes.len()];
    for (idx, node) in nodes.iter().enumerate() {
        let mut deps = Vec::new();
        collect_dependencies(&node.expr, &mut deps);

        for dep in deps {
            if let Some(dep_idx) = key_to_index.get(&dep) {
                adjacency[idx].push(*dep_idx);
            }
        }
    }

    let mut visiting = vec![false; nodes.len()];
    let mut visited = vec![false; nodes.len()];
    let mut stack: Vec<usize> = Vec::new();

    for node_idx in 0..nodes.len() {
        if visited[node_idx] {
            continue;
        }

        if let Some(path) = dfs_cycle(
            node_idx,
            &adjacency,
            &mut visiting,
            &mut visited,
            &mut stack,
        ) {
            let cycle: Vec<String> = path
                .into_iter()
                .map(|idx| nodes[idx].key.as_str().to_string())
                .collect();
            return vec![PlanError::CycleDetected(cycle)];
        }
    }

    Vec::new()
}

fn dfs_cycle(
    node: usize,
    adjacency: &[Vec<usize>],
    visiting: &mut [bool],
    visited: &mut [bool],
    stack: &mut Vec<usize>,
) -> Option<Vec<usize>> {
    visiting[node] = true;
    visited[node] = true;
    stack.push(node);

    for next in &adjacency[node] {
        if !visited[*next] {
            if let Some(path) = dfs_cycle(*next, adjacency, visiting, visited, stack) {
                return Some(path);
            }
        } else if visiting[*next]
            && let Some(start) = stack.iter().position(|idx| idx == next)
        {
            return Some(stack[start..].to_vec());
        }
    }

    stack.pop();
    visiting[node] = false;
    None
}

fn check_pool_refs(corpus: &crate::corpus::Corpus, nodes: &[PlanNode]) -> Vec<PlanError> {
    let mut refs = Vec::new();
    for node in nodes {
        collect_pool_refs(&node.expr, &mut refs);
    }

    let mut errors = Vec::new();
    for pool_ref in refs {
        if corpus
            .pooled_values_for_field_section(pool_ref.field(), pool_ref.section())
            .is_none()
        {
            errors.push(PlanError::PoolRefNotFound {
                section: pool_ref.section().to_string(),
                field: pool_ref.field().to_string(),
            });
        }
    }

    errors
}

fn check_numeric_constraints(nodes: &[PlanNode]) -> Vec<PlanError> {
    let mut errors = Vec::new();

    for node in nodes {
        collect_numeric_errors(&node.key, &node.expr, &mut errors);
    }

    errors
}

pub(crate) fn collect_dependencies(expr: &RuleExpr, out: &mut Vec<RuleKey>) {
    match expr {
        RuleExpr::Recall(key) => out.push(key.clone()),
        RuleExpr::Join(parts) => {
            for part in parts {
                collect_dependencies(part, out);
            }
        }
        RuleExpr::Chance { inner, .. } => collect_dependencies(inner, out),
        RuleExpr::Weighted(weighted) => {
            for (_, expr) in &weighted.choices {
                collect_dependencies(expr, out);
            }
        }
        RuleExpr::Map { inner, .. } => collect_dependencies(inner, out),
        RuleExpr::When { condition, inner } => {
            collect_dependencies(condition, out);
            collect_dependencies(inner, out);
        }
        RuleExpr::Custom(custom) => out.extend(custom.dependencies.iter().cloned()),
        RuleExpr::Pick(..) | RuleExpr::Lit(_) => {}
    }
}

pub(crate) fn collect_pool_refs(expr: &RuleExpr, out: &mut Vec<PoolRef>) {
    match expr {
        RuleExpr::Pick(pool_ref, ..) => out.push(pool_ref.clone()),
        RuleExpr::Join(parts) => {
            for part in parts {
                collect_pool_refs(part, out);
            }
        }
        RuleExpr::Chance { inner, .. } => collect_pool_refs(inner, out),
        RuleExpr::Weighted(weighted) => {
            for (_, expr) in &weighted.choices {
                collect_pool_refs(expr, out);
            }
        }
        RuleExpr::Map { inner, .. } => collect_pool_refs(inner, out),
        RuleExpr::When { condition, inner } => {
            collect_pool_refs(condition, out);
            collect_pool_refs(inner, out);
        }
        RuleExpr::Recall(_) | RuleExpr::Lit(_) | RuleExpr::Custom(_) => {}
    }
}

fn collect_numeric_errors(rule_key: &RuleKey, expr: &RuleExpr, out: &mut Vec<PlanError>) {
    match expr {
        RuleExpr::Chance { p, inner } => {
            if !(*p >= 0.0 && *p <= 1.0) {
                out.push(PlanError::InvalidChanceProbability {
                    rule: rule_key.as_str().to_string(),
                    value: *p,
                });
            }
            collect_numeric_errors(rule_key, inner, out);
        }
        RuleExpr::Weighted(weighted) => {
            let total: u32 = weighted.choices.iter().map(|(w, _)| *w).sum();
            if total == 0 {
                out.push(PlanError::WeightedTotalZero {
                    rule: rule_key.as_str().to_string(),
                });
            }

            for (_, expr) in &weighted.choices {
                collect_numeric_errors(rule_key, expr, out);
            }
        }
        RuleExpr::Join(parts) => {
            for part in parts {
                collect_numeric_errors(rule_key, part, out);
            }
        }
        RuleExpr::Map { inner, .. } => collect_numeric_errors(rule_key, inner, out),
        RuleExpr::When { condition, inner } => {
            collect_numeric_errors(rule_key, condition, out);
            collect_numeric_errors(rule_key, inner, out);
        }
        RuleExpr::Pick(..) | RuleExpr::Recall(_) | RuleExpr::Lit(_) | RuleExpr::Custom(_) => {}
    }
}