1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
//! Ngram-based indexing of strings into a binary file.
//!
//! This library can be used to index lots of strings into a file (identified
//! by a numeric ID), and then perform a fuzzy search over those strings.
//!
//! Example:
//! ```
//! # use std::io::BufWriter;
//! # use std::fs::File;
//! # use std::path::Path;
//! # let path = Path::new("/tmp/test.db");
//! # use ngram_search::Ngrams;
//! // Build index
//! let mut builder = Ngrams::builder();
//! builder.add("spam", 0);
//! builder.add("ham", 1);
//! builder.add("mam", 2);
//!
//! // Write it to a file
//! let mut file = BufWriter::new(File::create(path).unwrap());
//! builder.write(&mut file).unwrap();
//!
//! // Search our index
//! let mut data = Ngrams::open(path).unwrap();
//! assert_eq!(
//!     data.search("ham", 0.24).unwrap(),
//!     vec![
//!         (1, 1.0), // "ham" is an exact match
//!         (2, 0.25), // "mam" is close
//!     ],
//! );
//! assert_eq!(
//!     data.search("spa", 0.2).unwrap(),
//!     vec![
//!         (0, 0.22222222), // "spam" is close
//!     ],
//! );
//! ```

use byteorder::{self, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
use std::fs::File;
use std::io::{Seek, SeekFrom, Write};
use std::path::Path;
use unicode_normalization::UnicodeNormalization;

type Order = byteorder::BigEndian;

struct Leaf {
    id: u32,
    count: u8,
    total_ngrams: u8,
}

struct Branch {
    entries: Vec<Entry>,
    character: u32,
}

enum Entry {
    Branch(Branch),
    Leaf(Leaf),
}

#[cfg(feature = "mmap")]
type Reader = std::io::Cursor<memmap::Mmap>;
#[cfg(not(feature = "mmap"))]
type Reader = std::io::BufReader<File>;

/// Ngrams index of strings backed by a file.
pub struct Ngrams {
    reader: Reader,
}

fn with_trigrams<T, F: FnMut([char; 3]) -> Result<(), T>>(
    string: &str,
    mut f: F,
) -> Result<(), T> {
    // Normalize
    let string = string.to_lowercase();
    let mut chars = string.nfc();

    if string.len() == 0 {
        f(['$', '$', '$'])?;
    } else {
        let mut c1 = '$';
        let mut c2 = '$';
        while let Some(c3) = chars.next() {
            f([c1, c2, c3])?;
            c1 = c2;
            c2 = c3;
        }
        f([c1, c2, '$'])?;
        f([c2, '$', '$'])?;
    }

    Ok(())
}

impl Ngrams {
    /// Return a builder object used to build an index.
    pub fn builder() -> NgramsBuilder {
        Default::default()
    }

    /// Open an index from a file.
    pub fn open(path: &Path) -> std::io::Result<Ngrams> {
        let file = File::open(path)?;
        #[cfg(feature = "mmap")]
        let reader = {
            let data = unsafe { memmap::Mmap::map(&file) }?;
            std::io::Cursor::new(data)
        };
        #[cfg(not(feature = "mmap"))]
        let reader = std::io::BufReader::new(file);
        Ok(Ngrams { reader })
    }

    fn search_ngram(
        &mut self,
        trigram: &[char; 3],
    ) -> std::io::Result<Vec<Leaf>> {
        self.reader.seek(SeekFrom::Start(0))?;
        for character in trigram {
            let character = *character as u32;

            // Check that this is a branch
            if self.reader.read_u8()? != 1 {
                panic!("Invalid branch");
            }

            // Look for the character we need
            let size = self.reader.read_u32::<Order>()?;
            let mut found = None;
            for _ in 0..size {
                let c = self.reader.read_u32::<Order>()?;
                let p = self.reader.read_u32::<Order>()?;
                if c == character {
                    found = Some(p);
                    break;
                }
            }

            // Move down
            match found {
                Some(pos) => {
                    self.reader.seek(SeekFrom::Start(pos as u64))?;
                }
                None => return Ok(Vec::new()),
            }
        }

        // Read leaves
        if self.reader.read_u8()? != 2 {
            panic!("Invalid leaf record");
        }
        let mut leaves = Vec::new();
        let size = self.reader.read_u32::<Order>()?;
        for _ in 0..size {
            let id = self.reader.read_u32::<Order>()?;
            let count = self.reader.read_u8()?;
            let total_ngrams = self.reader.read_u8()?;
            leaves.push(Leaf {
                id,
                count,
                total_ngrams,
            });
        }

        Ok(leaves)
    }

    fn search_trigrams(
        &mut self,
        trigrams: &[([char; 3], u32)],
        threshold: f32,
    ) -> std::io::Result<Vec<(u32, f32)>> {
        let total_ngrams: u32 = trigrams.iter().map(|(_, c)| c).sum();

        // Look for all trigrams
        let hits = trigrams
            .iter()
            .map(|(trigram, count)| {
                (self.search_ngram(trigram).unwrap(), *count)
            })
            .collect::<Vec<_>>();

        // Build a list of results by merging all those hits together
        // (id, (nb_shared_ngrams, total_ngrams)
        let mut matches: Vec<(u32, f32)> = Vec::new();
        let mut positions = Vec::new();
        positions.resize(hits.len(), 0);
        loop {
            // Find the smallest next element and its count
            let mut smallest_id = None;
            let mut match_total_ngrams = 0;
            for i in 0..hits.len() {
                if positions[i] < hits[i].0.len() {
                    let leaf = &hits[i].0[positions[i]];
                    if let Some(s) = smallest_id {
                        if leaf.id < s {
                            smallest_id = Some(leaf.id);
                            match_total_ngrams = leaf.total_ngrams;
                        }
                    } else {
                        smallest_id = Some(leaf.id);
                        match_total_ngrams = leaf.total_ngrams;
                    }
                }
            }

            // No next element: we're done
            let smallest_id = match smallest_id {
                Some(s) => s,
                None => break,
            };

            // Compute the count and move forward in those Vecs
            let mut shared = 0;
            for i in 0..hits.len() {
                if positions[i] < hits[i].0.len() {
                    let leaf = &hits[i].0[positions[i]];
                    if leaf.id == smallest_id {
                        shared += hits[i].1.min(leaf.count as u32);
                        positions[i] += 1;
                    }
                }
            }

            // Compute score
            let allgrams = total_ngrams + match_total_ngrams as u32 - shared;
            let score = shared as f32 / allgrams as f32;

            // Threshold
            if score < threshold {
                continue;
            }

            // Store result
            matches.push((smallest_id, score));
        }

        // Sort results
        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

        Ok(matches)
    }

    /// Search for a string in the index.
    ///
    /// Returns a vector of pairs `(string id, score)` sorted by descending
    /// scores.
    pub fn search(
        &mut self,
        string: &str,
        threshold: f32,
    ) -> std::io::Result<Vec<(u32, f32)>> {
        let mut trigrams = HashMap::new();
        with_trigrams::<(), _>(string, |chars| {
            *trigrams.entry(chars).or_insert(0) += 1;
            Ok(())
        })
        .unwrap();
        let array = trigrams.into_iter().collect::<Vec<_>>();

        self.search_trigrams(&array, threshold)
    }
}

/// A builder object used to build the index file.
///
/// Note that the index will be held into memory during construction, and is
/// only written to disk when you call write(). Therefore you might need a lot
/// of memory for construction.
#[derive(Default)]
pub struct NgramsBuilder {
    data: Vec<Entry>,
}

impl NgramsBuilder {
    fn add_trigram_chars(
        &mut self,
        trigram: &[char; 3],
        id: u32,
        count: u8,
        total_ngrams: u8,
    ) {
        let mut data = &mut self.data;
        for character in trigram {
            let character = *character as u32;
            let mut idx = None;
            for (i, e) in data.iter_mut().enumerate() {
                let e = match e {
                    Entry::Leaf(_) => panic!("Found Leaf instead of Branch"),
                    Entry::Branch(b) => b,
                };
                if e.character == character {
                    // We found the right branch, go down

                    // We can't assign to `data` here, doesn't pass borrow
                    // checker So we store the index in idx
                    // and assign data below
                    idx = Some(i);
                    break;
                }
            }

            // If we didn't find an entry, add it
            let idx = if let Some(idx) = idx {
                idx
            } else {
                let idx = bisect_branches(data, character);
                data.insert(
                    idx,
                    Entry::Branch(Branch {
                        character,
                        entries: vec![],
                    }),
                );
                idx
            };

            // Change the reference to that new entry
            let e = if let Entry::Branch(b) = &mut data[idx] {
                b
            } else {
                panic!()
            };
            data = &mut e.entries;
        }

        // Now insert the leaf, sorted by id
        // Find position
        let idx = bisect_leaves(data, id);
        data.insert(
            idx,
            Entry::Leaf(Leaf {
                id,
                count,
                total_ngrams,
            }),
        );
    }

    /// Add a string to the index.
    ///
    /// The ID is what will be returned when searching for this string in the
    /// index, and should not be used multiple times.
    pub fn add(&mut self, string: &str, id: u32) {
        let mut trigrams = HashMap::new();
        let mut total_ngrams = 0;
        with_trigrams::<(), _>(string, |chars| {
            *trigrams.entry(chars).or_insert(0) += 1;
            total_ngrams += 1;
            Ok(())
        })
        .unwrap();

        for (trigram, count) in trigrams {
            self.add_trigram_chars(&trigram, id, count, total_ngrams);
        }
    }

    /// Write the index to a file.
    pub fn write<W: Write + Seek>(
        &self,
        output: &mut W,
    ) -> std::io::Result<()> {
        write_branch(&self.data, output)?;
        Ok(())
    }
}

fn bisect_branches(data: &[Entry], character: u32) -> usize {
    let mut low = 0;
    let mut high = data.len();
    while low < high {
        let mid = (low + high) / 2;
        let x = match &data[mid] {
            Entry::Leaf(_) => panic!("Leaf in the branches"),
            Entry::Branch(b) => b.character,
        };
        if character < x {
            high = mid;
        } else {
            low = mid + 1;
        }
    }
    low
}

fn bisect_leaves(data: &[Entry], id: u32) -> usize {
    let mut low = 0;
    let mut high = data.len();
    while low < high {
        let mid = (low + high) / 2;
        let x = match &data[mid] {
            Entry::Branch(_) => panic!("Branch in the leaves"),
            Entry::Leaf(leaf) => leaf.id,
        };
        if id < x {
            high = mid;
        } else {
            low = mid + 1;
        }
    }
    low
}

fn write_branch<W: Write + Seek>(
    entries: &[Entry],
    output: &mut W,
) -> std::io::Result<u64> {
    // Seek to end of stream, save position
    let pos = output.seek(SeekFrom::End(0))?;

    let is_branch = match entries.first() {
        None => panic!("Empty entry"),
        Some(Entry::Branch(_)) => true,
        Some(Entry::Leaf(_)) => false,
    };

    // Tag content
    output.write_all(&[if is_branch { 1u8 } else { 2u8 }])?;

    // Write length
    output.write_u32::<Order>(entries.len() as u32)?;

    let start = pos + 1 + 4;

    if is_branch {
        // Reserve space for our record
        let mut data = Vec::new();
        data.resize((4 + 4) * entries.len(), 0);
        output.write_all(&data)?;

        // Recursively write the entries at the end of the stream, each time
        // updating the entry in our record
        for (i, entry) in entries.iter().enumerate() {
            match entry {
                Entry::Branch(Branch {
                    entries: branch_entries,
                    character,
                }) => {
                    // Recursively write at the end
                    let branch_pos = write_branch(branch_entries, output)?;

                    // Update the entry in our record to point there
                    output
                        .seek(SeekFrom::Start(start + (4 + 4) * (i as u64)))?;
                    output.write_u32::<Order>(*character)?;
                    output.write_u32::<Order>(branch_pos as u32)?;
                }
                Entry::Leaf(_) => panic!("Leaf in a branch"),
            }
        }
    } else {
        // Write the leaves
        for entry in entries {
            match entry {
                Entry::Leaf(Leaf {
                    id,
                    count,
                    total_ngrams,
                }) => {
                    output.write_u32::<Order>(*id)?;
                    output.write_all(&[*count, *total_ngrams])?;
                }
                Entry::Branch(_) => panic!("Branch in a leaf"),
            }
        }
    }

    Ok(pos)
}