const DOT: usize = '.' as usize;
type Link<T> = Option<Box<Node<T>>>;
pub struct Node<T> {
pub val: Option<T>,
c: usize,
left: Link<T>,
mid: Link<T>,
right: Link<T>,
}
impl<T> Node<T> {
pub fn new(c: usize) -> Link<T> {
Some(Box::new(Node {
c,
val: None,
left: None,
mid: None,
right: None,
}))
}
fn left_wrap(&self) -> (&Link<T>, bool) {
(&self.left, false)
}
fn right_wrap(&self) -> (&Link<T>, bool) {
(&self.right, false)
}
fn mid_wrap(&self) -> (&Link<T>, bool) {
(&self.mid, true)
}
fn has_val(&self) -> bool {
self.val.is_some()
}
}
#[derive(Default)]
pub struct TST<T> {
root: Link<T>,
}
impl<T> TST<T> {
pub fn get<'a>(&'a self, key: &str) -> Option<&'a Node<T>> {
Self::get_in(&self.root, key, 0)
}
pub fn get_val<'a>(&'a self, key: &str) -> Option<&'a T> {
Self::get_in(&self.root, key, 0).and_then(|f|f.val.as_ref())
}
fn get_in<'a>(root: &'a Link<T>, key: &str, d: usize) -> Option<&'a Node<T>> {
root.as_deref().and_then(|root| {
if d == key.len() {
Some(root)
} else {
let c = Self::char_at(key, d);
match c.partial_cmp(&root.c).unwrap() {
std::cmp::Ordering::Less => Self::get_in(&root.left, key, d),
std::cmp::Ordering::Equal => {
if d < key.len() - 1 {
Self::get_in(&root.mid, key, d + 1)
} else {
Some(root)
}
}
std::cmp::Ordering::Greater => Self::get_in(&root.right, key, d),
}
}
})
}
fn char_at(s: &str, d: usize) -> usize {
s.as_bytes()[d] as usize
}
pub fn put(&mut self, key: &str, val: T) {
self.root = Self::put_in(self.root.take(), key, val, 0)
}
pub fn put_in(root: Link<T>, key: &str, val: T, d: usize) -> Link<T> {
let c = Self::char_at(key, d);
let mut ans = if root.is_none() { Node::new(c) } else { root };
if let Some(root) = &mut ans {
match c.partial_cmp(&root.c).unwrap() {
std::cmp::Ordering::Less => {
root.left = Self::put_in(root.left.take(), key, val, d);
}
std::cmp::Ordering::Equal => {
if d == key.len() - 1 {
root.val.replace(val);
} else {
root.mid = Self::put_in(root.mid.take(), key, val, d + 1)
}
}
std::cmp::Ordering::Greater => {
root.right = Self::put_in(root.right.take(), key, val, d);
}
}
}
ans
}
pub fn keys(&self) -> Vec<String> {
self.keys_with_prefix("")
}
pub fn keys_with_prefix(&self, prefix: &str) -> Vec<String> {
let mut q = vec![];
let root = Self::get_in(&self.root, prefix, 0);
if prefix.is_empty() {
Self::collect(root, prefix.to_string(), &mut q);
} else if let Some(root) = root {
if root.has_val() {
q.push(prefix.to_string());
}
Self::collect(root.mid.as_deref(), prefix.to_string(), &mut q);
}
q
}
pub fn keys_that_match(&self, pat: &str) -> Vec<String> {
let mut q = vec![];
Self::collect_by_match(self.root.as_deref(), "".to_string(), pat, &mut q);
q
}
fn collect_by_match(root: Option<&Node<T>>, prefix: String, pat: &str, q: &mut Vec<String>) {
if let Some(root) = root {
let d = prefix.len();
if d == pat.len() {
return;
}
let char = Self::char_at(pat, d);
let nexts = if DOT == char {
let mut nodes = vec![root.left_wrap(), root.right_wrap()];
if d == pat.len() - 1 {
if root.has_val() {
q.push(format!("{}{}", prefix, char::from(root.c as u8)));
}
} else {
nodes.push(root.mid_wrap());
}
nodes
} else if root.c == char {
if d == pat.len() - 1 {
if root.has_val() {
q.push(format!("{}{}", prefix, char::from(root.c as u8)));
}
vec![]
} else {
vec![root.mid_wrap()]
}
} else if char > root.c {
vec![root.right_wrap()]
} else {
vec![root.left_wrap()]
};
for (node, is_child) in nexts.iter().filter(|f| f.0.is_some()) {
Self::collect_by_match(
node.as_deref(),
if *is_child {
format!("{}{}", prefix, char::from(root.c as u8))
} else {
prefix.to_string()
},
pat,
q,
)
}
}
}
fn collect(root: Option<&Node<T>>, prefix: String, q: &mut Vec<String>) {
if let Some(root) = root {
if root.has_val() {
q.push(format!("{}{}", prefix, char::from(root.c as u8)));
}
for (node, is_next) in [root.left_wrap(), root.right_wrap(), root.mid_wrap()] {
Self::collect(
node.as_deref(),
if is_next {
format!("{}{}", prefix, char::from(root.c as u8))
} else {
prefix.to_string()
},
q,
)
}
}
}
pub fn longest_prefix_of(&self, s: &str) -> String {
let len = Self::search(self.root.as_deref(), s, 0, 0);
s[0..len].to_string()
}
fn search(root: Option<&Node<T>>, s: &str, d: usize, len: usize) -> usize {
root.map(|node| {
let mut len = len;
if d == s.len() {
len
} else {
let c = Self::char_at(s, d);
match c.partial_cmp(&node.c).unwrap() {
std::cmp::Ordering::Less => Self::search(node.left.as_deref(), s, d, len),
std::cmp::Ordering::Equal => {
if node.has_val() {
len = d + 1;
}
Self::search(node.mid.as_deref(), s, d + 1, len)
}
std::cmp::Ordering::Greater => Self::search(node.right.as_deref(), s, d, len),
}
}
})
.unwrap_or(len)
}
pub fn delete(&mut self, key: &str) {
self.root = Self::delete_in(self.root.take(), key, 0);
}
fn delete_in(root: Link<T>, key: &str, d: usize) -> Link<T> {
root.and_then(|mut root| {
if d == key.len() {
} else {
let c = Self::char_at(key, d);
match c.partial_cmp(&root.c).unwrap() {
std::cmp::Ordering::Less => {
root.left = Self::delete_in(root.left, key, d);
}
std::cmp::Ordering::Equal => {
if d == key.len() - 1 {
root.val.take();
} else {
root.mid = Self::delete_in(root.mid, key, d + 1);
}
}
std::cmp::Ordering::Greater => {
root.right = Self::delete_in(root.right, key, d);
}
}
}
if root.val.is_some() {
Some(root)
} else {
let has_child = [&root.left, &root.right, &root.mid]
.iter()
.any(|f| f.is_some());
if has_child {
Some(root)
} else {
None
}
}
})
}
}
#[cfg(test)]
mod test {
use super::TST;
#[test]
fn test() {
let mut st: TST<Option<Option<usize>>> = TST::default();
st.put("by", None);
st.put("sea", None);
st.put("sells", None);
st.put("she", None);
st.put("shells", None);
st.put("shore", None);
st.put("the", None);
println!("{:?}", st.keys());
println!("{:?}", st.keys_with_prefix("b"));
assert!(st.keys_with_prefix("b").contains(&"by".to_string()));
let keys = st.keys_with_prefix("s");
assert!(keys.contains(&"sea".to_string()));
assert!(keys.contains(&"she".to_string()));
assert!(keys.len() == 5);
let keys = st.keys_with_prefix("shells");
assert!(keys.contains(&"shells".to_string()));
assert!(st.get("the").unwrap().val.is_some());
assert!(st.get("b").unwrap().val.is_none());
assert_eq!(st.keys_that_match("by"), vec!["by".to_string()]);
assert_eq!(st.keys_that_match("b.").len(), 1);
assert_eq!(st.keys_that_match("s..").len(), 2);
assert_eq!(st.keys_that_match("s....").len(), 2);
assert_eq!(st.longest_prefix_of("shellsaaa"), "shells".to_string());
assert_eq!(st.longest_prefix_of("by"), "by".to_string());
assert_eq!(st.longest_prefix_of("b"), "".to_string());
st.delete("shells");
assert!(st.keys_that_match("shells").is_empty());
assert!(!st.keys().is_empty());
let mut len = st.keys().len();
for key in st.keys() {
st.delete(&key);
len -= 1;
assert_eq!(st.keys().len(), len);
}
assert!(st.keys().is_empty());
}
}