algorithms_fourth 0.1.10

用rust实现算法4书中的算法,作为rust的学习实践
Documentation
//! 基于三向单词查找树的符号表
//! # exception
//!  仅对1字节的字符有效
//! # 使用
//! ```
//!         use algorithms_fourth::search::tst::TST;
//!         let mut st: TST<Option<Option<usize>>> = TST::default();
//!         st.put("by", None);
//!         st.put("sea", None);
//!         st.put("sells", None);
//!         st.put("she", None);
//!         st.put("shells", None);
//!         st.put("shore", None);
//!         st.put("the", None);
//!         println!("{:?}", st.keys());
//!         println!("{:?}", st.keys_with_prefix("b"));
//!         assert!(st.keys_with_prefix("b").contains(&"by".to_string()));
//!         let keys = st.keys_with_prefix("s");
//!         assert!(keys.contains(&"sea".to_string()));
//!         assert!(keys.contains(&"she".to_string()));
//!         assert!(keys.len() == 5);
//!         let keys = st.keys_with_prefix("shells");
//!         assert!(keys.contains(&"shells".to_string()));
//!         assert!(st.get("the").unwrap().val.is_some());
//!         assert!(st.get("b").unwrap().val.is_none());
//!         assert_eq!(st.keys_that_match("by"), vec!["by".to_string()]);
//!         assert_eq!(st.keys_that_match("b.").len(), 1);
//!         assert_eq!(st.keys_that_match("s..").len(), 2);
//!         assert_eq!(st.keys_that_match("s....").len(), 2);
//!         assert_eq!(st.longest_prefix_of("shellsaaa"), "shells".to_string());
//!         assert_eq!(st.longest_prefix_of("by"), "by".to_string());
//!         assert_eq!(st.longest_prefix_of("b"), "".to_string());
//!         st.delete("shells");
//!         assert!(st.keys_that_match("shells").is_empty());
//!         assert!(!st.keys().is_empty());
//!         let mut len = st.keys().len();
//!         for key in st.keys() {
//!             st.delete(&key);
//!             len -= 1;
//!             assert_eq!(st.keys().len(), len);
//!         }
//!         assert!(st.keys().is_empty());
//!
//! ```
const DOT: usize = '.' as usize;
type Link<T> = Option<Box<Node<T>>>;
pub struct Node<T> {
    pub val: Option<T>,
    c: usize,
    left: Link<T>,
    mid: Link<T>,
    right: Link<T>,
}
impl<T> Node<T> {
    pub fn new(c: usize) -> Link<T> {
        Some(Box::new(Node {
            c,
            val: None,
            left: None,
            mid: None,
            right: None,
        }))
    }
    fn left_wrap(&self) -> (&Link<T>, bool) {
        (&self.left, false)
    }
    fn right_wrap(&self) -> (&Link<T>, bool) {
        (&self.right, false)
    }
    fn mid_wrap(&self) -> (&Link<T>, bool) {
        (&self.mid, true)
    }

    fn has_val(&self) -> bool {
        self.val.is_some()
    }
}
#[derive(Default)]
pub struct TST<T> {
    root: Link<T>,
}
impl<T> TST<T> {
    pub fn get<'a>(&'a self, key: &str) -> Option<&'a Node<T>> {
        Self::get_in(&self.root, key, 0)
    }

    pub fn get_val<'a>(&'a self, key: &str) -> Option<&'a T> {
        Self::get_in(&self.root, key, 0).and_then(|f|f.val.as_ref())
    }

    /// 获取前缀对应的节点,无需有值
    fn get_in<'a>(root: &'a Link<T>, key: &str, d: usize) -> Option<&'a Node<T>> {
        root.as_deref().and_then(|root| {
            if d == key.len() {
                Some(root)
            } else {
                let c = Self::char_at(key, d);
                match c.partial_cmp(&root.c).unwrap() {
                    std::cmp::Ordering::Less => Self::get_in(&root.left, key, d),
                    std::cmp::Ordering::Equal => {
                        if d < key.len() - 1 {
                            Self::get_in(&root.mid, key, d + 1)
                        } else {
                            Some(root)
                        }
                    }
                    std::cmp::Ordering::Greater => Self::get_in(&root.right, key, d),
                }
            }
        })
    }
    fn char_at(s: &str, d: usize) -> usize {
        s.as_bytes()[d] as usize
    }
    pub fn put(&mut self, key: &str, val: T) {
        self.root = Self::put_in(self.root.take(), key, val, 0)
    }
    pub fn put_in(root: Link<T>, key: &str, val: T, d: usize) -> Link<T> {
        let c = Self::char_at(key, d);
        let mut ans = if root.is_none() { Node::new(c) } else { root };
        if let Some(root) = &mut ans {
            match c.partial_cmp(&root.c).unwrap() {
                std::cmp::Ordering::Less => {
                    root.left = Self::put_in(root.left.take(), key, val, d);
                }
                std::cmp::Ordering::Equal => {
                    if d == key.len() - 1 {
                        root.val.replace(val);
                    } else {
                        root.mid = Self::put_in(root.mid.take(), key, val, d + 1)
                    }
                }
                std::cmp::Ordering::Greater => {
                    root.right = Self::put_in(root.right.take(), key, val, d);
                }
            }
        }
        ans
    }
    pub fn keys(&self) -> Vec<String> {
        self.keys_with_prefix("")
    }

    pub fn keys_with_prefix(&self, prefix: &str) -> Vec<String> {
        let mut q = vec![];
        let root = Self::get_in(&self.root, prefix, 0);
        if prefix.is_empty() {
            Self::collect(root, prefix.to_string(), &mut q);
        } else if let Some(root) = root {
            if root.has_val() {
                q.push(prefix.to_string());
            }
            Self::collect(root.mid.as_deref(), prefix.to_string(), &mut q);
        }
        q
    }
    pub fn keys_that_match(&self, pat: &str) -> Vec<String> {
        let mut q = vec![];
        Self::collect_by_match(self.root.as_deref(), "".to_string(), pat, &mut q);
        q
    }

    fn collect_by_match(root: Option<&Node<T>>, prefix: String, pat: &str, q: &mut Vec<String>) {
        if let Some(root) = root {
            let d = prefix.len();
            // 长度决定终点
            if d == pat.len() {
                return;
            }
            // d 一定不会越界
            let char = Self::char_at(pat, d);
            let nexts = if DOT == char {
                let mut nodes = vec![root.left_wrap(), root.right_wrap()];
                if d == pat.len() - 1 {
                    // 长度决定终点
                    if root.has_val() {
                        // 数据决定节点有效性
                        q.push(format!("{}{}", prefix, char::from(root.c as u8)));
                    }
                } else {
                    nodes.push(root.mid_wrap());
                }
                nodes
            } else if root.c == char {
                if d == pat.len() - 1 {
                    // 长度决定终点
                    if root.has_val() {
                        // 数据决定节点有效性
                        q.push(format!("{}{}", prefix, char::from(root.c as u8)));
                    }
                    vec![]
                } else {
                    // 匹配路径唯一,子节点需要遍历
                    vec![root.mid_wrap()]
                }
            } else if char > root.c {
                vec![root.right_wrap()]
            } else {
                vec![root.left_wrap()]
            };
            for (node, is_child) in nexts.iter().filter(|f| f.0.is_some()) {
                Self::collect_by_match(
                    node.as_deref(),
                    if *is_child {
                        // 子节点,说明当前字符匹配,需要加入前缀
                        format!("{}{}", prefix, char::from(root.c as u8))
                    } else {
                        // 邻节点,当前字符无用
                        prefix.to_string()
                    },
                    pat,
                    q,
                )
            }
        }
    }
    fn collect(root: Option<&Node<T>>, prefix: String, q: &mut Vec<String>) {
        if let Some(root) = root {
            if root.has_val() {
                q.push(format!("{}{}", prefix, char::from(root.c as u8)));
            }
            for (node, is_next) in [root.left_wrap(), root.right_wrap(), root.mid_wrap()] {
                Self::collect(
                    node.as_deref(),
                    if is_next {
                        format!("{}{}", prefix, char::from(root.c as u8))
                    } else {
                        prefix.to_string()
                    },
                    q,
                )
            }
        }
    }
    pub fn longest_prefix_of(&self, s: &str) -> String {
        let len = Self::search(self.root.as_deref(), s, 0, 0);
        s[0..len].to_string()
    }

    fn search(root: Option<&Node<T>>, s: &str, d: usize, len: usize) -> usize {
        root.map(|node| {
            let mut len = len;
            if d == s.len() {
                len
            } else {
                let c = Self::char_at(s, d);
                match c.partial_cmp(&node.c).unwrap() {
                    std::cmp::Ordering::Less => Self::search(node.left.as_deref(), s, d, len),
                    std::cmp::Ordering::Equal => {
                        if node.has_val() {
                            len = d + 1;
                        }
                        Self::search(node.mid.as_deref(), s, d + 1, len)
                    }
                    std::cmp::Ordering::Greater => Self::search(node.right.as_deref(), s, d, len),
                }
            }
        })
        .unwrap_or(len)
    }
    pub fn delete(&mut self, key: &str) {
        self.root = Self::delete_in(self.root.take(), key, 0);
    }
    /// 如果root有next为非空,不能删,
    /// 否则,说明root当前为无用的leaf,可以删除
    fn delete_in(root: Link<T>, key: &str, d: usize) -> Link<T> {
        root.and_then(|mut root| {
            if d == key.len() {
            } else {
                let c = Self::char_at(key, d);
                match c.partial_cmp(&root.c).unwrap() {
                    std::cmp::Ordering::Less => {
                        root.left = Self::delete_in(root.left, key, d);
                    }
                    std::cmp::Ordering::Equal => {
                        if d == key.len() - 1 {
                            root.val.take();
                        } else {
                            root.mid = Self::delete_in(root.mid, key, d + 1);
                        }
                    }
                    std::cmp::Ordering::Greater => {
                        root.right = Self::delete_in(root.right, key, d);
                    }
                }
            }
            if root.val.is_some() {
                Some(root)
            } else {
                let has_child = [&root.left, &root.right, &root.mid]
                    .iter()
                    .any(|f| f.is_some());
                if has_child {
                    Some(root)
                } else {
                    None
                }
            }
        })
    }
}
#[cfg(test)]
mod test {

    use super::TST;

    #[test]
    fn test() {
        let mut st: TST<Option<Option<usize>>> = TST::default();
        st.put("by", None);
        st.put("sea", None);
        st.put("sells", None);
        st.put("she", None);
        st.put("shells", None);
        st.put("shore", None);
        st.put("the", None);
        println!("{:?}", st.keys());
        println!("{:?}", st.keys_with_prefix("b"));
        assert!(st.keys_with_prefix("b").contains(&"by".to_string()));
        let keys = st.keys_with_prefix("s");
        assert!(keys.contains(&"sea".to_string()));
        assert!(keys.contains(&"she".to_string()));
        assert!(keys.len() == 5);
        let keys = st.keys_with_prefix("shells");
        assert!(keys.contains(&"shells".to_string()));
        assert!(st.get("the").unwrap().val.is_some());
        assert!(st.get("b").unwrap().val.is_none());
        assert_eq!(st.keys_that_match("by"), vec!["by".to_string()]);
        assert_eq!(st.keys_that_match("b.").len(), 1);
        assert_eq!(st.keys_that_match("s..").len(), 2);
        assert_eq!(st.keys_that_match("s....").len(), 2);
        assert_eq!(st.longest_prefix_of("shellsaaa"), "shells".to_string());
        assert_eq!(st.longest_prefix_of("by"), "by".to_string());
        assert_eq!(st.longest_prefix_of("b"), "".to_string());
        st.delete("shells");
        assert!(st.keys_that_match("shells").is_empty());
        assert!(!st.keys().is_empty());
        let mut len = st.keys().len();
        for key in st.keys() {
            st.delete(&key);
            len -= 1;
            assert_eq!(st.keys().len(), len);
        }
        assert!(st.keys().is_empty());
    }
}