marisa_ffi/
lib.rs

1//! # Marisa FFI Rust Bindings
2//!
3//! This crate provides Rust bindings for the libmarisa library, a space-efficient trie data structure.
4//!
5//! ## Features
6//!
7//! - Lookup: Check whether a string exists in the dictionary
8//! - Reverse lookup: Restore a key from its ID
9//! - Common prefix search: Find keys from prefixes of a given string
10//! - Predictive search: Find keys starting with a given string
11//!
12//! ## Example
13//!
14//! ```rust
15//! use marisa_ffi::{Trie, Keyset, Agent};
16//!
17//! // Create a keyset and add some keys
18//! let mut keyset = Keyset::new().unwrap();
19//! keyset.push("hello").unwrap();
20//! keyset.push("world").unwrap();
21//!
22//! // Build the trie
23//! let trie = Trie::build(&keyset).unwrap();
24//!
25//! // Lookup a key
26//! if let Some(id) = trie.lookup("hello") {
27//!     println!("Found 'hello' with ID: {}", id);
28//! }
29//!
30//! // Predictive search
31//! let mut agent = Agent::new().unwrap();
32//! if trie.predictive_search("h", &mut agent).unwrap() {
33//!     println!("Found key: {}", agent.key().unwrap());
34//! }
35//! ```
36
37use std::ffi::{CStr, CString};
38use std::os::raw::c_int;
39
40mod ffi {
41    use std::ffi::c_void as FILE;
42    use std::os::raw::{c_char, c_int};
43
44    #[repr(C)]
45    pub struct MarisaTrie {
46        _private: [u8; 0],
47    }
48
49    #[repr(C)]
50    pub struct MarisaAgent {
51        _private: [u8; 0],
52    }
53
54    #[repr(C)]
55    pub struct MarisaKeyset {
56        _private: [u8; 0],
57    }
58
59    pub type MarisaId = u32;
60
61    #[repr(C)]
62    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
63    pub enum MarisaError {
64        Ok = 0,
65        StateError = 1,
66        NullError = 2,
67        BoundError = 3,
68        RangeError = 4,
69        CodeError = 5,
70        ResetError = 6,
71        SizeError = 7,
72        MemoryError = 8,
73        IoError = 9,
74        FormatError = 10,
75    }
76
77    extern "C" {
78        // Core trie functions
79        pub fn marisa_create() -> *mut MarisaTrie;
80        pub fn marisa_destroy(trie: *mut MarisaTrie);
81
82        pub fn marisa_open(trie: *mut MarisaTrie, filename: *const c_char) -> c_int;
83        pub fn marisa_save(trie: *const MarisaTrie, filename: *const c_char) -> c_int;
84        pub fn marisa_write(trie: *const MarisaTrie, file: *mut FILE) -> c_int;
85        pub fn marisa_read(trie: *mut MarisaTrie, file: *mut FILE) -> c_int;
86        pub fn marisa_map(trie: *mut MarisaTrie, filename: *const c_char) -> c_int;
87        pub fn marisa_unmap(trie: *mut MarisaTrie) -> c_int;
88
89        pub fn marisa_build(trie: *mut MarisaTrie, keyset: *mut MarisaKeyset) -> c_int;
90        pub fn marisa_build_trie(
91            trie: *mut MarisaTrie,
92            keyset: *mut MarisaKeyset,
93            trie_mode: c_int,
94        ) -> c_int;
95
96        // Lookup functions
97        pub fn marisa_lookup(
98            trie: *const MarisaTrie,
99            key: *const c_char,
100            length: usize,
101            id: *mut MarisaId,
102        ) -> c_int;
103        pub fn marisa_predictive_search(
104            trie: *const MarisaTrie,
105            ptr: *const c_char,
106            length: usize,
107            agent: *mut MarisaAgent,
108        ) -> c_int;
109        pub fn marisa_reverse_lookup(
110            trie: *const MarisaTrie,
111            id: MarisaId,
112            agent: *mut MarisaAgent,
113        ) -> c_int;
114        pub fn marisa_common_prefix_search(
115            trie: *const MarisaTrie,
116            ptr: *const c_char,
117            length: usize,
118            agent: *mut MarisaAgent,
119        ) -> c_int;
120
121        // Agent functions
122        pub fn marisa_agent_create() -> *mut MarisaAgent;
123        pub fn marisa_agent_destroy(agent: *mut MarisaAgent);
124        pub fn marisa_agent_key(agent: *const MarisaAgent) -> *const c_char;
125        pub fn marisa_agent_key_length(agent: *const MarisaAgent) -> usize;
126        pub fn marisa_agent_id(agent: *const MarisaAgent) -> MarisaId;
127        pub fn marisa_agent_next(agent: *mut MarisaAgent) -> c_int;
128
129        // Keyset functions
130        pub fn marisa_keyset_create() -> *mut MarisaKeyset;
131        pub fn marisa_keyset_destroy(keyset: *mut MarisaKeyset);
132        pub fn marisa_keyset_push(
133            keyset: *mut MarisaKeyset,
134            key: *const c_char,
135            length: usize,
136        ) -> c_int;
137        pub fn marisa_keyset_push_back(
138            keyset: *mut MarisaKeyset,
139            key: *const c_char,
140            length: usize,
141            id: MarisaId,
142        ) -> c_int;
143        pub fn marisa_keyset_reset(keyset: *mut MarisaKeyset);
144
145        // Utility functions
146        pub fn marisa_strerror(err: c_int) -> *const c_char;
147        pub fn marisa_version() -> *const c_char;
148    }
149}
150
151// Use specific imports instead of wildcard
152
153/// Error type for Marisa operations
154#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum MarisaError {
156    StateError,
157    NullError,
158    BoundError,
159    RangeError,
160    CodeError,
161    ResetError,
162    SizeError,
163    MemoryError,
164    IoError,
165    FormatError,
166    KeyNotFound,
167    NoResults,
168    Utf8Error(std::str::Utf8Error),
169}
170
171impl std::fmt::Display for MarisaError {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        match self {
174            MarisaError::StateError => write!(f, "State error"),
175            MarisaError::NullError => write!(f, "Null pointer error"),
176            MarisaError::BoundError => write!(f, "Bound error"),
177            MarisaError::RangeError => write!(f, "Range error"),
178            MarisaError::CodeError => write!(f, "Code error"),
179            MarisaError::ResetError => write!(f, "Reset error"),
180            MarisaError::SizeError => write!(f, "Size error"),
181            MarisaError::MemoryError => write!(f, "Memory error"),
182            MarisaError::IoError => write!(f, "I/O error"),
183            MarisaError::FormatError => write!(f, "Format error"),
184            MarisaError::KeyNotFound => write!(f, "Key not found"),
185            MarisaError::NoResults => write!(f, "No results"),
186            MarisaError::Utf8Error(e) => write!(f, "UTF-8 error: {}", e),
187        }
188    }
189}
190
191impl std::error::Error for MarisaError {}
192
193impl From<MarisaError> for ffi::MarisaError {
194    fn from(err: MarisaError) -> Self {
195        match err {
196            MarisaError::StateError => ffi::MarisaError::StateError,
197            MarisaError::NullError => ffi::MarisaError::NullError,
198            MarisaError::BoundError => ffi::MarisaError::BoundError,
199            MarisaError::RangeError => ffi::MarisaError::RangeError,
200            MarisaError::CodeError => ffi::MarisaError::CodeError,
201            MarisaError::ResetError => ffi::MarisaError::ResetError,
202            MarisaError::SizeError => ffi::MarisaError::SizeError,
203            MarisaError::MemoryError => ffi::MarisaError::MemoryError,
204            MarisaError::IoError => ffi::MarisaError::IoError,
205            MarisaError::FormatError => ffi::MarisaError::FormatError,
206            _ => ffi::MarisaError::FormatError,
207        }
208    }
209}
210
211impl From<ffi::MarisaError> for MarisaError {
212    fn from(err: ffi::MarisaError) -> Self {
213        match err {
214            ffi::MarisaError::Ok => unreachable!(),
215            ffi::MarisaError::StateError => MarisaError::StateError,
216            ffi::MarisaError::NullError => MarisaError::NullError,
217            ffi::MarisaError::BoundError => MarisaError::BoundError,
218            ffi::MarisaError::RangeError => MarisaError::RangeError,
219            ffi::MarisaError::CodeError => MarisaError::CodeError,
220            ffi::MarisaError::ResetError => MarisaError::ResetError,
221            ffi::MarisaError::SizeError => MarisaError::SizeError,
222            ffi::MarisaError::MemoryError => MarisaError::MemoryError,
223            ffi::MarisaError::IoError => MarisaError::IoError,
224            ffi::MarisaError::FormatError => MarisaError::FormatError,
225        }
226    }
227}
228
229/// A Marisa trie for efficient string storage and lookup
230pub struct Trie {
231    ptr: *mut ffi::MarisaTrie,
232}
233
234impl Trie {
235    /// Create a new empty trie
236    pub fn new() -> Result<Self, MarisaError> {
237        let ptr = unsafe { ffi::marisa_create() };
238        if ptr.is_null() {
239            Err(MarisaError::MemoryError)
240        } else {
241            Ok(Trie { ptr })
242        }
243    }
244
245    /// Build a trie from a keyset
246    pub fn build(keyset: &Keyset) -> Result<Self, MarisaError> {
247        let trie = Self::new()?;
248        let result = unsafe { ffi::marisa_build(trie.ptr, keyset.ptr) };
249        if result == ffi::MarisaError::Ok as c_int {
250            Ok(trie)
251        } else {
252            Err(MarisaError::from(unsafe {
253                std::mem::transmute::<c_int, ffi::MarisaError>(result)
254            }))
255        }
256    }
257
258    /// Load a trie from a file
259    pub fn load(&mut self, filename: &str) -> Result<(), MarisaError> {
260        let c_filename = CString::new(filename).map_err(|_| MarisaError::FormatError)?;
261        let result = unsafe { ffi::marisa_open(self.ptr, c_filename.as_ptr()) };
262        if result == ffi::MarisaError::Ok as c_int {
263            Ok(())
264        } else {
265            Err(MarisaError::from(unsafe {
266                std::mem::transmute::<c_int, ffi::MarisaError>(result)
267            }))
268        }
269    }
270
271    /// Save the trie to a file
272    pub fn save(&self, filename: &str) -> Result<(), MarisaError> {
273        let c_filename = CString::new(filename).map_err(|_| MarisaError::FormatError)?;
274        let result = unsafe { ffi::marisa_save(self.ptr, c_filename.as_ptr()) };
275        if result == ffi::MarisaError::Ok as c_int {
276            Ok(())
277        } else {
278            Err(MarisaError::from(unsafe {
279                std::mem::transmute::<c_int, ffi::MarisaError>(result)
280            }))
281        }
282    }
283
284    /// Lookup a key and return its ID if found
285    pub fn lookup(&self, key: &str) -> Option<u32> {
286        let c_key = CString::new(key).ok()?;
287        let mut id = 0;
288        let result = unsafe { ffi::marisa_lookup(self.ptr, c_key.as_ptr(), key.len(), &mut id) };
289        if result == ffi::MarisaError::Ok as c_int {
290            Some(id)
291        } else {
292            None
293        }
294    }
295
296    /// Perform predictive search starting with the given prefix
297    pub fn predictive_search(&self, prefix: &str, agent: &mut Agent) -> Result<bool, MarisaError> {
298        let c_prefix = CString::new(prefix).map_err(|_| MarisaError::FormatError)?;
299        let result = unsafe {
300            ffi::marisa_predictive_search(self.ptr, c_prefix.as_ptr(), prefix.len(), agent.ptr)
301        };
302        match result {
303            x if x == ffi::MarisaError::Ok as c_int => Ok(true),
304            x if x == ffi::MarisaError::FormatError as c_int => Ok(false),
305            _ => Err(MarisaError::from(unsafe {
306                std::mem::transmute::<c_int, ffi::MarisaError>(result)
307            })),
308        }
309    }
310
311    /// Perform reverse lookup to get a key by its ID
312    pub fn reverse_lookup(&self, id: u32, agent: &mut Agent) -> Result<(), MarisaError> {
313        let result = unsafe { ffi::marisa_reverse_lookup(self.ptr, id, agent.ptr) };
314        if result == ffi::MarisaError::Ok as c_int {
315            Ok(())
316        } else {
317            Err(MarisaError::from(unsafe {
318                std::mem::transmute::<c_int, ffi::MarisaError>(result)
319            }))
320        }
321    }
322
323    /// Perform common prefix search
324    pub fn common_prefix_search(
325        &self,
326        prefix: &str,
327        agent: &mut Agent,
328    ) -> Result<bool, MarisaError> {
329        let c_prefix = CString::new(prefix).map_err(|_| MarisaError::FormatError)?;
330        let result = unsafe {
331            ffi::marisa_common_prefix_search(self.ptr, c_prefix.as_ptr(), prefix.len(), agent.ptr)
332        };
333        match result {
334            x if x == ffi::MarisaError::Ok as c_int => Ok(true),
335            x if x == ffi::MarisaError::FormatError as c_int => Ok(false),
336            _ => Err(MarisaError::from(unsafe {
337                std::mem::transmute::<c_int, ffi::MarisaError>(result)
338            })),
339        }
340    }
341}
342
343impl Drop for Trie {
344    fn drop(&mut self) {
345        if !self.ptr.is_null() {
346            unsafe { ffi::marisa_destroy(self.ptr) };
347        }
348    }
349}
350
351/// A keyset for building tries
352pub struct Keyset {
353    ptr: *mut ffi::MarisaKeyset,
354}
355
356impl Keyset {
357    /// Create a new empty keyset
358    pub fn new() -> Result<Self, MarisaError> {
359        let ptr = unsafe { ffi::marisa_keyset_create() };
360        if ptr.is_null() {
361            Err(MarisaError::MemoryError)
362        } else {
363            Ok(Keyset { ptr })
364        }
365    }
366
367    /// Add a key to the keyset
368    pub fn push(&mut self, key: &str) -> Result<(), MarisaError> {
369        let c_key = CString::new(key).map_err(|_| MarisaError::FormatError)?;
370        let result = unsafe { ffi::marisa_keyset_push(self.ptr, c_key.as_ptr(), key.len()) };
371        if result == ffi::MarisaError::Ok as c_int {
372            Ok(())
373        } else {
374            Err(MarisaError::from(unsafe {
375                std::mem::transmute::<c_int, ffi::MarisaError>(result)
376            }))
377        }
378    }
379
380    /// Add a key with a specific ID to the keyset
381    pub fn push_with_id(&mut self, key: &str, id: u32) -> Result<(), MarisaError> {
382        let c_key = CString::new(key).map_err(|_| MarisaError::FormatError)?;
383        let result =
384            unsafe { ffi::marisa_keyset_push_back(self.ptr, c_key.as_ptr(), key.len(), id) };
385        if result == ffi::MarisaError::Ok as c_int {
386            Ok(())
387        } else {
388            Err(MarisaError::from(unsafe {
389                std::mem::transmute::<c_int, ffi::MarisaError>(result)
390            }))
391        }
392    }
393
394    /// Reset the keyset
395    pub fn reset(&mut self) {
396        unsafe { ffi::marisa_keyset_reset(self.ptr) };
397    }
398}
399
400impl Drop for Keyset {
401    fn drop(&mut self) {
402        if !self.ptr.is_null() {
403            unsafe { ffi::marisa_keyset_destroy(self.ptr) };
404        }
405    }
406}
407
408/// An agent for search operations
409pub struct Agent {
410    ptr: *mut ffi::MarisaAgent,
411}
412
413impl Agent {
414    /// Create a new agent
415    pub fn new() -> Result<Self, MarisaError> {
416        let ptr = unsafe { ffi::marisa_agent_create() };
417        if ptr.is_null() {
418            Err(MarisaError::MemoryError)
419        } else {
420            Ok(Agent { ptr })
421        }
422    }
423
424    /// Get the current key as a string
425    pub fn key(&self) -> Result<String, MarisaError> {
426        let c_str = unsafe { ffi::marisa_agent_key(self.ptr) };
427        if c_str.is_null() {
428            return Err(MarisaError::NullError);
429        }
430        let length = unsafe { ffi::marisa_agent_key_length(self.ptr) };
431        let slice = unsafe { std::slice::from_raw_parts(c_str as *const u8, length) };
432        std::str::from_utf8(slice)
433            .map(|s| s.to_string())
434            .map_err(MarisaError::Utf8Error)
435    }
436
437    /// Get the current key ID
438    pub fn id(&self) -> u32 {
439        unsafe { ffi::marisa_agent_id(self.ptr) }
440    }
441
442    /// Move to the next result
443    pub fn next(&mut self) -> Result<bool, MarisaError> {
444        let result = unsafe { ffi::marisa_agent_next(self.ptr) };
445        match result {
446            x if x == ffi::MarisaError::Ok as c_int => Ok(true),
447            x if x == ffi::MarisaError::StateError as c_int => Ok(false),
448            _ => Err(MarisaError::from(unsafe {
449                std::mem::transmute::<c_int, ffi::MarisaError>(result)
450            })),
451        }
452    }
453}
454
455impl Drop for Agent {
456    fn drop(&mut self) {
457        if !self.ptr.is_null() {
458            unsafe { ffi::marisa_agent_destroy(self.ptr) };
459        }
460    }
461}
462
463/// Get the version string of the Marisa library
464pub fn version() -> String {
465    unsafe {
466        let c_str = ffi::marisa_version();
467        if c_str.is_null() {
468            "unknown".to_string()
469        } else {
470            CStr::from_ptr(c_str).to_string_lossy().to_string()
471        }
472    }
473}
474
475/// Get a human-readable error message
476pub fn strerror(err: MarisaError) -> String {
477    let err_code = ffi::MarisaError::from(err) as c_int;
478    unsafe {
479        let c_str = ffi::marisa_strerror(err_code);
480        if c_str.is_null() {
481            "unknown error".to_string()
482        } else {
483            CStr::from_ptr(c_str).to_string_lossy().to_string()
484        }
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_basic_functionality() {
494        // Create a keyset and add some keys
495        let mut keyset = Keyset::new().unwrap();
496        keyset.push("hello").unwrap();
497        keyset.push("world").unwrap();
498        keyset.push("rust").unwrap();
499
500        // Build the trie
501        let trie = Trie::build(&keyset).unwrap();
502
503        // Test lookup
504        assert!(trie.lookup("hello").is_some());
505        assert!(trie.lookup("world").is_some());
506        assert!(trie.lookup("rust").is_some());
507        assert!(trie.lookup("nonexistent").is_none());
508
509        // Test predictive search
510        let mut agent = Agent::new().unwrap();
511        assert!(trie.predictive_search("h", &mut agent).unwrap());
512        let key = agent.key().unwrap();
513        assert!(key == "hello" || key == "rust"); // Depends on internal ordering
514    }
515
516    #[test]
517    fn test_version() {
518        let version = version();
519        assert!(!version.is_empty());
520        println!("Marisa version: {}", version);
521    }
522}