use pyo3::prelude::*;
use crate::encoding::{CategoryFilter, CategoryMapping, CountMinSketch};
#[pyclass(name = "CountMinSketch")]
pub struct PyCountMinSketch {
inner: CountMinSketch,
}
#[pymethods]
impl PyCountMinSketch {
#[new]
#[pyo3(signature = (eps=0.001, confidence=0.99))]
fn new(eps: f64, confidence: f64) -> PyResult<Self> {
if eps <= 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"eps must be positive",
));
}
if confidence <= 0.0 || confidence >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"confidence must be in (0, 1)",
));
}
Ok(Self {
inner: CountMinSketch::new(eps, confidence),
})
}
fn inc(&mut self, hash: u64) {
self.inner.inc(hash);
}
fn inc_by(&mut self, hash: u64, count: u64) {
self.inner.inc_by(hash, count);
}
fn estimate(&self, hash: u64) -> u64 {
self.inner.estimate(hash)
}
fn clear(&mut self) {
self.inner.clear();
}
fn halve(&mut self) {
self.inner.halve();
}
#[getter]
fn width(&self) -> usize {
self.inner.width()
}
#[getter]
fn depth(&self) -> usize {
self.inner.depth()
}
#[getter]
fn memory_bytes(&self) -> usize {
self.inner.memory_bytes()
}
fn __repr__(&self) -> String {
format!(
"CountMinSketch(width={}, depth={}, memory={}B)",
self.inner.width(),
self.inner.depth(),
self.inner.memory_bytes()
)
}
}
#[pyclass(name = "CategoryFilter")]
pub struct PyCategoryFilter {
inner: CategoryFilter,
}
#[pymethods]
impl PyCategoryFilter {
#[new]
#[pyo3(signature = (eps=0.001, confidence=0.99, min_count=5))]
fn new(eps: f64, confidence: f64, min_count: u64) -> PyResult<Self> {
if eps <= 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"eps must be positive",
));
}
if confidence <= 0.0 || confidence >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"confidence must be in (0, 1)",
));
}
Ok(Self {
inner: CategoryFilter::new(eps, confidence, min_count),
})
}
#[staticmethod]
fn default_for_high_cardinality() -> Self {
Self {
inner: CategoryFilter::default_for_high_cardinality(),
}
}
fn count(&mut self, category: &str) {
self.inner.count(category);
}
fn count_batch(&mut self, categories: Vec<String>) {
for cat in &categories {
self.inner.count(cat);
}
}
fn finalize(&mut self, unique_categories: Vec<String>) {
self.inner.finalize(unique_categories);
}
fn is_frequent(&self, category: &str) -> bool {
self.inner.is_frequent(category)
}
fn estimate_count(&self, category: &str) -> u64 {
self.inner.estimate_count(category)
}
fn filter(&self, category: &str) -> String {
self.inner.filter(category).to_string()
}
fn filter_batch(&self, categories: Vec<String>) -> Vec<String> {
categories
.iter()
.map(|c| self.inner.filter(c).to_string())
.collect()
}
#[getter]
fn num_frequent(&self) -> usize {
self.inner.num_frequent()
}
fn frequent_categories(&self) -> Vec<String> {
self.inner.frequent_categories().iter().cloned().collect()
}
#[getter]
fn memory_bytes(&self) -> usize {
self.inner.memory_bytes()
}
fn to_mapping(&self) -> PyCategoryMapping {
PyCategoryMapping {
inner: CategoryMapping::from_filter(&self.inner),
}
}
fn __repr__(&self) -> String {
format!(
"CategoryFilter(num_frequent={}, memory={}B)",
self.inner.num_frequent(),
self.inner.memory_bytes()
)
}
}
#[pyclass(name = "CategoryMapping")]
#[derive(Clone)]
pub struct PyCategoryMapping {
inner: CategoryMapping,
}
#[pymethods]
impl PyCategoryMapping {
#[staticmethod]
fn from_filter(filter: &PyCategoryFilter) -> Self {
Self {
inner: CategoryMapping::from_filter(&filter.inner),
}
}
fn get_index(&self, category: &str) -> u32 {
self.inner.get_index(category)
}
fn get_indices(&self, categories: Vec<String>) -> Vec<u32> {
categories.iter().map(|c| self.inner.get_index(c)).collect()
}
#[getter]
fn unknown_idx(&self) -> u32 {
self.inner.unknown_idx
}
#[getter]
fn num_categories(&self) -> usize {
self.inner.num_categories()
}
fn items(&self) -> Vec<(String, u32)> {
self.inner.category_to_idx.clone()
}
fn __repr__(&self) -> String {
format!(
"CategoryMapping(num_categories={}, unknown_idx={})",
self.inner.num_categories(),
self.inner.unknown_idx
)
}
fn __len__(&self) -> usize {
self.inner.num_categories()
}
}
impl From<CategoryMapping> for PyCategoryMapping {
fn from(mapping: CategoryMapping) -> Self {
Self { inner: mapping }
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyCountMinSketch>()?;
m.add_class::<PyCategoryFilter>()?;
m.add_class::<PyCategoryMapping>()?;
Ok(())
}