1use std::ffi::{CStr, CString};
38use std::os::raw::c_int;
39
40mod ffi {
41 use std::ffi::c_void as FILE;
42 use std::os::raw::{c_char, c_int};
43
44 #[repr(C)]
45 pub struct MarisaTrie {
46 _private: [u8; 0],
47 }
48
49 #[repr(C)]
50 pub struct MarisaAgent {
51 _private: [u8; 0],
52 }
53
54 #[repr(C)]
55 pub struct MarisaKeyset {
56 _private: [u8; 0],
57 }
58
59 pub type MarisaId = u32;
60
61 #[repr(C)]
62 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
63 pub enum MarisaError {
64 Ok = 0,
65 StateError = 1,
66 NullError = 2,
67 BoundError = 3,
68 RangeError = 4,
69 CodeError = 5,
70 ResetError = 6,
71 SizeError = 7,
72 MemoryError = 8,
73 IoError = 9,
74 FormatError = 10,
75 }
76
77 extern "C" {
78 pub fn marisa_create() -> *mut MarisaTrie;
80 pub fn marisa_destroy(trie: *mut MarisaTrie);
81
82 pub fn marisa_open(trie: *mut MarisaTrie, filename: *const c_char) -> c_int;
83 pub fn marisa_save(trie: *const MarisaTrie, filename: *const c_char) -> c_int;
84 pub fn marisa_write(trie: *const MarisaTrie, file: *mut FILE) -> c_int;
85 pub fn marisa_read(trie: *mut MarisaTrie, file: *mut FILE) -> c_int;
86 pub fn marisa_map(trie: *mut MarisaTrie, filename: *const c_char) -> c_int;
87 pub fn marisa_unmap(trie: *mut MarisaTrie) -> c_int;
88
89 pub fn marisa_build(trie: *mut MarisaTrie, keyset: *mut MarisaKeyset) -> c_int;
90 pub fn marisa_build_trie(
91 trie: *mut MarisaTrie,
92 keyset: *mut MarisaKeyset,
93 trie_mode: c_int,
94 ) -> c_int;
95
96 pub fn marisa_lookup(
98 trie: *const MarisaTrie,
99 key: *const c_char,
100 length: usize,
101 id: *mut MarisaId,
102 ) -> c_int;
103 pub fn marisa_predictive_search(
104 trie: *const MarisaTrie,
105 ptr: *const c_char,
106 length: usize,
107 agent: *mut MarisaAgent,
108 ) -> c_int;
109 pub fn marisa_reverse_lookup(
110 trie: *const MarisaTrie,
111 id: MarisaId,
112 agent: *mut MarisaAgent,
113 ) -> c_int;
114 pub fn marisa_common_prefix_search(
115 trie: *const MarisaTrie,
116 ptr: *const c_char,
117 length: usize,
118 agent: *mut MarisaAgent,
119 ) -> c_int;
120
121 pub fn marisa_agent_create() -> *mut MarisaAgent;
123 pub fn marisa_agent_destroy(agent: *mut MarisaAgent);
124 pub fn marisa_agent_key(agent: *const MarisaAgent) -> *const c_char;
125 pub fn marisa_agent_key_length(agent: *const MarisaAgent) -> usize;
126 pub fn marisa_agent_id(agent: *const MarisaAgent) -> MarisaId;
127 pub fn marisa_agent_next(agent: *mut MarisaAgent) -> c_int;
128
129 pub fn marisa_keyset_create() -> *mut MarisaKeyset;
131 pub fn marisa_keyset_destroy(keyset: *mut MarisaKeyset);
132 pub fn marisa_keyset_push(
133 keyset: *mut MarisaKeyset,
134 key: *const c_char,
135 length: usize,
136 ) -> c_int;
137 pub fn marisa_keyset_push_back(
138 keyset: *mut MarisaKeyset,
139 key: *const c_char,
140 length: usize,
141 id: MarisaId,
142 ) -> c_int;
143 pub fn marisa_keyset_reset(keyset: *mut MarisaKeyset);
144
145 pub fn marisa_strerror(err: c_int) -> *const c_char;
147 pub fn marisa_version() -> *const c_char;
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum MarisaError {
156 StateError,
157 NullError,
158 BoundError,
159 RangeError,
160 CodeError,
161 ResetError,
162 SizeError,
163 MemoryError,
164 IoError,
165 FormatError,
166 KeyNotFound,
167 NoResults,
168 Utf8Error(std::str::Utf8Error),
169}
170
171impl std::fmt::Display for MarisaError {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 match self {
174 MarisaError::StateError => write!(f, "State error"),
175 MarisaError::NullError => write!(f, "Null pointer error"),
176 MarisaError::BoundError => write!(f, "Bound error"),
177 MarisaError::RangeError => write!(f, "Range error"),
178 MarisaError::CodeError => write!(f, "Code error"),
179 MarisaError::ResetError => write!(f, "Reset error"),
180 MarisaError::SizeError => write!(f, "Size error"),
181 MarisaError::MemoryError => write!(f, "Memory error"),
182 MarisaError::IoError => write!(f, "I/O error"),
183 MarisaError::FormatError => write!(f, "Format error"),
184 MarisaError::KeyNotFound => write!(f, "Key not found"),
185 MarisaError::NoResults => write!(f, "No results"),
186 MarisaError::Utf8Error(e) => write!(f, "UTF-8 error: {}", e),
187 }
188 }
189}
190
191impl std::error::Error for MarisaError {}
192
193impl From<MarisaError> for ffi::MarisaError {
194 fn from(err: MarisaError) -> Self {
195 match err {
196 MarisaError::StateError => ffi::MarisaError::StateError,
197 MarisaError::NullError => ffi::MarisaError::NullError,
198 MarisaError::BoundError => ffi::MarisaError::BoundError,
199 MarisaError::RangeError => ffi::MarisaError::RangeError,
200 MarisaError::CodeError => ffi::MarisaError::CodeError,
201 MarisaError::ResetError => ffi::MarisaError::ResetError,
202 MarisaError::SizeError => ffi::MarisaError::SizeError,
203 MarisaError::MemoryError => ffi::MarisaError::MemoryError,
204 MarisaError::IoError => ffi::MarisaError::IoError,
205 MarisaError::FormatError => ffi::MarisaError::FormatError,
206 _ => ffi::MarisaError::FormatError,
207 }
208 }
209}
210
211impl From<ffi::MarisaError> for MarisaError {
212 fn from(err: ffi::MarisaError) -> Self {
213 match err {
214 ffi::MarisaError::Ok => unreachable!(),
215 ffi::MarisaError::StateError => MarisaError::StateError,
216 ffi::MarisaError::NullError => MarisaError::NullError,
217 ffi::MarisaError::BoundError => MarisaError::BoundError,
218 ffi::MarisaError::RangeError => MarisaError::RangeError,
219 ffi::MarisaError::CodeError => MarisaError::CodeError,
220 ffi::MarisaError::ResetError => MarisaError::ResetError,
221 ffi::MarisaError::SizeError => MarisaError::SizeError,
222 ffi::MarisaError::MemoryError => MarisaError::MemoryError,
223 ffi::MarisaError::IoError => MarisaError::IoError,
224 ffi::MarisaError::FormatError => MarisaError::FormatError,
225 }
226 }
227}
228
229pub struct Trie {
231 ptr: *mut ffi::MarisaTrie,
232}
233
234impl Trie {
235 pub fn new() -> Result<Self, MarisaError> {
237 let ptr = unsafe { ffi::marisa_create() };
238 if ptr.is_null() {
239 Err(MarisaError::MemoryError)
240 } else {
241 Ok(Trie { ptr })
242 }
243 }
244
245 pub fn build(keyset: &Keyset) -> Result<Self, MarisaError> {
247 let trie = Self::new()?;
248 let result = unsafe { ffi::marisa_build(trie.ptr, keyset.ptr) };
249 if result == ffi::MarisaError::Ok as c_int {
250 Ok(trie)
251 } else {
252 Err(MarisaError::from(unsafe {
253 std::mem::transmute::<c_int, ffi::MarisaError>(result)
254 }))
255 }
256 }
257
258 pub fn load(&mut self, filename: &str) -> Result<(), MarisaError> {
260 let c_filename = CString::new(filename).map_err(|_| MarisaError::FormatError)?;
261 let result = unsafe { ffi::marisa_open(self.ptr, c_filename.as_ptr()) };
262 if result == ffi::MarisaError::Ok as c_int {
263 Ok(())
264 } else {
265 Err(MarisaError::from(unsafe {
266 std::mem::transmute::<c_int, ffi::MarisaError>(result)
267 }))
268 }
269 }
270
271 pub fn save(&self, filename: &str) -> Result<(), MarisaError> {
273 let c_filename = CString::new(filename).map_err(|_| MarisaError::FormatError)?;
274 let result = unsafe { ffi::marisa_save(self.ptr, c_filename.as_ptr()) };
275 if result == ffi::MarisaError::Ok as c_int {
276 Ok(())
277 } else {
278 Err(MarisaError::from(unsafe {
279 std::mem::transmute::<c_int, ffi::MarisaError>(result)
280 }))
281 }
282 }
283
284 pub fn lookup(&self, key: &str) -> Option<u32> {
286 let c_key = CString::new(key).ok()?;
287 let mut id = 0;
288 let result = unsafe { ffi::marisa_lookup(self.ptr, c_key.as_ptr(), key.len(), &mut id) };
289 if result == ffi::MarisaError::Ok as c_int {
290 Some(id)
291 } else {
292 None
293 }
294 }
295
296 pub fn predictive_search(&self, prefix: &str, agent: &mut Agent) -> Result<bool, MarisaError> {
298 let c_prefix = CString::new(prefix).map_err(|_| MarisaError::FormatError)?;
299 let result = unsafe {
300 ffi::marisa_predictive_search(self.ptr, c_prefix.as_ptr(), prefix.len(), agent.ptr)
301 };
302 match result {
303 x if x == ffi::MarisaError::Ok as c_int => Ok(true),
304 x if x == ffi::MarisaError::FormatError as c_int => Ok(false),
305 _ => Err(MarisaError::from(unsafe {
306 std::mem::transmute::<c_int, ffi::MarisaError>(result)
307 })),
308 }
309 }
310
311 pub fn reverse_lookup(&self, id: u32, agent: &mut Agent) -> Result<(), MarisaError> {
313 let result = unsafe { ffi::marisa_reverse_lookup(self.ptr, id, agent.ptr) };
314 if result == ffi::MarisaError::Ok as c_int {
315 Ok(())
316 } else {
317 Err(MarisaError::from(unsafe {
318 std::mem::transmute::<c_int, ffi::MarisaError>(result)
319 }))
320 }
321 }
322
323 pub fn common_prefix_search(
325 &self,
326 prefix: &str,
327 agent: &mut Agent,
328 ) -> Result<bool, MarisaError> {
329 let c_prefix = CString::new(prefix).map_err(|_| MarisaError::FormatError)?;
330 let result = unsafe {
331 ffi::marisa_common_prefix_search(self.ptr, c_prefix.as_ptr(), prefix.len(), agent.ptr)
332 };
333 match result {
334 x if x == ffi::MarisaError::Ok as c_int => Ok(true),
335 x if x == ffi::MarisaError::FormatError as c_int => Ok(false),
336 _ => Err(MarisaError::from(unsafe {
337 std::mem::transmute::<c_int, ffi::MarisaError>(result)
338 })),
339 }
340 }
341}
342
343impl Drop for Trie {
344 fn drop(&mut self) {
345 if !self.ptr.is_null() {
346 unsafe { ffi::marisa_destroy(self.ptr) };
347 }
348 }
349}
350
351pub struct Keyset {
353 ptr: *mut ffi::MarisaKeyset,
354}
355
356impl Keyset {
357 pub fn new() -> Result<Self, MarisaError> {
359 let ptr = unsafe { ffi::marisa_keyset_create() };
360 if ptr.is_null() {
361 Err(MarisaError::MemoryError)
362 } else {
363 Ok(Keyset { ptr })
364 }
365 }
366
367 pub fn push(&mut self, key: &str) -> Result<(), MarisaError> {
369 let c_key = CString::new(key).map_err(|_| MarisaError::FormatError)?;
370 let result = unsafe { ffi::marisa_keyset_push(self.ptr, c_key.as_ptr(), key.len()) };
371 if result == ffi::MarisaError::Ok as c_int {
372 Ok(())
373 } else {
374 Err(MarisaError::from(unsafe {
375 std::mem::transmute::<c_int, ffi::MarisaError>(result)
376 }))
377 }
378 }
379
380 pub fn push_with_id(&mut self, key: &str, id: u32) -> Result<(), MarisaError> {
382 let c_key = CString::new(key).map_err(|_| MarisaError::FormatError)?;
383 let result =
384 unsafe { ffi::marisa_keyset_push_back(self.ptr, c_key.as_ptr(), key.len(), id) };
385 if result == ffi::MarisaError::Ok as c_int {
386 Ok(())
387 } else {
388 Err(MarisaError::from(unsafe {
389 std::mem::transmute::<c_int, ffi::MarisaError>(result)
390 }))
391 }
392 }
393
394 pub fn reset(&mut self) {
396 unsafe { ffi::marisa_keyset_reset(self.ptr) };
397 }
398}
399
400impl Drop for Keyset {
401 fn drop(&mut self) {
402 if !self.ptr.is_null() {
403 unsafe { ffi::marisa_keyset_destroy(self.ptr) };
404 }
405 }
406}
407
408pub struct Agent {
410 ptr: *mut ffi::MarisaAgent,
411}
412
413impl Agent {
414 pub fn new() -> Result<Self, MarisaError> {
416 let ptr = unsafe { ffi::marisa_agent_create() };
417 if ptr.is_null() {
418 Err(MarisaError::MemoryError)
419 } else {
420 Ok(Agent { ptr })
421 }
422 }
423
424 pub fn key(&self) -> Result<String, MarisaError> {
426 let c_str = unsafe { ffi::marisa_agent_key(self.ptr) };
427 if c_str.is_null() {
428 return Err(MarisaError::NullError);
429 }
430 let length = unsafe { ffi::marisa_agent_key_length(self.ptr) };
431 let slice = unsafe { std::slice::from_raw_parts(c_str as *const u8, length) };
432 std::str::from_utf8(slice)
433 .map(|s| s.to_string())
434 .map_err(MarisaError::Utf8Error)
435 }
436
437 pub fn id(&self) -> u32 {
439 unsafe { ffi::marisa_agent_id(self.ptr) }
440 }
441
442 pub fn next(&mut self) -> Result<bool, MarisaError> {
444 let result = unsafe { ffi::marisa_agent_next(self.ptr) };
445 match result {
446 x if x == ffi::MarisaError::Ok as c_int => Ok(true),
447 x if x == ffi::MarisaError::StateError as c_int => Ok(false),
448 _ => Err(MarisaError::from(unsafe {
449 std::mem::transmute::<c_int, ffi::MarisaError>(result)
450 })),
451 }
452 }
453}
454
455impl Drop for Agent {
456 fn drop(&mut self) {
457 if !self.ptr.is_null() {
458 unsafe { ffi::marisa_agent_destroy(self.ptr) };
459 }
460 }
461}
462
463pub fn version() -> String {
465 unsafe {
466 let c_str = ffi::marisa_version();
467 if c_str.is_null() {
468 "unknown".to_string()
469 } else {
470 CStr::from_ptr(c_str).to_string_lossy().to_string()
471 }
472 }
473}
474
475pub fn strerror(err: MarisaError) -> String {
477 let err_code = ffi::MarisaError::from(err) as c_int;
478 unsafe {
479 let c_str = ffi::marisa_strerror(err_code);
480 if c_str.is_null() {
481 "unknown error".to_string()
482 } else {
483 CStr::from_ptr(c_str).to_string_lossy().to_string()
484 }
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_basic_functionality() {
494 let mut keyset = Keyset::new().unwrap();
496 keyset.push("hello").unwrap();
497 keyset.push("world").unwrap();
498 keyset.push("rust").unwrap();
499
500 let trie = Trie::build(&keyset).unwrap();
502
503 assert!(trie.lookup("hello").is_some());
505 assert!(trie.lookup("world").is_some());
506 assert!(trie.lookup("rust").is_some());
507 assert!(trie.lookup("nonexistent").is_none());
508
509 let mut agent = Agent::new().unwrap();
511 assert!(trie.predictive_search("h", &mut agent).unwrap());
512 let key = agent.key().unwrap();
513 assert!(key == "hello" || key == "rust"); }
515
516 #[test]
517 fn test_version() {
518 let version = version();
519 assert!(!version.is_empty());
520 println!("Marisa version: {}", version);
521 }
522}