1use matchy_match_mode::MatchMode;
34use rayon::prelude::*;
35use rustc_hash::FxHashMap;
36use std::mem;
37use xxhash_rust::xxh3::xxh3_128;
38
39pub mod validation;
40
41pub use validation::{validate_literal_hash, LiteralHashStats, LiteralHashValidationResult};
42
43#[derive(Debug, Clone)]
44pub enum LiteralHashError {
45 InvalidFormat(String),
46}
47
48impl std::fmt::Display for LiteralHashError {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 Self::InvalidFormat(msg) => write!(f, "Invalid literal hash format: {msg}"),
52 }
53 }
54}
55
56impl std::error::Error for LiteralHashError {}
57
58pub const LITERAL_HASH_MAGIC: &[u8; 4] = b"LHSH";
59pub const MATCHY_LITERAL_HASH_VERSION: u32 = 3;
60
61const HEADER_SIZE: usize = 32;
62const EMPTY_HASH_LO: u64 = 0xFFFF_FFFF_FFFF_FFFF;
63const EMPTY_HASH_HI: u32 = 0xFFFF_FFFF;
64
65#[repr(C)]
66#[derive(Debug, Clone, Copy)]
67pub struct LiteralHashHeader {
68 pub magic: [u8; 4],
69 pub version: u32,
70 pub entry_count: u32,
71 pub table_size: u32,
72 pub num_shards: u32,
73 pub shard_bits: u32,
74 pub mappings_offset: u32,
75 pub table_offset: u32, }
77
78#[repr(C)]
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub struct HashEntry {
81 pub hash_lo: u64,
82 pub hash_hi: u32,
83 pub pattern_id: u32,
84}
85
86impl HashEntry {
87 const fn empty() -> Self {
88 Self {
89 hash_lo: EMPTY_HASH_LO,
90 hash_hi: EMPTY_HASH_HI,
91 pattern_id: 0,
92 }
93 }
94
95 const fn is_empty(&self) -> bool {
96 self.hash_lo == EMPTY_HASH_LO && self.hash_hi == EMPTY_HASH_HI
97 }
98}
99
100struct Shard {
101 table: Vec<HashEntry>,
102 shard_id: usize,
103}
104
105pub struct LiteralHashBuilder {
106 patterns: Vec<(u64, u32, u32)>, mode: MatchMode,
108}
109
110impl LiteralHashBuilder {
111 #[must_use]
112 pub fn new(mode: MatchMode) -> Self {
113 Self {
114 patterns: Vec::new(),
115 mode,
116 }
117 }
118
119 pub fn add_pattern(&mut self, pattern: &str, pattern_id: u32) {
120 let normalized = match self.mode {
121 MatchMode::CaseSensitive => pattern.to_string(),
122 MatchMode::CaseInsensitive => pattern.to_lowercase(),
123 };
124 let (hash_lo, hash_hi) = compute_hash(&normalized);
125 self.patterns.push((hash_lo, hash_hi, pattern_id));
126 }
127
128 pub fn build(self, pattern_data_offsets: &[(u32, u32)]) -> Result<Vec<u8>, LiteralHashError> {
129 if self.patterns.is_empty() {
130 return Ok(Vec::new());
131 }
132
133 let shard_bits: u32 = if self.patterns.len() < 10_000 {
134 4
135 } else if self.patterns.len() < 100_000 {
136 5
137 } else {
138 6
139 };
140 let num_shards = 1 << shard_bits;
141
142 let mut shard_buckets: Vec<Vec<(u64, u32, u32)>> =
143 (0..num_shards).map(|_| Vec::new()).collect();
144
145 for (hash_lo, hash_hi, pattern_id) in self.patterns {
146 let shard_id = hash_to_shard(hash_lo, num_shards);
147 shard_buckets[shard_id].push((hash_lo, hash_hi, pattern_id));
148 }
149
150 let parallelism = std::thread::available_parallelism()
151 .map(std::num::NonZero::get)
152 .unwrap_or(8);
153 let batch_size = parallelism.min(num_shards).max(1);
154 let mut shards = Vec::with_capacity(num_shards);
155
156 for chunk_start in (0..num_shards).step_by(batch_size) {
157 let chunk_end = (chunk_start + batch_size).min(num_shards);
158
159 let mut chunk: Vec<Shard> = shard_buckets[chunk_start..chunk_end]
160 .par_iter_mut()
161 .enumerate()
162 .map(|(i, entries)| {
163 let shard_id = chunk_start + i;
164 let entries_vec = std::mem::take(entries);
165 build_shard(shard_id, &entries_vec)
166 })
167 .collect();
168
169 shards.append(&mut chunk);
170 }
171
172 let table_size: usize = shards.iter().map(|s| s.table.len()).sum();
173
174 let mut shard_offsets = vec![0u32; num_shards + 1];
175 let mut offset = 0u32;
176 for shard in &shards {
177 shard_offsets[shard.shard_id] = offset;
178 offset += u32::try_from(shard.table.len()).map_err(|_| {
179 LiteralHashError::InvalidFormat("Shard size exceeds u32::MAX".into())
180 })?;
181 }
182 shard_offsets[num_shards] = offset;
183
184 let shard_table_size = (num_shards + 1) * 4;
185 let table_start = align_to_8(HEADER_SIZE + shard_table_size);
186 let table_bytes = table_size * mem::size_of::<HashEntry>();
187 let mappings_offset = table_start + table_bytes;
188 let mappings_size = 4 + pattern_data_offsets.len() * 8;
189
190 let total_size = mappings_offset + mappings_size;
191 let mut buffer = vec![0u8; total_size];
192
193 let entry_count = shards
194 .iter()
195 .map(|s| s.table.iter().filter(|e| !e.is_empty()).count())
196 .sum::<usize>();
197
198 buffer[0..4].copy_from_slice(LITERAL_HASH_MAGIC);
200 buffer[4..8].copy_from_slice(&MATCHY_LITERAL_HASH_VERSION.to_le_bytes());
201 buffer[8..12].copy_from_slice(
202 &u32::try_from(entry_count)
203 .map_err(|_| {
204 LiteralHashError::InvalidFormat("Entry count exceeds u32::MAX".into())
205 })?
206 .to_le_bytes(),
207 );
208 buffer[12..16].copy_from_slice(
209 &u32::try_from(table_size)
210 .map_err(|_| LiteralHashError::InvalidFormat("Table size exceeds u32::MAX".into()))?
211 .to_le_bytes(),
212 );
213 buffer[16..20].copy_from_slice(
214 &u32::try_from(num_shards)
215 .map_err(|_| {
216 LiteralHashError::InvalidFormat("Shard count exceeds u32::MAX".into())
217 })?
218 .to_le_bytes(),
219 );
220 buffer[20..24].copy_from_slice(&shard_bits.to_le_bytes());
221 buffer[24..28].copy_from_slice(
222 &u32::try_from(mappings_offset)
223 .map_err(|_| {
224 LiteralHashError::InvalidFormat("Mappings offset exceeds u32::MAX".into())
225 })?
226 .to_le_bytes(),
227 );
228 buffer[28..32].copy_from_slice(
229 &u32::try_from(table_start)
230 .map_err(|_| {
231 LiteralHashError::InvalidFormat("Table offset exceeds u32::MAX".into())
232 })?
233 .to_le_bytes(),
234 );
235
236 let mut pos = HEADER_SIZE;
238 for off in &shard_offsets {
239 buffer[pos..pos + 4].copy_from_slice(&off.to_le_bytes());
240 pos += 4;
241 }
242
243 pos = table_start;
245 for shard in &shards {
246 for entry in &shard.table {
247 buffer[pos..pos + 8].copy_from_slice(&entry.hash_lo.to_le_bytes());
248 buffer[pos + 8..pos + 12].copy_from_slice(&entry.hash_hi.to_le_bytes());
249 buffer[pos + 12..pos + 16].copy_from_slice(&entry.pattern_id.to_le_bytes());
250 pos += 16;
251 }
252 }
253
254 pos = mappings_offset;
256 let pattern_count = u32::try_from(pattern_data_offsets.len()).map_err(|_| {
257 LiteralHashError::InvalidFormat("Pattern count exceeds u32::MAX".into())
258 })?;
259 buffer[pos..pos + 4].copy_from_slice(&pattern_count.to_le_bytes());
260 pos += 4;
261 for (pattern_id, data_offset) in pattern_data_offsets {
262 buffer[pos..pos + 4].copy_from_slice(&pattern_id.to_le_bytes());
263 buffer[pos + 4..pos + 8].copy_from_slice(&data_offset.to_le_bytes());
264 pos += 8;
265 }
266
267 Ok(buffer)
268 }
269
270 #[must_use]
271 pub fn pattern_count(&self) -> usize {
272 self.patterns.len()
273 }
274}
275
276impl Default for LiteralHashBuilder {
277 fn default() -> Self {
278 Self::new(MatchMode::CaseSensitive)
279 }
280}
281
282#[derive(Clone)]
283pub struct LiteralHash<'a> {
284 buffer: &'a [u8],
285 header: LiteralHashHeader,
286 table_start: usize,
287 shard_offsets: Vec<u32>,
288 mode: MatchMode,
289}
290
291impl<'a> LiteralHash<'a> {
292 pub fn from_buffer(buffer: &'a [u8], mode: MatchMode) -> Result<Self, LiteralHashError> {
293 if buffer.len() < HEADER_SIZE {
294 return Err(LiteralHashError::InvalidFormat(
295 "Buffer too small for header".into(),
296 ));
297 }
298
299 let magic = &buffer[0..4];
300 if magic != LITERAL_HASH_MAGIC {
301 return Err(LiteralHashError::InvalidFormat(format!(
302 "Invalid magic: expected {LITERAL_HASH_MAGIC:?}, got {magic:?}"
303 )));
304 }
305
306 let version = u32::from_le_bytes(buffer[4..8].try_into().unwrap());
307 if version != MATCHY_LITERAL_HASH_VERSION {
308 return Err(LiteralHashError::InvalidFormat(format!(
309 "Unsupported version: {version} (expected {MATCHY_LITERAL_HASH_VERSION})"
310 )));
311 }
312
313 let entry_count = u32::from_le_bytes(buffer[8..12].try_into().unwrap());
314 let table_size = u32::from_le_bytes(buffer[12..16].try_into().unwrap());
315 let num_shards = u32::from_le_bytes(buffer[16..20].try_into().unwrap());
316 if num_shards > 256 {
317 return Err(LiteralHashError::InvalidFormat(format!(
318 "num_shards {num_shards} exceeds maximum 256"
319 )));
320 }
321 let shard_bits = u32::from_le_bytes(buffer[20..24].try_into().unwrap());
322 let mappings_offset = u32::from_le_bytes(buffer[24..28].try_into().unwrap());
323 let table_offset = u32::from_le_bytes(buffer[28..32].try_into().unwrap());
324
325 let header = LiteralHashHeader {
326 magic: *LITERAL_HASH_MAGIC,
327 version,
328 entry_count,
329 table_size,
330 num_shards,
331 shard_bits,
332 mappings_offset,
333 table_offset,
334 };
335
336 let shard_table_start = HEADER_SIZE;
338 let mut shard_offsets = Vec::with_capacity(num_shards as usize + 1);
339 for i in 0..=num_shards as usize {
340 let off_pos = shard_table_start + i * 4;
341 if off_pos + 4 > buffer.len() {
342 return Err(LiteralHashError::InvalidFormat(
343 "Shard offset table truncated".into(),
344 ));
345 }
346 let off = u32::from_le_bytes(buffer[off_pos..off_pos + 4].try_into().unwrap());
347 shard_offsets.push(off);
348 }
349
350 let table_start = header.table_offset as usize;
351
352 Ok(Self {
353 buffer,
354 header,
355 table_start,
356 shard_offsets,
357 mode,
358 })
359 }
360
361 #[must_use]
362 pub fn mode(&self) -> MatchMode {
363 self.mode
364 }
365
366 #[must_use]
367 pub fn lookup(&self, query: &str) -> Option<u32> {
368 let (query_lo, query_hi) = match self.mode {
369 MatchMode::CaseSensitive => compute_hash(query),
370 MatchMode::CaseInsensitive => compute_hash(&query.to_lowercase()),
371 };
372
373 let num_shards = self.header.num_shards as usize;
374 let shard_id = hash_to_shard(query_lo, num_shards);
375
376 let shard_start = self.shard_offsets[shard_id] as usize;
377 let shard_end = self.shard_offsets[shard_id + 1] as usize;
378 let shard_capacity = shard_end - shard_start;
379
380 if shard_capacity == 0 {
381 return None;
382 }
383
384 let shard_mask = shard_capacity - 1;
385 let base_slot = hash_to_slot(query_lo, shard_mask);
386 let entry_size = mem::size_of::<HashEntry>();
387
388 for i in 0..shard_capacity {
389 let slot = (base_slot + i) & shard_mask;
390 let global_slot = shard_start + slot;
391 let entry_offset = self.table_start + global_slot * entry_size;
392
393 if entry_offset + entry_size > self.buffer.len() {
394 return None;
395 }
396
397 let hash_lo = u64::from_le_bytes(
398 self.buffer[entry_offset..entry_offset + 8]
399 .try_into()
400 .unwrap(),
401 );
402
403 if hash_lo == EMPTY_HASH_LO {
405 let hash_hi = u32::from_le_bytes(
406 self.buffer[entry_offset + 8..entry_offset + 12]
407 .try_into()
408 .unwrap(),
409 );
410 if hash_hi == EMPTY_HASH_HI {
411 return None; }
413 }
414
415 if hash_lo == query_lo {
416 let hash_hi = u32::from_le_bytes(
417 self.buffer[entry_offset + 8..entry_offset + 12]
418 .try_into()
419 .unwrap(),
420 );
421 if hash_hi == query_hi {
422 let pattern_id = u32::from_le_bytes(
423 self.buffer[entry_offset + 12..entry_offset + 16]
424 .try_into()
425 .unwrap(),
426 );
427 return Some(pattern_id);
428 }
429 }
430 }
431
432 None
433 }
434
435 #[must_use]
436 pub fn get_data_offset(&self, pattern_id: u32) -> Option<u32> {
437 let mappings_offset = self.header.mappings_offset as usize;
438
439 if mappings_offset + 4 > self.buffer.len() {
440 return None;
441 }
442
443 let count = u32::from_le_bytes(
444 self.buffer[mappings_offset..mappings_offset + 4]
445 .try_into()
446 .ok()?,
447 );
448
449 let mappings_data_start = mappings_offset + 4;
450
451 for i in 0..count {
452 let offset = mappings_data_start + (i as usize) * 8;
453 if offset + 8 > self.buffer.len() {
454 return None;
455 }
456
457 let pid = u32::from_le_bytes(self.buffer[offset..offset + 4].try_into().ok()?);
458 if pid == pattern_id {
459 return Some(u32::from_le_bytes(
460 self.buffer[offset + 4..offset + 8].try_into().ok()?,
461 ));
462 }
463 }
464
465 None
466 }
467
468 #[must_use]
469 pub fn entry_count(&self) -> u32 {
470 self.header.entry_count
471 }
472
473 #[must_use]
474 pub fn table_size(&self) -> u32 {
475 self.header.table_size
476 }
477}
478
479fn build_shard(shard_id: usize, entries: &[(u64, u32, u32)]) -> Shard {
480 if entries.is_empty() {
481 return Shard {
482 table: Vec::new(),
483 shard_id,
484 };
485 }
486
487 let needed = (entries.len() * 10).div_ceil(6);
488 let capacity = needed.next_power_of_two().max(16);
489 let mask = capacity - 1;
490
491 let mut map: FxHashMap<(u64, u32), u32> = FxHashMap::default();
493 for (hash_lo, hash_hi, pattern_id) in entries {
494 map.insert((*hash_lo, *hash_hi), *pattern_id);
495 }
496
497 let mut table = vec![HashEntry::empty(); capacity];
498
499 for ((hash_lo, hash_hi), pattern_id) in map {
500 let mut pos = hash_to_slot(hash_lo, mask);
501 let mut probes = 0;
502
503 while !table[pos].is_empty() {
504 pos = (pos + 1) & mask;
505 probes += 1;
506 debug_assert!(probes < capacity, "hash table unexpectedly full");
507 }
508
509 table[pos] = HashEntry {
510 hash_lo,
511 hash_hi,
512 pattern_id,
513 };
514 }
515
516 Shard { table, shard_id }
517}
518
519#[inline]
522fn compute_hash(s: &str) -> (u64, u32) {
523 let full = xxh3_128(s.as_bytes());
524 let lo = (full & 0xFFFF_FFFF_FFFF_FFFF) as u64;
525 let hi = ((full >> 64) & 0xFFFF_FFFF) as u32;
526
527 if lo == EMPTY_HASH_LO && hi == EMPTY_HASH_HI {
529 (lo ^ 1, hi ^ 1)
530 } else {
531 (lo, hi)
532 }
533}
534
535#[inline]
536const fn align_to_8(val: usize) -> usize {
537 (val + 7) & !7
538}
539
540#[inline]
541fn hash_to_shard(hash_lo: u64, num_shards: usize) -> usize {
542 #[allow(clippy::cast_possible_truncation)]
543 ((hash_lo % num_shards as u64) as usize)
544}
545
546#[inline]
547fn hash_to_slot(hash_lo: u64, mask: usize) -> usize {
548 #[allow(clippy::cast_possible_truncation)]
549 (((hash_lo >> 32) as usize) & mask)
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_basic_hash_table() {
558 let mut builder = LiteralHashBuilder::new(MatchMode::CaseSensitive);
559 builder.add_pattern("test1", 0);
560 builder.add_pattern("test2", 1);
561 builder.add_pattern("test3", 2);
562
563 let pattern_data = vec![(0, 100), (1, 200), (2, 300)];
564 let bytes = builder.build(&pattern_data).unwrap();
565
566 let hash = LiteralHash::from_buffer(&bytes, MatchMode::CaseSensitive).unwrap();
567 assert_eq!(hash.lookup("test1"), Some(0));
568 assert_eq!(hash.lookup("test2"), Some(1));
569 assert_eq!(hash.lookup("test3"), Some(2));
570 assert_eq!(hash.lookup("test4"), None);
571
572 assert_eq!(hash.get_data_offset(0), Some(100));
573 assert_eq!(hash.get_data_offset(1), Some(200));
574 assert_eq!(hash.get_data_offset(2), Some(300));
575 }
576
577 #[test]
578 fn test_hash_collisions() {
579 let mut builder = LiteralHashBuilder::new(MatchMode::CaseSensitive);
580 for i in 0..100 {
581 builder.add_pattern(&format!("pattern_{i}"), i);
582 }
583
584 let pattern_data: Vec<_> = (0..100).map(|i| (i, i * 10)).collect();
585 let bytes = builder.build(&pattern_data).unwrap();
586
587 let hash = LiteralHash::from_buffer(&bytes, MatchMode::CaseSensitive).unwrap();
588 for i in 0..100 {
589 assert_eq!(hash.lookup(&format!("pattern_{i}")), Some(i));
590 }
591 }
592
593 #[test]
594 fn test_case_insensitive() {
595 let mut builder = LiteralHashBuilder::new(MatchMode::CaseInsensitive);
596 builder.add_pattern("Example.Com", 0);
597 builder.add_pattern("TEST", 1);
598
599 let pattern_data = vec![(0, 100), (1, 200)];
600 let bytes = builder.build(&pattern_data).unwrap();
601
602 let hash = LiteralHash::from_buffer(&bytes, MatchMode::CaseInsensitive).unwrap();
603 assert_eq!(hash.lookup("example.com"), Some(0));
604 assert_eq!(hash.lookup("EXAMPLE.COM"), Some(0));
605 assert_eq!(hash.lookup("test"), Some(1));
606 assert_eq!(hash.lookup("TeSt"), Some(1));
607 }
608
609 #[test]
610 fn test_empty_table() {
611 let builder = LiteralHashBuilder::new(MatchMode::CaseSensitive);
612 let bytes = builder.build(&[]).unwrap();
613 assert!(bytes.is_empty());
614 }
615
616 #[test]
617 fn test_large_table() {
618 let mut builder = LiteralHashBuilder::new(MatchMode::CaseSensitive);
619 for i in 0..10_000 {
620 builder.add_pattern(&format!("entry_{i}"), i);
621 }
622
623 let pattern_data: Vec<_> = (0..10_000).map(|i| (i, i * 4)).collect();
624 let bytes = builder.build(&pattern_data).unwrap();
625
626 let hash = LiteralHash::from_buffer(&bytes, MatchMode::CaseSensitive).unwrap();
627
628 assert_eq!(hash.lookup("entry_0"), Some(0));
629 assert_eq!(hash.lookup("entry_5000"), Some(5000));
630 assert_eq!(hash.lookup("entry_9999"), Some(9999));
631 assert_eq!(hash.lookup("entry_10000"), None);
632 }
633
634 #[test]
635 fn test_version_mismatch() {
636 let mut buffer = vec![0u8; 128];
637 buffer[0..4].copy_from_slice(b"LHSH");
638 buffer[4..8].copy_from_slice(&2u32.to_le_bytes());
639
640 let result = LiteralHash::from_buffer(&buffer, MatchMode::CaseSensitive);
641 match result {
642 Err(LiteralHashError::InvalidFormat(msg)) => {
643 assert!(msg.contains("Unsupported version"), "got: {msg}");
644 }
645 Ok(_) => panic!("expected version mismatch error"),
646 }
647 }
648
649 #[test]
650 fn test_num_shards_limit() {
651 let mut buffer = vec![0u8; 128];
652 buffer[0..4].copy_from_slice(b"LHSH");
653 buffer[4..8].copy_from_slice(&3u32.to_le_bytes());
654 buffer[16..20].copy_from_slice(&1000u32.to_le_bytes());
655
656 let result = LiteralHash::from_buffer(&buffer, MatchMode::CaseSensitive);
657 assert!(
658 matches!(result, Err(LiteralHashError::InvalidFormat(msg)) if msg.contains("exceeds maximum"))
659 );
660 }
661
662 #[test]
663 fn test_get_data_offset_not_found() {
664 let mut builder = LiteralHashBuilder::new(MatchMode::CaseSensitive);
665 builder.add_pattern("test", 0);
666 let bytes = builder.build(&[(0, 100)]).unwrap();
667 let hash = LiteralHash::from_buffer(&bytes, MatchMode::CaseSensitive).unwrap();
668
669 assert_eq!(hash.get_data_offset(999), None);
670 }
671
672 #[test]
673 fn test_empty_marker_not_returned_by_hash() {
674 for i in 0..1000 {
675 let (lo, hi) = compute_hash(&format!("test_string_{i}"));
676 assert!(
677 !(lo == EMPTY_HASH_LO && hi == EMPTY_HASH_HI),
678 "compute_hash returned empty marker for input {i}"
679 );
680 }
681 }
682
683 #[test]
684 fn test_binary_format_alignment() {
685 const _: () = assert!(mem::size_of::<LiteralHashHeader>() == 32);
686 const _: () = assert!(mem::size_of::<HashEntry>() == 16);
687 const _: () = assert!(mem::align_of::<HashEntry>() == 8);
688
689 let mut builder = LiteralHashBuilder::new(MatchMode::CaseSensitive);
690 for i in 0..100 {
691 builder.add_pattern(&format!("pattern_{i}"), i);
692 }
693 let pattern_data: Vec<_> = (0..100).map(|i| (i, i * 10)).collect();
694 let bytes = builder.build(&pattern_data).unwrap();
695
696 let table_offset = u32::from_le_bytes(bytes[28..32].try_into().unwrap()) as usize;
697 assert!(
698 table_offset.is_multiple_of(8),
699 "table_offset {table_offset} not 8-byte aligned"
700 );
701
702 let table_size = u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize;
703 for i in 0..table_size {
704 let entry_offset = table_offset + i * mem::size_of::<HashEntry>();
705 assert!(
706 entry_offset.is_multiple_of(8),
707 "entry {i} at offset {entry_offset} not 8-byte aligned"
708 );
709 }
710
711 let hash = LiteralHash::from_buffer(&bytes, MatchMode::CaseSensitive).unwrap();
712 for i in 0..100 {
713 assert_eq!(hash.lookup(&format!("pattern_{i}")), Some(i));
714 }
715 }
716}