use std::fmt;
use std::sync::{Arc, OnceLock};
#[derive(Debug, Clone)]
pub enum ParamDomain {
Continuous { min: f64, max: f64 },
Discrete(Vec<f64>),
Int { min: i64, max: i64 },
}
impl ParamDomain {
pub fn clamp(&self, value: f64) -> f64 {
match self {
ParamDomain::Continuous { min, max } => value.clamp(*min, *max),
ParamDomain::Int { min, max } => (value.round() as i64).clamp(*min, *max) as f64,
ParamDomain::Discrete(values) => {
if values.is_empty() { return value; }
let mut best = values[0];
let mut best_d = (best - value).abs();
for &v in values.iter().skip(1) {
let d = (v - value).abs();
if d < best_d {
best_d = d;
best = v;
}
}
best
}
}
}
pub fn midpoint(&self) -> f64 {
match self {
ParamDomain::Continuous { min, max } => 0.5 * (*min + *max),
ParamDomain::Int { min, max } => ((*min + *max) as f64) * 0.5,
ParamDomain::Discrete(values) => values.first().copied().unwrap_or(0.0),
}
}
}
pub trait CommandBuilder<C>: fmt::Debug + Send + Sync {
fn domains(&self) -> Vec<ParamDomain>;
fn build(&self, values: &[f64]) -> C;
fn describe(&self) -> String {
format!("{:?}", self)
}
}
pub enum CommandTree<C> {
Empty,
Leaf(C),
Parametric {
label: String,
builder: Arc<dyn CommandBuilder<C>>,
},
Layer {
label: String,
children: Vec<(String, Arc<CommandTree<C>>)>,
},
LazyLayer {
label: String,
expand: Arc<dyn Fn() -> Vec<(String, Arc<CommandTree<C>>)> + Send + Sync>,
cache: OnceLock<Vec<(String, Arc<CommandTree<C>>)>>,
},
}
impl<C: 'static> CommandTree<C> {
pub fn lazy_layer<F>(label: impl Into<String>, expand: F) -> Self
where
F: Fn() -> Vec<(String, Arc<CommandTree<C>>)> + Send + Sync + 'static,
{
CommandTree::LazyLayer {
label: label.into(),
expand: Arc::new(expand),
cache: OnceLock::new(),
}
}
}
impl<C: fmt::Debug> fmt::Debug for CommandTree<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CommandTree::Empty => write!(f, "Empty"),
CommandTree::Leaf(c) => write!(f, "Leaf({:?})", c),
CommandTree::Parametric { label, builder } => {
write!(f, "Parametric({:?} -> {})", label, builder.describe())
}
CommandTree::Layer { label, children } => {
f.debug_struct("Layer")
.field("label", label)
.field("child_count", &children.len())
.finish()
}
CommandTree::LazyLayer { label, cache, .. } => {
f.debug_struct("LazyLayer")
.field("label", label)
.field("expanded", &cache.get().is_some())
.field("child_count", &cache.get().map(|v| v.len()))
.finish()
}
}
}
}
impl<C> CommandTree<C> {
pub fn is_empty(&self) -> bool {
matches!(self, CommandTree::Empty)
}
pub fn children(&self) -> Option<&[(String, Arc<CommandTree<C>>)]> {
match self {
CommandTree::Layer { children, .. } => Some(children.as_slice()),
CommandTree::LazyLayer { expand, cache, .. } => {
Some(cache.get_or_init(|| expand()).as_slice())
}
_ => None,
}
}
pub fn label(&self) -> Option<&str> {
match self {
CommandTree::Layer { label, .. } => Some(label.as_str()),
CommandTree::LazyLayer { label, .. } => Some(label.as_str()),
CommandTree::Parametric { label, .. } => Some(label.as_str()),
_ => None,
}
}
pub fn leaf_count(&self) -> usize {
match self {
CommandTree::Leaf(_) => 1,
CommandTree::Layer { children, .. } => {
children.iter().map(|(_, c)| c.leaf_count()).sum()
}
CommandTree::LazyLayer { .. } => {
self.children()
.map(|cs| cs.iter().map(|(_, c)| c.leaf_count()).sum())
.unwrap_or(0)
}
_ => 0,
}
}
pub fn for_each_leaf<F: FnMut(&C)>(&self, mut f: F) {
fn walk<C, F: FnMut(&C)>(node: &CommandTree<C>, f: &mut F) {
match node {
CommandTree::Leaf(c) => f(c),
CommandTree::Layer { children, .. } => {
for (_, child) in children {
walk(child, f);
}
}
CommandTree::LazyLayer { .. } => {
if let Some(children) = node.children() {
for (_, child) in children {
walk(child, f);
}
}
}
_ => {}
}
}
walk(self, &mut f);
}
pub fn find_leaf<F: Fn(&C) -> bool>(&self, pred: F) -> Option<&C> {
fn walk<'a, C, F: Fn(&C) -> bool>(node: &'a CommandTree<C>, pred: &F) -> Option<&'a C> {
match node {
CommandTree::Leaf(c) if pred(c) => Some(c),
CommandTree::Layer { children, .. } => {
for (_, child) in children {
if let Some(hit) = walk(child, pred) {
return Some(hit);
}
}
None
}
CommandTree::LazyLayer { .. } => {
if let Some(children) = node.children() {
for (_, child) in children {
if let Some(hit) = walk(child, pred) {
return Some(hit);
}
}
}
None
}
_ => None,
}
}
walk(self, &pred)
}
pub fn child(&self, key: &str) -> Option<&CommandTree<C>> {
match self {
CommandTree::Layer { children, .. } => children
.iter()
.find(|(k, _)| k == key)
.map(|(_, child)| child.as_ref()),
CommandTree::LazyLayer { .. } => {
self.children()?
.iter()
.find(|(k, _)| k == key)
.map(|(_, child)| child.as_ref())
}
_ => None,
}
}
}
impl<C: Clone> CommandTree<C> {
pub fn flatten(&self) -> Vec<C> {
let mut out = Vec::new();
self.for_each_leaf(|c| out.push(c.clone()));
out
}
pub fn argmax<F: Fn(&C) -> f64>(&self, score: F) -> Option<C> {
let mut best: Option<(C, f64)> = None;
self.for_each_leaf(|c| {
let s = score(c);
if best.as_ref().map(|(_, b)| s > *b).unwrap_or(true) {
best = Some((c.clone(), s));
}
});
best.map(|(c, _)| c)
}
}
#[cfg(test)]
mod tests {
use super::*;
type C = i32;
fn leaf(c: C) -> Arc<CommandTree<C>> { Arc::new(CommandTree::Leaf(c)) }
fn layer(label: &str, children: Vec<(&str, Arc<CommandTree<C>>)>) -> Arc<CommandTree<C>> {
Arc::new(CommandTree::Layer {
label: label.into(),
children: children.into_iter().map(|(k, v)| (k.into(), v)).collect(),
})
}
#[test]
fn empty_is_empty() {
let t: CommandTree<C> = CommandTree::Empty;
assert!(t.is_empty());
assert_eq!(t.leaf_count(), 0);
assert!(t.flatten().is_empty());
}
#[test]
fn leaf_counts_itself() {
let t = CommandTree::Leaf(42);
assert!(!t.is_empty());
assert_eq!(t.leaf_count(), 1);
assert_eq!(t.flatten(), vec![42]);
}
#[test]
fn layer_aggregates_children() {
let tree = layer("root", vec![
("a", leaf(1)),
("b", layer("sub", vec![
("x", leaf(2)),
("y", leaf(3)),
])),
]);
assert_eq!(tree.leaf_count(), 3);
let mut flat = tree.flatten();
flat.sort();
assert_eq!(flat, vec![1, 2, 3]);
}
#[test]
fn find_leaf_walks_tree() {
let tree = layer("root", vec![
("a", leaf(1)),
("b", leaf(7)),
("c", leaf(3)),
]);
assert_eq!(tree.find_leaf(|&x| x == 7), Some(&7));
assert_eq!(tree.find_leaf(|&x| x == 99), None);
}
#[test]
fn argmax_picks_highest_score() {
let tree = layer("root", vec![
("a", leaf(1)),
("b", leaf(5)),
("c", leaf(3)),
]);
let best = tree.argmax(|&c| c as f64);
assert_eq!(best, Some(5));
}
#[test]
fn child_lookup_by_key() {
let tree = layer("root", vec![
("attack", leaf(10)),
("move", leaf(20)),
]);
assert!(matches!(tree.child("attack"), Some(CommandTree::Leaf(10))));
assert!(matches!(tree.child("move"), Some(CommandTree::Leaf(20))));
assert!(tree.child("missing").is_none());
}
#[test]
fn param_domain_clamps() {
let d = ParamDomain::Continuous { min: 0.0, max: 10.0 };
assert_eq!(d.clamp(-5.0), 0.0);
assert_eq!(d.clamp(5.0), 5.0);
assert_eq!(d.clamp(15.0), 10.0);
assert_eq!(d.midpoint(), 5.0);
}
#[test]
fn int_domain_rounds_and_clamps() {
let d = ParamDomain::Int { min: 0, max: 5 };
assert_eq!(d.clamp(2.7), 3.0);
assert_eq!(d.clamp(-100.0), 0.0);
assert_eq!(d.clamp(100.0), 5.0);
}
#[test]
fn discrete_domain_snaps_to_nearest() {
let d = ParamDomain::Discrete(vec![-1.0, 0.0, 1.0]);
assert_eq!(d.clamp(-0.9), -1.0);
assert_eq!(d.clamp(0.4), 0.0);
assert_eq!(d.clamp(0.6), 1.0);
assert_eq!(d.midpoint(), -1.0); }
#[derive(Debug)]
struct ScaleBy(f64);
impl CommandBuilder<f64> for ScaleBy {
fn domains(&self) -> Vec<ParamDomain> {
vec![ParamDomain::Continuous { min: 0.0, max: 1.0 }]
}
fn build(&self, values: &[f64]) -> f64 {
values[0] * self.0
}
fn describe(&self) -> String {
format!("ScaleBy({})", self.0)
}
}
#[test]
fn parametric_leaf_builds_command() {
let tree: CommandTree<f64> = CommandTree::Parametric {
label: "scale".into(),
builder: Arc::new(ScaleBy(10.0)),
};
if let CommandTree::Parametric { builder, .. } = &tree {
let domains = builder.domains();
assert_eq!(domains.len(), 1);
let v = builder.build(&[0.5]);
assert!((v - 5.0).abs() < 1e-9);
} else {
panic!("expected Parametric");
}
}
#[test]
fn debug_format_is_useful() {
let tree: CommandTree<f64> = CommandTree::Parametric {
label: "scale".into(),
builder: Arc::new(ScaleBy(10.0)),
};
let s = format!("{:?}", tree);
assert!(s.contains("scale"));
assert!(s.contains("ScaleBy(10)"));
}
#[test]
fn structural_sharing_works() {
let shared = leaf(42);
let tree = layer("root", vec![
("a", shared.clone()),
("b", shared.clone()),
]);
if let CommandTree::Layer { children, .. } = &*tree {
assert!(Arc::ptr_eq(&children[0].1, &children[1].1));
} else {
panic!();
}
assert_eq!(tree.leaf_count(), 2);
}
#[test]
fn lazy_layer_expands_on_access() {
use std::sync::atomic::{AtomicUsize, Ordering};
let calls = Arc::new(AtomicUsize::new(0));
let calls_clone = calls.clone();
let tree: CommandTree<C> = CommandTree::lazy_layer("root", move || {
calls_clone.fetch_add(1, Ordering::SeqCst);
vec![("a".into(), leaf(1)), ("b".into(), leaf(2))]
});
assert_eq!(calls.load(Ordering::SeqCst), 0);
let children = tree.children().unwrap();
assert_eq!(children.len(), 2);
assert_eq!(calls.load(Ordering::SeqCst), 1);
let _ = tree.children().unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn lazy_layer_flatten_works() {
let tree: CommandTree<C> = CommandTree::lazy_layer("root", || {
vec![("a".into(), leaf(1)), ("b".into(), leaf(2)), ("c".into(), leaf(3))]
});
let mut flat = tree.flatten();
flat.sort();
assert_eq!(flat, vec![1, 2, 3]);
assert_eq!(tree.leaf_count(), 3);
}
#[test]
fn lazy_layer_nested_inside_layer() {
use std::sync::atomic::{AtomicUsize, Ordering};
let calls = Arc::new(AtomicUsize::new(0));
let calls_clone = calls.clone();
let lazy = Arc::new(CommandTree::<C>::lazy_layer("expensive", move || {
calls_clone.fetch_add(1, Ordering::SeqCst);
vec![("deep".into(), leaf(99))]
}));
let tree = layer("root", vec![
("cheap", leaf(1)),
("expensive", lazy),
]);
let cheap = tree.child("cheap").unwrap();
assert!(matches!(cheap, CommandTree::Leaf(1)));
assert_eq!(calls.load(Ordering::SeqCst), 0);
let expensive_ref = tree.child("expensive").unwrap();
assert!(matches!(expensive_ref, CommandTree::LazyLayer { .. }));
assert_eq!(calls.load(Ordering::SeqCst), 0);
let _ = expensive_ref.children();
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn lazy_layer_child_lookup_forces_only_this_layer() {
use std::sync::atomic::{AtomicUsize, Ordering};
let outer = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(AtomicUsize::new(0));
let outer_c = outer.clone();
let inner_c = inner.clone();
let tree: CommandTree<C> = CommandTree::lazy_layer("outer", move || {
outer_c.fetch_add(1, Ordering::SeqCst);
let inner_c2 = inner_c.clone();
let inner_tree = Arc::new(CommandTree::<C>::lazy_layer("inner", move || {
inner_c2.fetch_add(1, Ordering::SeqCst);
vec![("x".into(), leaf(7))]
}));
vec![("nested".into(), inner_tree)]
});
let nested = tree.child("nested").unwrap();
assert_eq!(outer.load(Ordering::SeqCst), 1);
assert_eq!(inner.load(Ordering::SeqCst), 0);
let _ = nested.children();
assert_eq!(inner.load(Ordering::SeqCst), 1);
}
#[test]
fn lazy_layer_debug_marks_expansion_state() {
let tree: CommandTree<C> = CommandTree::lazy_layer("root", || vec![
("a".into(), leaf(1)),
]);
let s = format!("{:?}", tree);
assert!(s.contains("LazyLayer"));
assert!(s.contains("expanded: false"));
let _ = tree.children();
let s = format!("{:?}", tree);
assert!(s.contains("expanded: true"));
assert!(s.contains("child_count: Some(1)"));
}
}