use super::Segmentable;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use core::marker::PhantomData;
use core::{fmt, mem};
#[derive(Debug, Default)]
pub struct Builder<T, Type> {
nodes: Vec<Node<T>>,
ty: PhantomData<Type>,
}
impl<'a, T, Type: GraphType<'a>> Builder<T, Type> {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
ty: PhantomData,
}
}
pub fn add(&mut self, mut key: String, value: T) -> Result<(), AddError<T>> {
if key.is_empty() {
return Err(AddError::Empty(value));
}
if !Type::validate(&mut key) {
return Err(AddError::Invalid(key, value));
}
let mut node = Node {
value: key,
output: Some(value),
children: Vec::new(),
};
let mut siblings = &mut self.nodes;
loop {
let closest_node = siblings.iter_mut().enumerate().find_map(|(i, sibling)| {
let prefix = prefix(&node.value, &sibling.value);
if !prefix.is_empty() {
Some((i, prefix))
} else {
None
}
});
let (index, prefix) = match closest_node {
Some(result) => result,
None => {
siblings.push(node);
return Ok(());
}
};
if prefix == siblings[index].value || prefix == node.value {
let prefix_len = prefix.len();
if node.value == siblings[index].value {
if siblings[index].output.is_none() {
siblings[index].output = node.output;
return Ok(());
}
return Err(AddError::Duplicate(node.value, node.output.unwrap()));
}
if prefix == node.value {
mem::swap(&mut node, &mut siblings[index]);
}
siblings = &mut siblings[index].children;
node.value = node.value[prefix_len..].to_string();
continue;
}
let mut sibling = siblings.swap_remove(index);
let prefix = prefix.to_string();
node.value = node.value[prefix.len()..].to_string();
sibling.value = sibling.value[prefix.len()..].to_string();
let prefix_node = Node {
value: prefix,
output: None,
children: vec![sibling, node],
};
siblings.push(prefix_node);
return Ok(());
}
}
pub fn build<'nodes>(
&'a mut self,
node_buffer: &'nodes mut Vec<super::Node<'a, Type::InputKey, Option<T>>>,
) -> super::Graph<'a, 'nodes, Type::InputKey, Option<T>>
where
T: Clone,
{
node_buffer.clear();
shorten_children(&mut self.nodes);
self.nodes.sort_unstable_by(|a, b| a.value.cmp(&b.value));
for node in &mut self.nodes {
node.normalize();
}
node_buffer.push(super::Node {
inputs: crate::MaybeSlice::Slice(&[]),
output: None,
default: 0,
amount: core::usize::MAX,
});
let initial_indices = self
.nodes
.iter()
.map(|node| {
let index = node.build::<Type>(node_buffer);
let value = Type::key(&node.value);
(value, index)
})
.collect::<Vec<_>>();
let amount = initial_indices.first().map_or(1, |(key, _)| key.len());
let root = super::Node {
inputs: crate::MaybeSlice::Vec(initial_indices),
output: None,
default: 0,
amount,
};
node_buffer.push(root);
let end = node_buffer.len() - 1;
super::Graph::new(&*node_buffer, end)
}
}
#[derive(Debug)]
struct Node<T> {
value: String,
output: Option<T>,
children: Vec<Node<T>>,
}
impl<T: Clone> Node<T> {
fn normalize(&mut self) {
shorten_children(&mut self.children);
self.children.sort_by(|a, b| a.value.cmp(&b.value));
for child in &mut self.children {
child.normalize();
}
}
#[allow(clippy::mem_replace_with_default)]
fn shorten(&mut self, len: usize) {
if self.value.len() > len {
let new_value = self.value.split_off(len);
let new_node = Node {
value: new_value,
output: self.output.take(),
children: mem::replace(&mut self.children, vec![]),
};
self.children.push(new_node);
}
}
fn build<'a, 'nodes, Type: GraphType<'a>>(
&'a self,
nodes: &'nodes mut Vec<super::Node<'a, Type::InputKey, Option<T>>>,
) -> usize {
let child_indices = self
.children
.iter()
.map(|child| {
let index = child.build::<Type>(nodes);
let value = Type::key(&child.value);
(value, index)
})
.collect::<Vec<_>>();
let amount = child_indices.first().map_or(1, |(key, _)| key.len());
let node_index = nodes.len();
nodes.push(super::Node {
inputs: crate::MaybeSlice::Vec(child_indices),
output: self.output.clone(),
default: 0,
amount,
});
node_index
}
}
fn shorten_children<T: Clone>(children: &mut [Node<T>]) {
let shortest = children
.iter()
.map(|child| child.value.len())
.min()
.unwrap_or(0);
for child in children {
child.shorten(shortest);
}
}
pub trait GraphType<'a> {
type InputKey: super::Segmentable + 'a;
fn validate(input: &mut str) -> bool;
fn key(input: &'a str) -> Self::InputKey;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Utf8Graph;
impl<'a> GraphType<'a> for Utf8Graph {
type InputKey = &'a str;
fn validate(_: &mut str) -> bool {
true
}
fn key(input: &'a str) -> Self::InputKey {
input
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct AsciiGraph;
impl<'a> GraphType<'a> for AsciiGraph {
type InputKey = &'a [u8];
fn validate(input: &mut str) -> bool {
input.is_ascii()
}
fn key(input: &'a str) -> Self::InputKey {
input.as_bytes()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct IgnoreCase<Graph>(core::marker::PhantomData<Graph>);
impl<'a, G: GraphType<'a>> GraphType<'a> for IgnoreCase<G>
where
G::InputKey: AsRef<[u8]>,
{
type InputKey = super::CaseInsensitive<G::InputKey>;
fn validate(input: &mut str) -> bool {
input.make_ascii_lowercase();
G::validate(input)
}
fn key(input: &'a str) -> Self::InputKey {
super::CaseInsensitive(G::key(input))
}
}
#[derive(Debug)]
pub enum AddError<T> {
Empty(T),
Invalid(String, T),
Duplicate(String, T),
}
impl<T: fmt::Display> fmt::Display for AddError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AddError::Empty(value) => write!(f, "Cannot add an empty key to the graph: {}", value),
AddError::Invalid(key, value) => write!(
f,
"Cannot add an invalid key to the graph: {} ({})",
key, value
),
AddError::Duplicate(key, value) => write!(
f,
"Cannot add a duplicate key to the graph: {} ({})",
key, value
),
}
}
}
#[cfg(feature = "std")]
impl<T: fmt::Debug + fmt::Display> std::error::Error for AddError<T> {}
fn prefix<'a>(a: &'a str, b: &str) -> &'a str {
let mut i = 0;
for (a, b) in a.chars().zip(b.chars()) {
if a != b {
break;
}
i += a.len_utf8();
}
&a[..i]
}