Skip to main content

marisa/
key.rs

1//! Key type for trie operations.
2//!
3//! Ported from: include/marisa/key.h
4//!
5//! This is the public API Key type, distinct from grimoire::trie::Key
6//! which is used internally.
7
8use std::fmt;
9
10/// Union type to hold either an ID or a weight.
11#[derive(Clone, Copy)]
12union KeyUnion {
13    /// Key ID for indexed access.
14    id: u32,
15    /// Weight for weighted keys.
16    weight: f32,
17}
18
19impl Default for KeyUnion {
20    fn default() -> Self {
21        KeyUnion { id: 0 }
22    }
23}
24
25/// Key represents a dictionary key with its string and metadata.
26///
27/// A key contains:
28/// - A string (ptr + length)
29/// - Either an ID or a weight (union)
30#[derive(Clone)]
31pub struct Key {
32    /// Pointer to key string data (borrowed).
33    ptr: Option<*const u8>,
34    /// Length of key string.
35    length: u32,
36    /// Union holding either ID or weight.
37    union: KeyUnion,
38}
39
40// Manual Debug implementation since raw pointers don't implement Debug
41impl fmt::Debug for Key {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("Key")
44            .field("ptr", &self.ptr.map(|_| "..."))
45            .field("length", &self.length)
46            .field("id_or_weight", &unsafe { self.union.id })
47            .finish()
48    }
49}
50
51impl Default for Key {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl Key {
58    /// Creates a new empty key.
59    pub fn new() -> Self {
60        Key {
61            ptr: None,
62            length: 0,
63            union: KeyUnion::default(),
64        }
65    }
66
67    /// Returns the character at the specified index.
68    ///
69    /// # Panics
70    ///
71    /// Panics if index is out of bounds.
72    pub fn get(&self, i: usize) -> u8 {
73        assert!((i as u32) < self.length, "Index out of bounds");
74        if let Some(ptr) = self.ptr {
75            unsafe { *ptr.add(i) }
76        } else {
77            panic!("Key has no string data");
78        }
79    }
80
81    /// Sets the key from a string slice.
82    pub fn set_str(&mut self, s: &str) {
83        assert!(s.len() <= u32::MAX as usize, "String too long");
84        self.ptr = Some(s.as_ptr());
85        self.length = s.len() as u32;
86    }
87
88    /// Sets the key from a byte slice.
89    pub fn set_bytes(&mut self, bytes: &[u8]) {
90        assert!(bytes.len() <= u32::MAX as usize, "Bytes too long");
91        if bytes.is_empty() {
92            self.ptr = None;
93            self.length = 0;
94        } else {
95            self.ptr = Some(bytes.as_ptr());
96            self.length = bytes.len() as u32;
97        }
98    }
99
100    /// Sets the key ID.
101    pub fn set_id(&mut self, id: usize) {
102        assert!(id <= u32::MAX as usize, "ID too large");
103        self.union = KeyUnion { id: id as u32 };
104    }
105
106    /// Sets the key weight.
107    pub fn set_weight(&mut self, weight: f32) {
108        self.union = KeyUnion { weight };
109    }
110
111    /// Returns the key as a byte slice.
112    ///
113    /// Returns an empty slice if no string is set.
114    pub fn as_bytes(&self) -> &[u8] {
115        if let Some(ptr) = self.ptr {
116            unsafe { std::slice::from_raw_parts(ptr, self.length as usize) }
117        } else {
118            &[]
119        }
120    }
121
122    /// Returns the key string as a str reference.
123    ///
124    /// # Panics
125    ///
126    /// Panics if the key contains invalid UTF-8.
127    pub fn as_str(&self) -> &str {
128        std::str::from_utf8(self.as_bytes()).expect("Invalid UTF-8 in key")
129    }
130
131    /// Returns a pointer to the key data.
132    pub fn ptr(&self) -> Option<*const u8> {
133        self.ptr
134    }
135
136    /// Returns the length of the key string.
137    pub fn length(&self) -> usize {
138        self.length as usize
139    }
140
141    /// Returns the key ID.
142    ///
143    /// # Safety
144    ///
145    /// This accesses the union as an ID. The caller must ensure
146    /// that set_id() was called more recently than set_weight().
147    pub fn id(&self) -> usize {
148        unsafe { self.union.id as usize }
149    }
150
151    /// Returns the key weight.
152    ///
153    /// # Safety
154    ///
155    /// This accesses the union as a weight. The caller must ensure
156    /// that set_weight() was called more recently than set_id().
157    pub fn weight(&self) -> f32 {
158        unsafe { self.union.weight }
159    }
160
161    /// Clears the key to empty state.
162    pub fn clear(&mut self) {
163        *self = Key::new();
164    }
165
166    /// Swaps with another key.
167    pub fn swap(&mut self, other: &mut Key) {
168        std::mem::swap(self, other);
169    }
170}
171
172// Safety: Key only holds a pointer that must remain valid for its lifetime.
173// The user is responsible for ensuring the borrowed data outlives the Key.
174unsafe impl Send for Key {}
175unsafe impl Sync for Key {}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_key_new() {
183        let key = Key::new();
184        assert_eq!(key.length(), 0);
185        assert_eq!(key.id(), 0);
186        assert_eq!(key.as_bytes(), &[]);
187    }
188
189    #[test]
190    fn test_key_default() {
191        let key = Key::default();
192        assert_eq!(key.length(), 0);
193    }
194
195    #[test]
196    fn test_key_set_str() {
197        let s = "hello";
198        let mut key = Key::new();
199        key.set_str(s);
200
201        assert_eq!(key.length(), 5);
202        assert_eq!(key.as_str(), "hello");
203        assert_eq!(key.as_bytes(), b"hello");
204    }
205
206    #[test]
207    fn test_key_set_bytes() {
208        let bytes = b"world";
209        let mut key = Key::new();
210        key.set_bytes(bytes);
211
212        assert_eq!(key.length(), 5);
213        assert_eq!(key.as_bytes(), b"world");
214    }
215
216    #[test]
217    fn test_key_set_empty_bytes() {
218        let mut key = Key::new();
219        key.set_str("test");
220        key.set_bytes(&[]);
221
222        assert_eq!(key.length(), 0);
223        assert_eq!(key.as_bytes(), &[]);
224    }
225
226    #[test]
227    fn test_key_get() {
228        let s = "test";
229        let mut key = Key::new();
230        key.set_str(s);
231
232        assert_eq!(key.get(0), b't');
233        assert_eq!(key.get(1), b'e');
234        assert_eq!(key.get(2), b's');
235        assert_eq!(key.get(3), b't');
236    }
237
238    #[test]
239    #[should_panic(expected = "Index out of bounds")]
240    fn test_key_get_out_of_bounds() {
241        let s = "test";
242        let mut key = Key::new();
243        key.set_str(s);
244        key.get(4);
245    }
246
247    #[test]
248    fn test_key_set_id() {
249        let mut key = Key::new();
250        key.set_id(42);
251        assert_eq!(key.id(), 42);
252    }
253
254    #[test]
255    fn test_key_set_weight() {
256        let mut key = Key::new();
257        key.set_weight(3.15);
258        assert!((key.weight() - 3.15).abs() < 0.001);
259    }
260
261    #[test]
262    fn test_key_id_weight_union() {
263        let mut key = Key::new();
264
265        // Set as ID
266        key.set_id(100);
267        assert_eq!(key.id(), 100);
268
269        // Overwrite with weight
270        key.set_weight(2.5);
271        assert!((key.weight() - 2.5).abs() < 0.001);
272    }
273
274    #[test]
275    fn test_key_clear() {
276        let mut key = Key::new();
277        key.set_str("test");
278        key.set_id(10);
279
280        key.clear();
281
282        assert_eq!(key.length(), 0);
283        assert_eq!(key.id(), 0);
284        assert_eq!(key.as_bytes(), &[]);
285    }
286
287    #[test]
288    fn test_key_swap() {
289        let s1 = "hello";
290        let s2 = "world";
291
292        let mut k1 = Key::new();
293        k1.set_str(s1);
294        k1.set_id(1);
295
296        let mut k2 = Key::new();
297        k2.set_str(s2);
298        k2.set_id(2);
299
300        k1.swap(&mut k2);
301
302        assert_eq!(k1.as_str(), "world");
303        assert_eq!(k1.id(), 2);
304        assert_eq!(k2.as_str(), "hello");
305        assert_eq!(k2.id(), 1);
306    }
307
308    #[test]
309    fn test_key_with_unicode() {
310        let s = "こんにちは";
311        let mut key = Key::new();
312        key.set_str(s);
313
314        assert_eq!(key.length(), s.len());
315        assert_eq!(key.as_str(), "こんにちは");
316    }
317
318    // Note: Cannot safely test set_str with length > u32::MAX
319    // as creating such a string would require invalid operations.
320}