use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::fs::{File, OpenOptions};
use std::time::Instant;
use crate::mmap_slice::{MmapSlice, MmapSliceMut};
use crate::par_quicksort::par_sort_unstable_by_key;
use crate::table::SuffixTable;
#[pyclass]
pub struct MemmapIndex {
table: SuffixTable<MmapSlice<u16>, MmapSlice<u64>>,
}
#[pymethods]
impl MemmapIndex {
#[new]
pub fn new(_py: Python, text_path: String, table_path: String) -> PyResult<Self> {
let text_file = File::open(&text_path)?;
let table_file = File::open(&table_path)?;
Ok(MemmapIndex {
table: SuffixTable::from_parts(
MmapSlice::new(&text_file)?,
MmapSlice::new(&table_file)?,
),
})
}
#[staticmethod]
pub fn build(text_path: String, table_path: String, verbose: bool) -> PyResult<Self> {
let text_mmap = MmapSlice::new(&File::open(&text_path)?)?;
let table_file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&table_path)?;
let table_size = text_mmap.len() * 8;
table_file.set_len(table_size as u64)?;
println!("Writing indices to disk...");
let start = Instant::now();
let mut table_mmap = MmapSliceMut::<u64>::new(&table_file)?;
table_mmap
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = i as u64);
assert_eq!(table_mmap.len(), text_mmap.len());
println!("Time elapsed: {:?}", start.elapsed());
let start = Instant::now();
let scale = (text_mmap.len() as f64) / 5e9; let stack_size = scale.log2().max(1.0) * 8e6;
rayon::ThreadPoolBuilder::new()
.stack_size(stack_size as usize)
.build()
.unwrap()
.install(|| {
println!("Sorting indices...");
par_sort_unstable_by_key(
table_mmap.as_slice_mut(),
|&i| &text_mmap[i as usize..],
verbose,
);
});
println!("Time elapsed: {:?}", start.elapsed());
let table_mmap = MmapSlice::new(&table_file)?;
Ok(MemmapIndex {
table: SuffixTable::from_parts(text_mmap, table_mmap),
})
}
pub fn contains(&self, query: Vec<u16>) -> bool {
self.table.contains(&query)
}
pub fn count(&self, query: Vec<u16>) -> usize {
self.table.positions(&query).len()
}
pub fn positions(&self, query: Vec<u16>) -> Vec<u64> {
self.table.positions(&query).to_vec()
}
pub fn count_next(&self, query: Vec<u16>, vocab: Option<u16>) -> Vec<usize> {
self.table.count_next(&query, vocab)
}
pub fn batch_count_next(&self, queries: Vec<Vec<u16>>, vocab: Option<u16>) -> Vec<Vec<usize>> {
self.table.batch_count_next(&queries, vocab)
}
pub fn sample(&self, query: Vec<u16>, n: usize, k: usize) -> Result<Vec<u16>, PyErr> {
self.table
.sample(&query, n, k)
.map_err(|error| PyValueError::new_err(error.to_string()))
}
pub fn batch_sample(
&self,
query: Vec<u16>,
n: usize,
k: usize,
num_samples: usize,
) -> Result<Vec<Vec<u16>>, PyErr> {
self.table
.batch_sample(&query, n, k, num_samples)
.map_err(|error| PyValueError::new_err(error.to_string()))
}
pub fn is_sorted(&self) -> bool {
self.table.is_sorted()
}
}