use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use super::{BaseNgrams, Ngrams};
use crate::trie::CountTrie;
#[pyclass(name = "Ngrams", subclass, from_py_object)]
#[derive(Clone)]
pub struct PyNgrams {
pub inner: Ngrams,
}
impl BaseNgrams for PyNgrams {
fn order(&self) -> usize {
self.inner.order()
}
fn min_order(&self) -> usize {
self.inner.min_order()
}
fn counts(&self) -> &CountTrie<String> {
self.inner.counts()
}
fn counts_mut(&mut self) -> &mut CountTrie<String> {
self.inner.counts_mut()
}
fn totals(&self) -> &Vec<u64> {
self.inner.totals()
}
fn totals_mut(&mut self) -> &mut Vec<u64> {
self.inner.totals_mut()
}
fn from_parts(
order: usize,
min_order: usize,
counts: CountTrie<String>,
totals: Vec<u64>,
) -> Self {
Self {
inner: Ngrams::from_parts(order, min_order, counts, totals),
}
}
}
fn pairs_to_pylist(py: Python<'_>, pairs: Vec<(Vec<String>, u64)>) -> PyResult<Py<PyAny>> {
let result = PyList::empty(py);
for (ngram, count) in pairs {
let tuple = PyTuple::new(py, &ngram)?;
result.append((tuple, count))?;
}
Ok(result.into_any().unbind())
}
#[pymethods]
impl PyNgrams {
#[new]
#[pyo3(signature = (n, *, min_n=None))]
fn new(n: usize, min_n: Option<usize>) -> PyResult<Self> {
Ngrams::new(n, min_n)
.map(|inner| Self { inner })
.map_err(PyErr::from)
}
fn count(&mut self, seq: Vec<String>) {
BaseNgrams::count(self, seq);
}
fn count_seqs(&mut self, seqs: Vec<Vec<String>>) {
BaseNgrams::count_seqs(self, seqs);
}
fn get(&self, ngram: Vec<String>) -> u64 {
BaseNgrams::get(self, ngram)
}
#[pyo3(signature = (n=None, *, order=None))]
fn most_common(
&self,
py: Python<'_>,
n: Option<usize>,
order: Option<usize>,
) -> PyResult<Py<PyAny>> {
let pairs = self.most_common_items(n, order).map_err(PyErr::from)?;
pairs_to_pylist(py, pairs)
}
#[pyo3(signature = (*, order=None))]
fn items(&self, py: Python<'_>, order: Option<usize>) -> PyResult<Py<PyAny>> {
let pairs = self.items_list(order).map_err(PyErr::from)?;
pairs_to_pylist(py, pairs)
}
#[pyo3(signature = (*, order=None))]
fn total(&self, order: Option<usize>) -> PyResult<u64> {
BaseNgrams::total(self, order).map_err(PyErr::from)
}
#[getter]
fn n(&self) -> usize {
self.order()
}
#[getter]
fn min_n(&self) -> usize {
self.min_order()
}
fn __getitem__(&self, ngram: Vec<String>) -> u64 {
BaseNgrams::get(self, ngram)
}
fn __len__(&self) -> usize {
BaseNgrams::len(self)
}
fn __contains__(&self, ngram: Vec<String>) -> bool {
BaseNgrams::contains(self, ngram)
}
fn __iter__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let ngrams = self.all_ngrams();
let result = PyList::empty(py);
for ngram in ngrams {
let tuple = PyTuple::new(py, &ngram)?;
result.append(tuple)?;
}
Ok(result.call_method0("__iter__")?.into_any().unbind())
}
fn __repr__(&self) -> String {
self.repr_string()
}
fn __add__(&self, other: &PyNgrams) -> PyResult<PyNgrams> {
BaseNgrams::add(self, other).map_err(PyErr::from)
}
fn __iadd__(&mut self, other: &PyNgrams) -> PyResult<()> {
BaseNgrams::iadd(self, other).map_err(PyErr::from)
}
#[pyo3(signature = (*, order=None))]
fn to_counter(&self, py: Python<'_>, order: Option<usize>) -> PyResult<Py<PyAny>> {
let effective_order = order.unwrap_or(self.order());
self.validate_order(Some(effective_order))
.map_err(PyErr::from)?;
let counter_type = py.import("collections")?.getattr("Counter")?;
let dict = PyDict::new(py);
for (ngram, count) in self.counts().all_counts() {
if ngram.len() == effective_order {
let tuple = PyTuple::new(py, &ngram)?;
dict.set_item(tuple, count)?;
}
}
Ok(counter_type.call1((dict,))?.unbind())
}
fn clear(&mut self) {
BaseNgrams::clear(self);
}
}
pub(crate) fn register_module(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
let ngram_module = PyModule::new(parent_module.py(), "ngram")?;
ngram_module.add_class::<PyNgrams>()?;
parent_module.add_submodule(&ngram_module)?;
Ok(())
}