aethellib 0.9.6

Composable text generation primitives over target-specific TOML corpora with provenance tracking.
Documentation
//! typed expression model and combinator helpers for generation plans.

use std::collections::HashMap;
use std::sync::Arc;

use rand::{Rng, RngExt};

use crate::{
    corpus::PooledValue,
    engine::{AethelError, ComposedValue, GenerationContext},
};

use super::{PoolRef, RuleKey};

/// transform function used by [`RuleExpr::Map`].
pub type MapFn = Arc<dyn Fn(String) -> String + Send + Sync>;

/// callback signature for user-defined expressions.
pub type CustomExprFn = Arc<
    dyn for<'a> Fn(&GenerationContext<'a>, &mut dyn Rng) -> Result<ComposedValue, AethelError>
        + Send
        + Sync,
>;

#[derive(Clone)]
/// user-defined expression logic with declared plan dependencies.
pub struct CustomExpr {
    /// rule keys this expression reads from context.
    pub dependencies: Vec<RuleKey>,
    /// execution callback that builds a composed value.
    pub logic: CustomExprFn,
}

#[derive(Clone)]
/// weighted expression choices for rule evaluation.
pub struct WeightedExpr {
    /// weighted expression branches as `(weight, expression)` tuples.
    pub choices: Vec<(u32, RuleExpr)>,
}

#[derive(Clone)]
/// typed rule expression used by plan-based execution.
pub enum RuleExpr {
    /// selects one random value from a corpus pool.
    Pick(PoolRef, Option<usize>, String),
    /// reads a previously generated rule result by key.
    Recall(RuleKey),
    /// returns a fixed literal value.
    Lit(String),
    /// evaluates child expressions and concatenates their values.
    Join(Vec<RuleExpr>),
    /// evaluates `inner` with probability `p`; otherwise returns an empty value.
    Chance {
        /// probability in the range `[0.0, 1.0]`.
        p: f64,
        /// expression to evaluate when the probability check succeeds.
        inner: Box<RuleExpr>,
    },
    /// picks exactly one child based on weighted random selection.
    Weighted(WeightedExpr),
    /// transforms the evaluated child string while preserving provenance.
    Map {
        /// expression to evaluate before applying the transform.
        inner: Box<RuleExpr>,
        /// string transform callback.
        transform: MapFn,
    },
    /// invokes user-provided expression logic.
    Custom(CustomExpr),
    /// evaluates `inner` only when `condition` produces a non-empty value.
    When {
        /// expression whose output decides whether `inner` runs.
        condition: Box<RuleExpr>,
        /// expression to evaluate when `condition` is non-empty.
        inner: Box<RuleExpr>,
    },
}

/// creates a typed pool pick expression.
pub fn pick(pool: PoolRef, amount: Option<usize>, delimiter: impl Into<String>) -> RuleExpr {
    RuleExpr::Pick(pool, amount, delimiter.into())
}

/// creates a typed dependency recall expression.
pub fn recall(key: impl Into<RuleKey>) -> RuleExpr {
    RuleExpr::Recall(key.into())
}

/// creates a typed literal expression.
pub fn lit(text: impl Into<String>) -> RuleExpr {
    RuleExpr::Lit(text.into())
}

/// creates a typed join expression.
pub fn join(parts: impl IntoIterator<Item = RuleExpr>) -> RuleExpr {
    RuleExpr::Join(parts.into_iter().collect())
}

/// creates a typed chance expression.
pub fn chance(probability: f64, inner: RuleExpr) -> RuleExpr {
    RuleExpr::Chance {
        p: probability,
        inner: Box::new(inner),
    }
}

/// creates a typed weighted expression.
pub fn weighted(choices: impl IntoIterator<Item = (u32, RuleExpr)>) -> RuleExpr {
    RuleExpr::Weighted(WeightedExpr {
        choices: choices.into_iter().collect(),
    })
}

/// creates a typed map expression.
pub fn map<F>(inner: RuleExpr, transform: F) -> RuleExpr
where
    F: Fn(String) -> String + Send + Sync + 'static,
{
    RuleExpr::Map {
        inner: Box::new(inner),
        transform: Arc::new(transform),
    }
}

/// creates a typed user-defined expression with declared dependencies.
pub fn custom<K, F>(dependencies: impl IntoIterator<Item = K>, logic: F) -> RuleExpr
where
    K: Into<RuleKey>,
    F: for<'a> Fn(&GenerationContext<'a>, &mut dyn Rng) -> Result<ComposedValue, AethelError>
        + Send
        + Sync
        + 'static,
{
    RuleExpr::Custom(CustomExpr {
        dependencies: dependencies.into_iter().map(Into::into).collect(),
        logic: Arc::new(logic),
    })
}

/// creates a typed conditional expression.
pub fn when(condition: RuleExpr, inner: RuleExpr) -> RuleExpr {
    RuleExpr::When {
        condition: Box::new(condition),
        inner: Box::new(inner),
    }
}

pub(crate) fn eval_expr(
    expr: &RuleExpr,
    ctx: &GenerationContext<'_>,
    pool_index: &HashMap<PoolRef, Vec<PooledValue>>,
    rng: &mut dyn Rng,
) -> Result<ComposedValue, AethelError> {
    match expr {
        RuleExpr::Pick(pool_ref, amount, delimiter) => {
            let values = pool_index
                .get(pool_ref)
                .ok_or_else(|| AethelError::PoolNotFound {
                    section: pool_ref.section().to_string(),
                    field: pool_ref.field().to_string(),
                })?;

            if values.is_empty() {
                return Err(AethelError::Custom("pool is empty".to_string()));
            }

            let amount = amount.unwrap_or(1);
            let mut selected_values = Vec::new();
            for _ in 0..amount {
                let idx = rng.random_range(0..values.len());
                let selected = &values[idx];
                selected_values.push(selected.clone());
            }

            Ok(ComposedValue {
                value: selected_values
                    .iter()
                    .map(|v| v.value.clone())
                    .collect::<Vec<_>>()
                    .join(delimiter),
                provenance: selected_values
                    .iter()
                    .flat_map(|v| v.provenance.clone())
                    .collect(),
            })
        }
        RuleExpr::Recall(key) => ctx
            .get(key)
            .cloned()
            .ok_or_else(|| AethelError::MissingDependency(key.as_str().to_string())),
        RuleExpr::Lit(text) => Ok(ComposedValue {
            value: text.clone(),
            provenance: Vec::new(),
        }),
        RuleExpr::Join(parts) => {
            let mut result = ComposedValue {
                value: String::new(),
                provenance: Vec::new(),
            };

            for part in parts {
                let next = eval_expr(part, ctx, pool_index, rng)?;
                result = result.merge(next);
            }

            Ok(result)
        }
        RuleExpr::Chance { p, inner } => {
            let roll = rng.random::<f64>();
            if roll < *p {
                eval_expr(inner, ctx, pool_index, rng)
            } else {
                Ok(ComposedValue {
                    value: String::new(),
                    provenance: Vec::new(),
                })
            }
        }
        RuleExpr::Weighted(weighted) => {
            let total_weight: u32 = weighted.choices.iter().map(|(w, _)| *w).sum();
            if total_weight == 0 {
                return Err(AethelError::Custom(
                    "weighted choice has a total weight of 0".to_string(),
                ));
            }

            let mut roll = rng.random_range(0..total_weight);
            for (weight, child) in &weighted.choices {
                if roll < *weight {
                    return eval_expr(child, ctx, pool_index, rng);
                }
                roll -= *weight;
            }

            Err(AethelError::Custom(
                "mathematical error in weighted expression".to_string(),
            ))
        }
        RuleExpr::Map { inner, transform } => {
            let mut composed = eval_expr(inner, ctx, pool_index, rng)?;
            composed.value = transform(composed.value);
            Ok(composed)
        }
        RuleExpr::Custom(custom) => (custom.logic)(ctx, rng),
        RuleExpr::When { condition, inner } => {
            let condition_value = eval_expr(condition, ctx, pool_index, rng)?;
            if condition_value.value.is_empty() {
                Ok(ComposedValue {
                    value: String::new(),
                    provenance: Vec::new(),
                })
            } else {
                eval_expr(inner, ctx, pool_index, rng)
            }
        }
    }
}