use std::collections::HashMap;
use rand::rngs::SmallRng;
use rand::seq::SliceRandom;
use crate::native::bignum::BigUint;
use crate::native::core::{ChoiceKind, ChoiceNode, ChoiceValue, Status};
#[derive(Clone, PartialEq, Eq, Hash)]
enum ChoiceValueKey {
Integer(i128),
Boolean(bool),
Float(u64),
}
impl From<&ChoiceValue> for ChoiceValueKey {
fn from(v: &ChoiceValue) -> Self {
match v {
ChoiceValue::Integer(n) => ChoiceValueKey::Integer(*n),
ChoiceValue::Boolean(b) => ChoiceValueKey::Boolean(*b),
ChoiceValue::Float(f) => ChoiceValueKey::Float(f.to_bits()),
}
}
}
#[derive(Default)]
pub(crate) struct DataTreeNode {
kind: Option<ChoiceKind>,
children: HashMap<ChoiceValueKey, Box<DataTreeNode>>,
conclusion: Option<Status>,
pub(crate) is_exhausted: bool,
}
impl Drop for DataTreeNode {
fn drop(&mut self) {
let mut stack: Vec<Box<DataTreeNode>> =
self.children.drain().map(|(_, child)| child).collect();
while let Some(mut node) = stack.pop() {
stack.extend(node.children.drain().map(|(_, child)| child));
}
}
}
impl DataTreeNode {
fn check_exhausted(&mut self) -> bool {
if self.is_exhausted {
return true;
}
if self.conclusion.is_some() {
self.is_exhausted = true;
return true;
}
if let Some(ref kind) = self.kind {
let max_c = kind.max_children();
if BigUint::from(self.children.len() as u64) >= max_c {
let all_exhausted = self.children.values_mut().all(|c| c.check_exhausted());
if all_exhausted {
self.is_exhausted = true;
return true;
}
}
}
false
}
}
pub(crate) fn record_tree(
tree_root: &mut DataTreeNode,
nodes: &[ChoiceNode],
status: Status,
kill_depths: &[usize],
) {
let mut path: Vec<*mut DataTreeNode> = Vec::with_capacity(nodes.len() + 1);
path.push(tree_root as *mut _);
for first in nodes {
let parent_ptr = *path.last().unwrap();
let node = unsafe { &mut *parent_ptr };
match &node.kind {
Some(expected_kind) if *expected_kind != first.kind => {
panic!(
"Your data generation is non-deterministic: at the same choice \
position with the same prefix, the schema changed from {:?} to {:?}. \
This usually means a generator depends on global mutable state.",
expected_kind, first.kind
);
}
None => {
node.kind = Some(first.kind.clone());
}
_ => {}
}
let key = ChoiceValueKey::from(&first.value);
let child = node
.children
.entry(key)
.or_insert_with(|| Box::new(DataTreeNode::default()));
path.push(child.as_mut() as *mut _);
}
if status >= Status::Invalid {
let leaf = unsafe { &mut **path.last().unwrap() };
leaf.conclusion = Some(status);
}
for &depth in kill_depths {
if depth < path.len() {
let node = unsafe { &mut *path[depth] };
node.is_exhausted = true;
}
}
while let Some(p) = path.pop() {
let node = unsafe { &mut *p };
node.check_exhausted();
}
}
const ENUMERATION_CAP: u64 = 1024;
fn pick_non_exhausted_value(
kind: &ChoiceKind,
children: &HashMap<ChoiceValueKey, Box<DataTreeNode>>,
rng: &mut SmallRng,
) -> Option<ChoiceValue> {
for _ in 0..10 {
let value = kind.random_value(rng);
let key = ChoiceValueKey::from(&value);
match children.get(&key) {
Some(child) if child.is_exhausted => continue,
_ => return Some(value),
}
}
let candidates = kind.enumerate(ENUMERATION_CAP)?;
let mut untried: Vec<ChoiceValue> = candidates
.into_iter()
.filter(|v| {
let key = ChoiceValueKey::from(v);
children.get(&key).is_none_or(|c| !c.is_exhausted)
})
.collect();
if untried.is_empty() {
return None; }
untried.shuffle(rng);
untried.into_iter().next()
}
pub(crate) fn generate_novel_prefix(
tree_root: &DataTreeNode,
rng: &mut SmallRng,
) -> Vec<ChoiceValue> {
if tree_root.is_exhausted {
return Vec::new();
}
let mut prefix = Vec::new();
let mut current = tree_root;
while let Some(ref kind) = current.kind {
let Some(value) = pick_non_exhausted_value(kind, ¤t.children, rng) else {
break;
};
let key = ChoiceValueKey::from(&value);
let next = current.children.get(&key);
prefix.push(value);
match next {
Some(child) if !child.is_exhausted => current = child,
_ => break,
}
}
prefix
}
pub(crate) fn sub_key(database_key: &[u8], sub: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(database_key.len() + 1 + sub.len());
out.extend_from_slice(database_key);
out.push(b'.');
out.extend_from_slice(sub);
out
}
#[cfg(test)]
#[path = "../../tests/embedded/native/data_tree_tests.rs"]
mod tests;