yada_mod 0.4.0

Yada is a yet another double-array trie library aiming for fast search and compact data representation. This fork added a tokenization function
Documentation
pub mod builder;
pub mod unit;

use crate::unit::{Unit, UnitID, UNIT_SIZE};
use std::convert::TryInto;
use std::ops::Deref;

/// A double array trie.
#[derive(Clone)]
pub struct DoubleArray<T>(pub T)
where
    T: Deref<Target = [u8]>;

impl<T> DoubleArray<T>
where
    T: Deref<Target = [u8]>,
{
	/// function that gets the highest scoring (sum of values) combination of tokens
	pub fn get_all_tokens(&self, input: &str)-> Vec<(String, usize, usize, u32)>{
		let mut candidates: Vec<(String, usize, usize, u32)> = Vec::new();

		let mut offsets: Vec<usize> = input.match_indices(' ').map(|s| s.0+1).collect();
		offsets.push(0);
		offsets.rotate_right(1);
		for offset in offsets{
			let word  = self.common_prefix_search(&input[offset..]).last().unwrap_or((0, 0));
			let candidate = (input[offset..offset+word.1].to_string(), offset, offset+word.1,word.0);
			if candidate.2 > candidates.last().unwrap_or(&("".to_string(), 0, 0, 0)).2 && word.1!=0 && input.chars().nth(offset+word.1).unwrap_or(' ')== ' '{
				candidates.push(candidate);
			}
		}
		let mut offset: usize = 0;
		let mut to_add: Vec<(String, usize, usize, u32)> = Vec::new();
		if candidates.len() == 1 || candidates.len() == 0{

			for candidate in &candidates{
				if &input[offset..candidate.1] != "" && &input[offset..candidate.1] != " "{
					if offset == 0{
						to_add.push((input[offset..candidate.1-1].to_string(), offset, candidate.1-1, 0))
					}
					else{
						to_add.push((input[offset+1..candidate.1-1].to_string(), offset+1, candidate.1-1, 0))
					}
				}
				offset = candidate.2;
			}
			if &input[offset..input.len()] != "" && &input[offset..input.len()] != " "{
				to_add.push((input[offset+1..input.len()].to_string(), offset+1, input.len(), 0))
			}
			// println!("{:?}", to_add);
			candidates.append(&mut to_add);
			candidates.sort_by_key(|a| a.1);
			return candidates;
		}
		let mut tokens_to_remove: Vec<usize> = Vec::new();
		let windows_iter = candidates.windows(3);
		for (i, window) in windows_iter.enumerate(){
			if window[1].3 < window[0].3 + window[2].3 && window[0].2 > window[1].1 && window[1].2 > window[2].1 && !tokens_to_remove.contains(&i){
				tokens_to_remove.push(i+1);
			}
			else{
				if window[0].3 >= window[1].3 && window[0].2 > window[1].1{
					tokens_to_remove.push(i+1);
				}
				if window[1].3 > window[0].3 && window[0].2 > window[1].1{
					tokens_to_remove.push(i);
				}
			}
		}
		tokens_to_remove.dedup();
		// println!("candidates: {:?}	{:?}", candidates, tokens_to_remove);
		for index in tokens_to_remove.iter().rev(){
			candidates.remove(*index);
		}
		if candidates[candidates.len()-1].1 < candidates[candidates.len()-2].2{
			if candidates[candidates.len()-2].3 >= candidates[candidates.len()-1].3{
				candidates.remove(candidates.len()-1);
			}
			else{
				candidates.remove(candidates.len()-2);
			}
		}
		for candidate in &candidates{
			if &input[offset..candidate.1] != "" && &input[offset..candidate.1] != " "{
				if offset == 0{
					to_add.push((input[offset..candidate.1-1].to_string(), offset, candidate.1-1, 0))
				}
				else{
					to_add.push((input[offset+1..candidate.1-1].to_string(), offset+1, candidate.1-1, 0))
				}
			}
			offset = candidate.2;
		}
		if &input[offset..input.len()] != "" && &input[offset..input.len()] != " "{
			to_add.push((input[offset+1..input.len()].to_string(), offset+1, input.len(), 0))
		}
		// println!("{:?}", to_add);
		candidates.append(&mut to_add);
		candidates.sort_by_key(|a| a.1);
		return candidates;
	}
	/// Creates a new `DoubleArray` with a byte slice.
    pub fn new(bytes: T) -> Self {
        Self { 0: bytes }
    }

    /// Finds a value associated with a `key`.
    pub fn exact_match_search<K>(&self, key: K) -> Option<u32>
    where
        K: AsRef<[u8]>,
    {
        self.exact_match_search_bytes(key.as_ref())
    }

    fn exact_match_search_bytes(&self, key: &[u8]) -> Option<u32> {
        // traverse from root node
        let mut node_pos = 0 as UnitID;
        let mut unit = self.get_unit(node_pos)?;

        for &c in key.iter().take(key.len()) {
            assert!(!unit.is_leaf());
            assert_ne!(c, 0); // assumes characters don't have NULL ('\0')

            // try to traverse node
            node_pos = (unit.offset() ^ node_pos as u32 ^ c as u32) as UnitID;
            unit = self.get_unit(node_pos)?;

            if c != unit.label() as u8 {
                return None;
            }
        }

        if !unit.has_leaf() {
            return None;
        }

        // traverse node by NULL ('\0')
        let node_pos = (unit.offset() ^ node_pos as u32 ^ 0u32) as UnitID;
        unit = self.get_unit(node_pos)?;
        assert!(unit.is_leaf());
        assert!(unit.value() < (1 << 31));

        Some(unit.value())
    }

    /// Finds all values and it's key length which have a common prefix with a `key`.
    pub fn common_prefix_search<'b, K>(
        &'b self,
        key: &'b K,
    ) -> impl Iterator<Item = (u32, usize)> + 'b
    where
        K: AsRef<[u8]>,
        K: ?Sized,
    {
        self.common_prefix_search_bytes(key.as_ref())
    }

    fn common_prefix_search_bytes<'b>(
        &'b self,
        key: &'b [u8],
    ) -> impl Iterator<Item = (u32, usize)> + 'b {
        CommonPrefixSearch {
            key,
            double_array: self,
            unit_id: 0,
            key_pos: 0,
        }
    }

    fn get_unit(&self, index: usize) -> Option<Unit> {
        let b = unsafe {
            // This unsafe method call does not lead unexpected transitions
            // when a double array was built properly.
            self.0
                .get_unchecked(index * UNIT_SIZE..(index + 1) * UNIT_SIZE)
        };
        match b.try_into() {
            Ok(bytes) => Some(Unit::from_u32(u32::from_le_bytes(bytes))),
            Err(_) => None,
        }
    }
}

/// An iterator that finds all values with a common prefix.
pub struct CommonPrefixSearch<'k, 'd, T>
where
    T: Deref<Target = [u8]>,
{
    key: &'k [u8],
    double_array: &'d DoubleArray<T>,
    unit_id: UnitID,
    key_pos: usize,
}

impl<T> Iterator for CommonPrefixSearch<'_, '_, T>
where
    T: Deref<Target = [u8]>,
{
    type Item = (u32, usize);

    fn next(&mut self) -> Option<Self::Item> {
        while self.key_pos < self.key.len() {
            let unit = self.double_array.get_unit(self.unit_id)?;

            let c = *self.key.get(self.key_pos)?;
            self.key_pos += 1;

            self.unit_id = (unit.offset() ^ self.unit_id as u32 ^ c as u32) as UnitID;
            let unit = self.double_array.get_unit(self.unit_id)?;
            if unit.label() != c as u32 {
                return None;
            }
            if unit.has_leaf() {
                let leaf_pos = unit.offset() ^ self.unit_id as u32;
                let leaf_unit = self.double_array.get_unit(leaf_pos as UnitID)?;
                return Some((leaf_unit.value(), self.key_pos));
            }
        }
        None
    }
}

#[cfg(test)]
mod tests {
    use crate::builder::DoubleArrayBuilder;
    use crate::DoubleArray;

    #[test]
    fn test_build_search() {
        let keyset = &[
            ("a".as_bytes(), 0),
            ("ab".as_bytes(), 1),
            ("aba".as_bytes(), 2),
            ("ac".as_bytes(), 3),
            ("acb".as_bytes(), 4),
            ("acc".as_bytes(), 5),
            ("ad".as_bytes(), 6),
            ("ba".as_bytes(), 7),
            ("bb".as_bytes(), 8),
            ("bc".as_bytes(), 9),
            ("c".as_bytes(), 10),
            ("caa".as_bytes(), 11),
        ];

        let da_bytes = DoubleArrayBuilder::build(keyset);
        assert!(da_bytes.is_some());

        let da = DoubleArray::new(da_bytes.unwrap());

        for (key, value) in keyset {
            assert_eq!(da.exact_match_search(key), Some(*value as u32));
        }
        assert_eq!(da.exact_match_search("aa".as_bytes()), None);
        assert_eq!(da.exact_match_search("abc".as_bytes()), None);
        assert_eq!(da.exact_match_search("b".as_bytes()), None);
        assert_eq!(da.exact_match_search("ca".as_bytes()), None);

        assert_eq!(
            da.common_prefix_search("a".as_bytes()).collect::<Vec<_>>(),
            vec![(0, 1)]
        );
        assert_eq!(
            da.common_prefix_search("aa".as_bytes()).collect::<Vec<_>>(),
            vec![(0, 1)]
        );
        assert_eq!(
            da.common_prefix_search("abbb".as_bytes())
                .collect::<Vec<_>>(),
            vec![(0, 1), (1, 2)]
        );
        assert_eq!(
            da.common_prefix_search("abaa".as_bytes())
                .collect::<Vec<_>>(),
            vec![(0, 1), (1, 2), (2, 3)]
        );
        assert_eq!(
            da.common_prefix_search("caa".as_bytes())
                .collect::<Vec<_>>(),
            vec![(10, 1), (11, 3)]
        );
        assert_eq!(
            da.common_prefix_search("d".as_bytes()).collect::<Vec<_>>(),
            vec![]
        );
    }

    #[test]
    fn test_clone_and_search() {
        let keyset = &[
            ("a".as_bytes(), 0),
            ("ab".as_bytes(), 1),
            ("aba".as_bytes(), 2),
            ("ac".as_bytes(), 3),
            ("acb".as_bytes(), 4),
            ("acc".as_bytes(), 5),
            ("ad".as_bytes(), 6),
            ("ba".as_bytes(), 7),
            ("bb".as_bytes(), 8),
            ("bc".as_bytes(), 9),
            ("c".as_bytes(), 10),
            ("caa".as_bytes(), 11),
        ];

        let da_bytes = DoubleArrayBuilder::build(keyset);
        assert!(da_bytes.is_some());

        let da_orig = DoubleArray::new(da_bytes.unwrap());
        let da = da_orig.clone();

        for (key, value) in keyset {
            assert_eq!(da.exact_match_search(key), Some(*value as u32));
        }
        assert_eq!(da.exact_match_search("aa".as_bytes()), None);
        assert_eq!(da.exact_match_search("abc".as_bytes()), None);
        assert_eq!(da.exact_match_search("b".as_bytes()), None);
        assert_eq!(da.exact_match_search("ca".as_bytes()), None);

        assert_eq!(
            da.common_prefix_search("a".as_bytes()).collect::<Vec<_>>(),
            vec![(0, 1)]
        );
        assert_eq!(
            da.common_prefix_search("aa".as_bytes()).collect::<Vec<_>>(),
            vec![(0, 1)]
        );
        assert_eq!(
            da.common_prefix_search("abbb".as_bytes())
                .collect::<Vec<_>>(),
            vec![(0, 1), (1, 2)]
        );
        assert_eq!(
            da.common_prefix_search("abaa".as_bytes())
                .collect::<Vec<_>>(),
            vec![(0, 1), (1, 2), (2, 3)]
        );
        assert_eq!(
            da.common_prefix_search("caa".as_bytes())
                .collect::<Vec<_>>(),
            vec![(10, 1), (11, 3)]
        );
        assert_eq!(
            da.common_prefix_search("d".as_bytes()).collect::<Vec<_>>(),
            vec![]
        );
    }
}