use std::collections::{HashSet, VecDeque};
use std::fs::File;
use numpy::{PyArray, PyArray1, PyArray2};
use pyo3::class::*;
use pyo3::exceptions::{PyIndexError, PyKeyError};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyBytes, PyDict, PyList, PyTuple, PyType};
use pyo3::{create_exception, wrap_pyfunction};
use crate::distance::{distance, minmer_matrix};
use crate::errors::FinchResult;
use crate::filtering::FilterParams;
use crate::serialization::{write_finch_file, Sketch as SketchRs};
use crate::sketch_schemes::{KmerCount, SketchParams};
use crate::{bail, open_sketch_file, sketch_files as rs_sketch_files};
create_exception!(finch, FinchError, pyo3::exceptions::PyException);
macro_rules! py_try {
($call:expr) => {
$call.map_err(|e| PyErr::new::<FinchError, _>(format!("{}", e)))?
};
}
fn merge_sketches(sketch: &mut SketchRs, other: &SketchRs, size: Option<usize>) -> FinchResult<()> {
sketch.seq_length += other.seq_length;
sketch.num_valid_kmers += other.num_valid_kmers;
if let Some((name, v1, v2)) = sketch
.sketch_params
.check_compatibility(&other.sketch_params)
{
bail!(
"First sketch has {} {}, but second sketch has {0} {}",
name,
v1,
v2,
);
}
let sketch1 = &sketch.hashes;
let sketch2 = &other.hashes;
let mut new_hashes = Vec::with_capacity(sketch1.len() + sketch2.len());
let (mut i, mut j) = (0, 0);
while (i < sketch1.len()) && (j < sketch2.len()) {
if sketch1[i].hash < sketch2[j].hash {
new_hashes.push(sketch1[i].clone());
i += 1;
} else if sketch2[j].hash < sketch1[i].hash {
new_hashes.push(sketch2[j].clone());
j += 1;
} else {
new_hashes.push(KmerCount {
hash: sketch1[i].hash,
kmer: sketch1[i].kmer.clone(),
count: sketch1[i].count + sketch2[j].count,
extra_count: sketch1[i].extra_count + sketch2[j].extra_count,
label: sketch1[i].label.clone(),
});
i += 1;
j += 1;
}
}
let scale = sketch.sketch_params.hash_info().3;
match (size, scale) {
(Some(s), Some(sc)) => {
let max_hash = u64::max_value() / (1. / sc) as u64;
new_hashes = new_hashes
.into_iter()
.enumerate()
.take_while(|(ix, h)| (h.hash <= max_hash) || (*ix < s))
.map(|(_, h)| h)
.collect();
}
(None, Some(sc)) => {
let max_hash = u64::max_value() / (1. / sc) as u64;
new_hashes = new_hashes
.into_iter()
.take_while(|h| h.hash <= max_hash)
.collect();
}
(Some(s), None) => {
new_hashes.truncate(s);
}
(None, None) => {
}
}
sketch.hashes = new_hashes;
Ok(())
}
#[pyclass]
pub struct Multisketch {
pub sketches: Vec<SketchRs>,
}
#[pymethods]
impl Multisketch {
#[classmethod]
pub fn open(_cls: &PyType, filename: &str) -> PyResult<Multisketch> {
Ok(Multisketch {
sketches: py_try!(open_sketch_file(filename)),
})
}
#[classmethod]
pub fn from_sketches(_cls: &PyType, sketches: Vec<PyRef<Sketch>>) -> PyResult<Multisketch> {
let sketches = sketches.iter().map(|s| s.s.clone()).collect();
Ok(Multisketch { sketches })
}
pub fn save(&self, filename: &str) -> PyResult<()> {
let mut out = File::create(&filename)
.map_err(|_| PyErr::new::<FinchError, _>(format!("Could not create {}", filename)))?;
py_try!(write_finch_file(&mut out, &self.sketches));
Ok(())
}
pub fn add(&mut self, sketch: &Sketch) -> PyResult<()> {
self.sketches.push(sketch.s.clone());
Ok(())
}
pub fn best_match(&self, query: &Sketch) -> PyResult<(usize, Sketch)> {
let mut best_sketch: usize = 0;
let mut max_containment: f64 = 0.;
for (ix, sketch) in self.sketches.iter().enumerate() {
let dist = py_try!(distance(&query.s, &sketch, false));
if dist.containment > max_containment {
max_containment = dist.containment;
best_sketch = ix;
}
}
Ok((best_sketch, self.sketches[best_sketch].clone().into()))
}
pub fn filter_to_matches(&mut self, query: &Sketch, threshold: f64) -> PyResult<()> {
let mut filtered_sketches = Vec::new();
for sketch in &self.sketches {
let dist = py_try!(distance(&query.s, &sketch, false));
if dist.containment >= threshold {
filtered_sketches.push(sketch.clone());
}
}
self.sketches = filtered_sketches;
Ok(())
}
pub fn filter_to_names(&mut self, names: &PyList) -> PyResult<()> {
let sketch_names: Vec<&str> = names.extract()?;
let name_set: HashSet<&str> = sketch_names.into_iter().collect();
self.sketches
.retain(|s| name_set.contains::<str>(s.name.as_ref()));
Ok(())
}
}
#[pyproto]
impl PyIterProtocol for Multisketch {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<SketchIter> {
let sketches = slf.sketches.iter().map(|s| s.clone().into()).collect();
Ok(SketchIter { sketches })
}
}
#[pyclass]
pub struct SketchIter {
sketches: VecDeque<Sketch>,
}
#[pyproto]
impl PyIterProtocol for SketchIter {
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<Sketch>> {
Ok(slf.sketches.pop_front())
}
}
#[pyproto]
impl PyObjectProtocol for Multisketch {
fn __repr__(&self) -> PyResult<String> {
let n_sketches = self.sketches.len();
let sketch_plural = if n_sketches == 1 {
"sketch"
} else {
"sketches"
};
Ok(format!("<Multisketch ({} {})>", n_sketches, sketch_plural))
}
}
#[inline]
fn _get_sketch_index(sketches: &[SketchRs], key: &PyAny) -> PyResult<usize> {
if let Ok(int_key) = key.extract::<isize>() {
let l = sketches.len() as isize;
if -l <= int_key && int_key < 0 {
Ok((l - int_key) as usize)
} else if 0 <= int_key && int_key < l {
Ok(int_key as usize)
} else {
Err(PyErr::new::<PyIndexError, _>("index out of range"))
}
} else if let Ok(str_key) = key.extract::<&str>() {
let remove_idx = sketches.iter().position(|s| s.name == str_key);
if let Some(idx) = remove_idx {
Ok(idx)
} else {
Err(PyErr::new::<PyKeyError, _>(str_key.to_string()))
}
} else {
Err(PyErr::new::<FinchError, _>(
"key is not a string or integer",
))
}
}
#[pyproto]
impl PyMappingProtocol for Multisketch {
fn __len__(&self) -> PyResult<usize> {
Ok(self.sketches.len())
}
fn __getitem__(&self, key: &PyAny) -> PyResult<Sketch> {
let idx = _get_sketch_index(&self.sketches, key)?;
Ok(self.sketches[idx].clone().into())
}
fn __delitem__(&mut self, key: &PyAny) -> PyResult<()> {
let idx = _get_sketch_index(&self.sketches, key)?;
self.sketches.remove(idx);
Ok(())
}
}
#[pyproto]
impl PySequenceProtocol for Multisketch {
fn __contains__(&self, key: &str) -> PyResult<bool> {
for sketch in &self.sketches {
if sketch.name == key {
return Ok(true);
}
}
Ok(false)
}
}
#[pyclass]
pub struct Sketch {
pub s: SketchRs,
}
#[pymethods]
impl Sketch {
#[new]
fn new(name: &str) -> Self {
let sketch_params = SketchParams::Mash {
kmers_to_sketch: 1000,
final_size: 1000,
no_strict: true,
kmer_length: 21,
hash_seed: 0,
};
let s = SketchRs {
name: name.to_string(),
seq_length: 0,
num_valid_kmers: 0,
comment: String::new(),
hashes: Vec::new(),
sketch_params,
filter_params: FilterParams::default(),
};
Sketch { s }
}
#[getter]
fn get_name(&self) -> PyResult<String> {
Ok(self.s.name.clone())
}
#[setter]
fn set_name(&mut self, value: &str) -> PyResult<()> {
self.s.name = value.to_string();
Ok(())
}
#[getter]
fn get_seq_length(&self) -> PyResult<u64> {
Ok(self.s.seq_length)
}
#[getter]
fn get_num_valid_kmers(&self) -> PyResult<u64> {
Ok(self.s.num_valid_kmers)
}
#[getter]
fn get_comment(&self) -> PyResult<String> {
Ok(self.s.comment.clone())
}
#[setter]
fn set_comment(&mut self, value: &str) -> PyResult<()> {
self.s.comment = value.to_string();
Ok(())
}
#[getter]
fn get_hashes(&self) -> PyResult<Vec<(u64, PyObject, u32, u32)>> {
let gil = Python::acquire_gil();
let py = gil.python();
self.s
.hashes
.clone()
.into_iter()
.map(|i| {
Ok((
i.hash,
PyBytes::new(py, &i.kmer).into(),
i.count,
i.extra_count,
))
})
.collect()
}
#[getter]
pub fn get_sketch_params(&self, py: Python) -> PyResult<PyObject> {
let ret = PyDict::new(py);
match self.s.sketch_params {
SketchParams::Mash {
kmers_to_sketch,
final_size,
no_strict,
kmer_length,
hash_seed,
} => {
ret.set_item("sketch_type", "mash")?;
ret.set_item("kmers_to_sketch", kmers_to_sketch)?;
ret.set_item("final_size", final_size)?;
ret.set_item("no_strict", no_strict)?;
ret.set_item("kmer_length", kmer_length)?;
ret.set_item("hash_seed", hash_seed)?;
}
SketchParams::Scaled {
kmers_to_sketch,
kmer_length,
scale,
hash_seed,
} => {
ret.set_item("sketch_type", "scaled")?;
ret.set_item("kmers_to_sketch", kmers_to_sketch)?;
ret.set_item("kmer_length", kmer_length)?;
ret.set_item("scale", scale)?;
ret.set_item("hash_seed", hash_seed)?;
}
SketchParams::AllCounts { kmer_length } => {
ret.set_item("sketch_type", "none")?;
ret.set_item("kmer_length", kmer_length)?;
}
}
Ok(ret.to_object(py))
}
pub fn merge(&mut self, sketch: &Sketch, size: Option<usize>) -> PyResult<()> {
Ok(py_try!(merge_sketches(&mut self.s, &sketch.s, size)))
}
#[args(old_mode = false)]
pub fn compare(&self, sketch: &Sketch, old_mode: bool) -> PyResult<(f64, f64)> {
let dist = py_try!(distance(&sketch.s, &self.s, old_mode));
Ok((dist.containment, dist.jaccard))
}
pub fn compare_counts(
&self,
sketch: &Sketch,
) -> PyResult<(u64, u64, u64, u64, u64, f64, f64, f64)> {
let reference = &self.s.hashes;
let query = &sketch.s.hashes;
let mut common: u64 = 0;
let mut ref_pos: usize = 0;
let mut ref_count: u64 = 0;
let mut query_pos: usize = 0;
let mut query_count: u64 = 0;
let mut query_mean: f64 = 0.;
let mut query_m2: f64 = 0.;
let mut query_m3: f64 = 0.;
let mut query_m4: f64 = 0.;
while (ref_pos < reference.len()) && (query_pos < query.len()) {
if reference[ref_pos].hash < query[query_pos].hash {
ref_pos += 1;
} else if query[query_pos].hash < reference[ref_pos].hash {
query_pos += 1;
} else {
ref_count += u64::from(reference[ref_pos].count);
query_count += u64::from(query[query_pos].count);
let n = common as f64 + 1.;
let float_count = f64::from(query[query_pos].count);
let delta: f64 = float_count - query_mean;
let delta_n: f64 = delta / n;
let delta_n2: f64 = delta_n * delta_n;
let term1 = delta * delta_n * (n - 1.);
query_mean += delta_n;
query_m4 += term1 * delta_n2 * (n * n - 3. * n + 3.) + 6. * delta_n2 * query_m2
- 4. * delta_n * query_m3;
query_m3 += term1 * delta_n * (n - 2.) - 3. * delta_n * query_m2;
query_m2 += term1;
ref_pos += 1;
query_pos += 1;
common += 1;
}
}
let var = query_m2 / common as f64;
let skew = (common as f64).sqrt() * query_m3 / query_m2.powf(1.5);
let kurt = (common as f64) * query_m4 / (query_m2 * query_m2) - 3.;
Ok((
common,
ref_pos as u64,
query_pos as u64,
ref_count,
query_count,
var,
skew,
kurt,
))
}
#[args(args = "*")]
pub fn compare_matrix(&self, args: &PyTuple) -> PyResult<Py<PyArray2<i32>>> {
let sketches: Vec<PyRef<Sketch>> = args.extract()?;
let sketch_kmers: Vec<&[KmerCount]> = sketches.iter().map(|s| &s.s.hashes[..]).collect();
let result = minmer_matrix(&self.s.hashes, &sketch_kmers);
let gil = Python::acquire_gil();
let py = gil.python();
Ok(PyArray::from_owned_array(py, result).to_owned())
}
#[getter]
pub fn get_counts(&self) -> PyResult<Py<PyArray1<i32>>> {
let result = self.s.hashes.iter().map(|k| k.count as i32);
let gil = Python::acquire_gil();
let py = gil.python();
Ok(PyArray::from_exact_iter(py, result).to_owned())
}
#[setter]
pub fn set_counts(&mut self, value: &PyArray1<i32>) -> PyResult<()> {
let val: Vec<i32> = value.extract()?;
if val.len() != self.s.hashes.len() {
return Err(PyErr::new::<FinchError, _>(
"counts must be same length as sketch",
));
}
let mut new_hashes = Vec::new();
for (s, v) in self.s.hashes.iter_mut().zip(val.iter()) {
if *v < 0 {
return Err(PyErr::new::<FinchError, _>(format!(
"Negative count {} not supported",
*v
)));
} else if *v > 0 {
let mut new_s = s.clone();
new_s.count = *v as u32;
new_hashes.push(new_s);
}
}
self.s.hashes = new_hashes;
Ok(())
}
pub fn copy(&self) -> PyResult<Sketch> {
Ok(Sketch { s: self.s.clone() })
}
}
#[pyproto]
impl PyObjectProtocol for Sketch {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("<Sketch \"{}\">", self.s.name.clone()))
}
}
#[pyproto]
impl PyMappingProtocol for Sketch {
fn __len__(&self) -> PyResult<usize> {
Ok(self.s.len())
}
}
impl From<SketchRs> for Sketch {
fn from(s: SketchRs) -> Self {
Sketch { s }
}
}
#[pyfunction(n_hashes = 1000, kmer_length = 21, filter = true, seed = 0)]
pub fn sketch_file(
filename: &str,
n_hashes: usize,
final_size: Option<usize>,
kmer_length: u8,
filter: bool,
seed: u64,
) -> PyResult<Sketch> {
let sketch_params = SketchParams::Mash {
kmers_to_sketch: n_hashes,
final_size: final_size.unwrap_or(n_hashes),
no_strict: false,
kmer_length,
hash_seed: seed,
};
let filters = FilterParams {
filter_on: Some(filter),
abun_filter: (None, None),
err_filter: 1.,
strand_filter: 0.1,
};
let mut sketches = py_try!(rs_sketch_files(&[filename], &sketch_params, &filters));
Ok(Sketch {
s: sketches.pop().unwrap(),
})
}
#[pymodule]
fn finch(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Multisketch>()?;
m.add_class::<Sketch>()?;
m.add_wrapped(wrap_pyfunction!(sketch_file))?;
m.add("FinchError", py.get_type::<FinchError>())?;
Ok(())
}