use crate::agent::Agent;
use crate::base::{NodeOrder, TailMode};
use crate::grimoire::io::{Reader, Writer};
use crate::grimoire::trie::louds_trie::LoudsTrie;
use crate::keyset::Keyset;
pub struct Trie {
trie: Option<Box<LoudsTrie>>,
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
impl Trie {
pub fn new() -> Self {
Trie { trie: None }
}
pub fn build(&mut self, keyset: &mut Keyset, config_flags: i32) {
let mut temp = Box::new(LoudsTrie::new());
temp.build(keyset, config_flags);
self.trie = Some(temp);
}
pub fn mmap(&mut self, filename: &str) -> std::io::Result<()> {
let mut temp = Box::new(LoudsTrie::new());
temp.mmap(filename)?;
self.trie = Some(temp);
Ok(())
}
pub fn map(&mut self, data: &'static [u8]) -> std::io::Result<()> {
let mut temp = Box::new(LoudsTrie::new());
temp.map(data)?;
self.trie = Some(temp);
Ok(())
}
pub fn load(&mut self, filename: &str) -> std::io::Result<()> {
let mut reader = Reader::open(filename)?;
self.read(&mut reader)
}
pub fn read(&mut self, reader: &mut Reader) -> std::io::Result<()> {
let mut temp = Box::new(LoudsTrie::new());
temp.read(reader)?;
self.trie = Some(temp);
Ok(())
}
pub fn save(&self, filename: &str) -> std::io::Result<()> {
if self.trie.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Cannot save empty trie (not built)",
));
}
let mut writer = Writer::open(filename)?;
self.write(&mut writer)
}
pub fn write(&self, writer: &mut Writer) -> std::io::Result<()> {
match self.trie.as_ref() {
Some(trie) => trie.write(writer),
None => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Cannot write empty trie (not built)",
)),
}
}
pub fn lookup(&self, agent: &mut Agent) -> bool {
let trie = self.trie.as_ref().expect("Trie not built");
if !agent.has_state() {
agent
.init_state()
.expect("Failed to initialize agent state");
}
trie.lookup(agent)
}
pub fn reverse_lookup(&self, agent: &mut Agent) {
let trie = self.trie.as_ref().expect("Trie not built");
if !agent.has_state() {
agent
.init_state()
.expect("Failed to initialize agent state");
}
trie.reverse_lookup(agent);
}
pub fn common_prefix_search(&self, agent: &mut Agent) -> bool {
let trie = self.trie.as_ref().expect("Trie not built");
if !agent.has_state() {
agent
.init_state()
.expect("Failed to initialize agent state");
}
trie.common_prefix_search(agent)
}
pub fn predictive_search(&self, agent: &mut Agent) -> bool {
let trie = self.trie.as_ref().expect("Trie not built");
if !agent.has_state() {
agent
.init_state()
.expect("Failed to initialize agent state");
}
trie.predictive_search(agent)
}
pub fn num_tries(&self) -> usize {
let trie = self.trie.as_ref().expect("Trie not built");
trie.num_tries()
}
pub fn num_keys(&self) -> usize {
let trie = self.trie.as_ref().expect("Trie not built");
trie.num_keys()
}
pub fn num_nodes(&self) -> usize {
let trie = self.trie.as_ref().expect("Trie not built");
trie.num_nodes()
}
pub fn tail_mode(&self) -> TailMode {
let trie = self.trie.as_ref().expect("Trie not built");
trie.tail_mode()
}
pub fn node_order(&self) -> NodeOrder {
let trie = self.trie.as_ref().expect("Trie not built");
trie.node_order()
}
pub fn empty(&self) -> bool {
let trie = self.trie.as_ref().expect("Trie not built");
trie.empty()
}
pub fn size(&self) -> usize {
let trie = self.trie.as_ref().expect("Trie not built");
trie.size()
}
pub fn total_size(&self) -> usize {
let trie = self.trie.as_ref().expect("Trie not built");
trie.total_size()
}
pub fn io_size(&self) -> usize {
let trie = self.trie.as_ref().expect("Trie not built");
trie.io_size()
}
pub fn clear(&mut self) {
self.trie = None;
}
pub fn swap(&mut self, other: &mut Trie) {
std::mem::swap(&mut self.trie, &mut other.trie);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trie_new() {
let trie = Trie::new();
assert!(trie.trie.is_none());
}
#[test]
fn test_trie_build() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("apple");
let _ = keyset.push_back_str("banana");
let _ = keyset.push_back_str("cherry");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
assert_eq!(trie.num_keys(), 3);
}
#[test]
fn test_trie_lookup() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("app");
let _ = keyset.push_back_str("apple");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let mut agent = Agent::new();
agent.set_query_str("app");
assert!(trie.lookup(&mut agent), "Should find 'app'");
println!(
"Found app: id={}, str={:?}",
agent.key().id(),
String::from_utf8_lossy(agent.key().as_bytes())
);
agent.set_query_str("apple");
assert!(trie.lookup(&mut agent), "Should find 'apple'");
println!(
"Found apple: id={}, str={:?}",
agent.key().id(),
String::from_utf8_lossy(agent.key().as_bytes())
);
agent.set_query_str("banana");
assert!(!trie.lookup(&mut agent), "Should not find 'banana'");
}
#[test]
fn test_trie_reverse_lookup() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("a");
let _ = keyset.push_back_str("b");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let mut agent = Agent::new();
agent.set_query_id(0);
trie.reverse_lookup(&mut agent);
assert!(agent.key().length() > 0);
}
#[test]
fn test_trie_common_prefix_search() {
{
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("a");
let _ = keyset.push_back_str("ab");
let _ = keyset.push_back_str("abc");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let mut agent = Agent::new();
agent.set_query_str("abc");
let mut count = 0;
while trie.common_prefix_search(&mut agent) {
count += 1;
if count > 10 {
break;
}
}
assert_eq!(
count, 3,
"Expected 3 matches (a, ab, abc) but got {}",
count
);
}
{
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("app");
let _ = keyset.push_back_str("apple");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let mut agent = Agent::new();
agent.set_query_str("application");
assert!(trie.common_prefix_search(&mut agent));
assert_eq!(std::str::from_utf8(agent.key().as_bytes()).unwrap(), "app");
assert!(!trie.common_prefix_search(&mut agent));
}
}
#[test]
fn test_trie_predictive_search() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("a");
let _ = keyset.push_back_str("ab");
let _ = keyset.push_back_str("ac");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let mut agent = Agent::new();
agent.set_query_str("a");
let mut count = 0;
while trie.predictive_search(&mut agent) {
count += 1;
if count > 10 {
break;
} }
assert!(count <= 3);
}
#[test]
fn test_trie_clear() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("test");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
trie.clear();
assert!(trie.trie.is_none());
}
#[test]
fn test_trie_swap() {
let mut keyset1 = Keyset::new();
let _ = keyset1.push_back_str("apple");
let mut trie1 = Trie::new();
trie1.build(&mut keyset1, 0);
let mut keyset2 = Keyset::new();
let _ = keyset2.push_back_str("banana");
let _ = keyset2.push_back_str("cherry");
let mut trie2 = Trie::new();
trie2.build(&mut keyset2, 0);
trie1.swap(&mut trie2);
assert_eq!(trie1.num_keys(), 2);
assert_eq!(trie2.num_keys(), 1);
}
#[test]
fn test_trie_empty() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("test");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
assert!(!trie.empty());
}
#[test]
fn test_trie_sizes() {
let mut keyset = Keyset::new();
let _ = keyset.push_back_str("test");
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
assert!(trie.total_size() > 0);
assert!(trie.io_size() > 0);
}
#[test]
fn test_trie_write_read() {
use crate::grimoire::io::{Reader, Writer};
let mut keyset = Keyset::new();
keyset.push_back_str("app").unwrap();
keyset.push_back_str("apple").unwrap();
keyset.push_back_str("application").unwrap();
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let mut writer = Writer::from_vec(Vec::new());
trie.write(&mut writer).unwrap();
let data = writer.into_inner().unwrap();
let mut reader = Reader::from_bytes(&data);
let mut trie2 = Trie::new();
trie2.read(&mut reader).unwrap();
assert_eq!(trie2.num_keys(), 3);
assert_eq!(trie2.num_nodes(), trie.num_nodes());
let mut agent = Agent::new();
agent.init_state().unwrap();
agent.set_query_str("app");
assert!(trie2.lookup(&mut agent));
agent.set_query_str("apple");
assert!(trie2.lookup(&mut agent));
agent.set_query_str("application");
assert!(trie2.lookup(&mut agent));
}
#[test]
fn test_trie_save_load() {
use std::fs;
use tempfile::NamedTempFile;
let mut keyset = Keyset::new();
keyset.push_back_str("hello").unwrap();
keyset.push_back_str("world").unwrap();
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_str().unwrap();
trie.save(path).unwrap();
let metadata = fs::metadata(path).unwrap();
assert!(metadata.len() > 0);
let mut trie2 = Trie::new();
trie2.load(path).unwrap();
assert_eq!(trie2.num_keys(), 2);
let mut agent = Agent::new();
agent.init_state().unwrap();
agent.set_query_str("hello");
assert!(trie2.lookup(&mut agent));
agent.set_query_str("world");
assert!(trie2.lookup(&mut agent));
}
#[test]
fn test_trie_write_empty_error() {
use crate::grimoire::io::Writer;
let trie = Trie::new();
let mut writer = Writer::from_vec(Vec::new());
let result = trie.write(&mut writer);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::InvalidInput);
}
#[test]
fn test_trie_save_empty_error() {
use tempfile::NamedTempFile;
let trie = Trie::new();
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_str().unwrap();
let result = trie.save(path);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::InvalidInput);
}
#[test]
fn test_trie_read_invalid_header() {
use crate::grimoire::io::Reader;
let invalid_data = vec![0u8; 100]; let mut reader = Reader::from_bytes(&invalid_data);
let mut trie = Trie::new();
let result = trie.read(&mut reader);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn test_trie_mmap() {
use tempfile::NamedTempFile;
let mut keyset = Keyset::new();
keyset.push_back_str("apple").unwrap();
keyset.push_back_str("application").unwrap();
keyset.push_back_str("apply").unwrap();
let mut trie1 = Trie::new();
trie1.build(&mut keyset, 0);
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_str().unwrap();
trie1.save(path).unwrap();
let mut trie2 = Trie::new();
trie2.mmap(path).unwrap();
assert_eq!(trie2.num_keys(), 3);
assert_eq!(trie2.num_nodes(), trie1.num_nodes());
let mut agent = Agent::new();
agent.set_query_str("apple");
assert!(trie2.lookup(&mut agent));
assert_eq!(
std::str::from_utf8(agent.key().as_bytes()).unwrap(),
"apple"
);
agent.set_query_str("application");
assert!(trie2.lookup(&mut agent));
assert_eq!(
std::str::from_utf8(agent.key().as_bytes()).unwrap(),
"application"
);
agent.set_query_str("apply");
assert!(trie2.lookup(&mut agent));
agent.set_query_str("banana");
assert!(!trie2.lookup(&mut agent));
}
#[test]
fn test_trie_mmap_vs_load_equivalence() {
use tempfile::NamedTempFile;
let mut keyset = Keyset::new();
keyset.push_back_str("test1").unwrap();
keyset.push_back_str("test2").unwrap();
keyset.push_back_str("test3").unwrap();
let mut trie = Trie::new();
trie.build(&mut keyset, 0);
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_str().unwrap();
trie.save(path).unwrap();
let mut trie_load = Trie::new();
trie_load.load(path).unwrap();
let mut trie_mmap = Trie::new();
trie_mmap.mmap(path).unwrap();
assert_eq!(trie_load.num_keys(), trie_mmap.num_keys());
assert_eq!(trie_load.num_nodes(), trie_mmap.num_nodes());
let test_keys = ["test1", "test2", "test3", "nonexistent"];
for key in &test_keys {
let mut agent1 = Agent::new();
let mut agent2 = Agent::new();
agent1.set_query_str(key);
agent2.set_query_str(key);
let result1 = trie_load.lookup(&mut agent1);
let result2 = trie_mmap.lookup(&mut agent2);
assert_eq!(result1, result2, "Lookup result mismatch for key: {}", key);
if result1 {
assert_eq!(
agent1.key().as_bytes(),
agent2.key().as_bytes(),
"Key bytes mismatch for key: {}",
key
);
assert_eq!(
agent1.key().id(),
agent2.key().id(),
"Key ID mismatch for key: {}",
key
);
}
}
}
#[test]
fn test_trie_mmap_file_not_found() {
let mut trie = Trie::new();
let result = trie.mmap("/nonexistent/file.marisa");
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound);
}
}