use std::collections::HashMap;
use std::sync::Arc;
use rand::{Rng, RngExt};
use crate::{
corpus::PooledValue,
engine::{AethelError, ComposedValue, GenerationContext},
};
use super::{PoolRef, RuleKey};
pub type MapFn = Arc<dyn Fn(String) -> String + Send + Sync>;
pub type CustomExprFn = Arc<
dyn for<'a> Fn(&GenerationContext<'a>, &mut dyn Rng) -> Result<ComposedValue, AethelError>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct CustomExpr {
pub dependencies: Vec<RuleKey>,
pub logic: CustomExprFn,
}
#[derive(Clone)]
pub struct WeightedExpr {
pub choices: Vec<(u32, RuleExpr)>,
}
#[derive(Clone)]
pub enum RuleExpr {
Pick(PoolRef, Option<usize>, String),
Recall(RuleKey),
Lit(String),
Join(Vec<RuleExpr>),
Chance {
p: f64,
inner: Box<RuleExpr>,
},
Weighted(WeightedExpr),
Map {
inner: Box<RuleExpr>,
transform: MapFn,
},
Custom(CustomExpr),
When {
condition: Box<RuleExpr>,
inner: Box<RuleExpr>,
},
}
pub fn pick(pool: PoolRef, amount: Option<usize>, delimiter: impl Into<String>) -> RuleExpr {
RuleExpr::Pick(pool, amount, delimiter.into())
}
pub fn recall(key: impl Into<RuleKey>) -> RuleExpr {
RuleExpr::Recall(key.into())
}
pub fn lit(text: impl Into<String>) -> RuleExpr {
RuleExpr::Lit(text.into())
}
pub fn join(parts: impl IntoIterator<Item = RuleExpr>) -> RuleExpr {
RuleExpr::Join(parts.into_iter().collect())
}
pub fn chance(probability: f64, inner: RuleExpr) -> RuleExpr {
RuleExpr::Chance {
p: probability,
inner: Box::new(inner),
}
}
pub fn weighted(choices: impl IntoIterator<Item = (u32, RuleExpr)>) -> RuleExpr {
RuleExpr::Weighted(WeightedExpr {
choices: choices.into_iter().collect(),
})
}
pub fn map<F>(inner: RuleExpr, transform: F) -> RuleExpr
where
F: Fn(String) -> String + Send + Sync + 'static,
{
RuleExpr::Map {
inner: Box::new(inner),
transform: Arc::new(transform),
}
}
pub fn custom<K, F>(dependencies: impl IntoIterator<Item = K>, logic: F) -> RuleExpr
where
K: Into<RuleKey>,
F: for<'a> Fn(&GenerationContext<'a>, &mut dyn Rng) -> Result<ComposedValue, AethelError>
+ Send
+ Sync
+ 'static,
{
RuleExpr::Custom(CustomExpr {
dependencies: dependencies.into_iter().map(Into::into).collect(),
logic: Arc::new(logic),
})
}
pub fn when(condition: RuleExpr, inner: RuleExpr) -> RuleExpr {
RuleExpr::When {
condition: Box::new(condition),
inner: Box::new(inner),
}
}
pub(crate) fn eval_expr(
expr: &RuleExpr,
ctx: &GenerationContext<'_>,
pool_index: &HashMap<PoolRef, Vec<PooledValue>>,
rng: &mut dyn Rng,
) -> Result<ComposedValue, AethelError> {
match expr {
RuleExpr::Pick(pool_ref, amount, delimiter) => {
let values = pool_index
.get(pool_ref)
.ok_or_else(|| AethelError::PoolNotFound {
section: pool_ref.section().to_string(),
field: pool_ref.field().to_string(),
})?;
if values.is_empty() {
return Err(AethelError::Custom("pool is empty".to_string()));
}
let amount = amount.unwrap_or(1);
let mut selected_values = Vec::new();
for _ in 0..amount {
let idx = rng.random_range(0..values.len());
let selected = &values[idx];
selected_values.push(selected.clone());
}
Ok(ComposedValue {
value: selected_values
.iter()
.map(|v| v.value.clone())
.collect::<Vec<_>>()
.join(delimiter),
provenance: selected_values
.iter()
.flat_map(|v| v.provenance.clone())
.collect(),
})
}
RuleExpr::Recall(key) => ctx
.get(key)
.cloned()
.ok_or_else(|| AethelError::MissingDependency(key.as_str().to_string())),
RuleExpr::Lit(text) => Ok(ComposedValue {
value: text.clone(),
provenance: Vec::new(),
}),
RuleExpr::Join(parts) => {
let mut result = ComposedValue {
value: String::new(),
provenance: Vec::new(),
};
for part in parts {
let next = eval_expr(part, ctx, pool_index, rng)?;
result = result.merge(next);
}
Ok(result)
}
RuleExpr::Chance { p, inner } => {
let roll = rng.random::<f64>();
if roll < *p {
eval_expr(inner, ctx, pool_index, rng)
} else {
Ok(ComposedValue {
value: String::new(),
provenance: Vec::new(),
})
}
}
RuleExpr::Weighted(weighted) => {
let total_weight: u32 = weighted.choices.iter().map(|(w, _)| *w).sum();
if total_weight == 0 {
return Err(AethelError::Custom(
"weighted choice has a total weight of 0".to_string(),
));
}
let mut roll = rng.random_range(0..total_weight);
for (weight, child) in &weighted.choices {
if roll < *weight {
return eval_expr(child, ctx, pool_index, rng);
}
roll -= *weight;
}
Err(AethelError::Custom(
"mathematical error in weighted expression".to_string(),
))
}
RuleExpr::Map { inner, transform } => {
let mut composed = eval_expr(inner, ctx, pool_index, rng)?;
composed.value = transform(composed.value);
Ok(composed)
}
RuleExpr::Custom(custom) => (custom.logic)(ctx, rng),
RuleExpr::When { condition, inner } => {
let condition_value = eval_expr(condition, ctx, pool_index, rng)?;
if condition_value.value.is_empty() {
Ok(ComposedValue {
value: String::new(),
provenance: Vec::new(),
})
} else {
eval_expr(inner, ctx, pool_index, rng)
}
}
}
}