#![cfg(feature = "python")]
use std::path::PathBuf;
use pyo3::prelude::*;
use super::FeatureSketch;
use crate::pyutil::{KeyValueLike, StrOrBytes, to_py_err};
#[pyclass(name = "FeatureSketch", module = "rcf3.rcf3", skip_from_py_object)]
#[derive(Clone, Debug)]
pub struct PyFeatureSketch {
inner: FeatureSketch,
}
#[pymethods]
impl PyFeatureSketch {
#[new]
#[pyo3(signature = (
value_projection_dims = 32,
presence_projection_dims = 32,
chains_per_ensemble = 16,
chain_depth = 8,
sketch_rows = 2,
sketch_buckets = 2048,
decay_half_life = 2048,
seed = None
))]
#[allow(clippy::too_many_arguments)]
fn py_new(
value_projection_dims: usize,
presence_projection_dims: usize,
chains_per_ensemble: usize,
chain_depth: usize,
sketch_rows: usize,
sketch_buckets: usize,
decay_half_life: u64,
seed: Option<u64>,
) -> PyResult<Self> {
let mut builder = FeatureSketch::builder()
.value_projection_dims(value_projection_dims)
.presence_projection_dims(presence_projection_dims)
.chains_per_ensemble(chains_per_ensemble)
.chain_depth(chain_depth)
.sketch_rows(sketch_rows)
.sketch_buckets(sketch_buckets)
.decay_half_life(decay_half_life);
if let Some(seed) = seed {
builder = builder.seed(seed);
}
let inner = builder.build().map_err(to_py_err)?;
Ok(Self { inner })
}
fn update(&mut self, py: Python<'_>, feature: KeyValueLike) -> PyResult<()> {
py.detach(|| match feature {
KeyValueLike::Pairs(pairs) => self.inner.update(pairs).map_err(to_py_err),
KeyValueLike::Dict(mapping) => self.inner.update(mapping).map_err(to_py_err),
})
}
fn score(&self, py: Python<'_>, feature: KeyValueLike) -> PyResult<f64> {
py.detach(|| match feature {
KeyValueLike::Pairs(pairs) => self.inner.score(pairs).map_err(to_py_err),
KeyValueLike::Dict(mapping) => self.inner.score(mapping).map_err(to_py_err),
})
}
fn update_and_score(&mut self, py: Python<'_>, feature: KeyValueLike) -> PyResult<f64> {
py.detach(|| match feature {
KeyValueLike::Pairs(pairs) => self.inner.update_and_score(pairs).map_err(to_py_err),
KeyValueLike::Dict(mapping) => self.inner.update_and_score(mapping).map_err(to_py_err),
})
}
fn is_ready(&self) -> bool {
self.inner.is_ready()
}
fn entries_seen(&self) -> u64 {
self.inner.entries_seen()
}
fn to_json(&self) -> PyResult<String> {
self.inner.to_json().map_err(to_py_err)
}
#[staticmethod]
fn from_json(json: StrOrBytes) -> PyResult<Self> {
let inner = FeatureSketch::from_json(json).map_err(to_py_err)?;
Ok(Self { inner })
}
fn save_json(&self, path: PathBuf) -> PyResult<()> {
self.inner.save_json(path).map_err(to_py_err)
}
#[staticmethod]
fn load_json(path: PathBuf) -> PyResult<Self> {
let inner = FeatureSketch::load_json(path).map_err(to_py_err)?;
Ok(Self { inner })
}
fn __repr__(&self) -> String {
let c = self.inner.config();
format!(
"FeatureSketch(value_projection_dims={}, presence_projection_dims={}, chains_per_ensemble={}, chain_depth={}, sketch_rows={}, sketch_buckets={}, decay_half_life={}, entries_seen={})",
c.value_projection_dims(),
c.presence_projection_dims(),
c.chains_per_ensemble(),
c.chain_depth(),
c.sketch_rows(),
c.sketch_buckets(),
c.decay_half_life(),
self.inner.entries_seen(),
)
}
fn __str__(&self) -> String {
self.__repr__()
}
fn __copy__(&self) -> Self {
self.clone()
}
#[allow(unused_variables)]
fn __deepcopy__<'py>(&self, memo: Bound<'py, PyAny>) -> Self {
self.clone()
}
fn __getstate__(&self) -> PyResult<String> {
self.to_json()
}
fn __setstate__(&mut self, state: String) -> PyResult<()> {
let new = Self::from_json(state.into())?;
*self = new;
Ok(())
}
}