use super::super::scalars::*;
use derive_more::{Deref, From, Into};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyBytes};
use pyo3::Py;
#[derive(Copy, Clone, Eq, PartialEq, Debug, From, Into, Deref)]
#[pyclass(name = "ScalarNonZero", from_py_object)]
pub struct PyScalarNonZero(pub(crate) ScalarNonZero);
#[pymethods]
impl PyScalarNonZero {
#[pyo3(name = "to_bytes")]
fn encode(&self, py: Python) -> Py<PyAny> {
PyBytes::new(py, &self.0.to_bytes()).into()
}
#[staticmethod]
#[pyo3(name = "from_bytes")]
fn decode(bytes: &[u8]) -> Option<PyScalarNonZero> {
ScalarNonZero::from_slice(bytes).map(PyScalarNonZero)
}
#[staticmethod]
#[pyo3(name = "from_hex")]
fn from_hex(hex: &str) -> Option<PyScalarNonZero> {
ScalarNonZero::from_hex(hex).map(PyScalarNonZero)
}
#[pyo3(name = "to_hex")]
fn as_hex(&self) -> String {
self.0.to_hex()
}
#[staticmethod]
#[pyo3(name = "random")]
fn random() -> PyScalarNonZero {
ScalarNonZero::random(&mut rand::rng()).into()
}
#[staticmethod]
#[pyo3(name = "from_hash")]
fn from_hash(v: &[u8]) -> PyResult<PyScalarNonZero> {
if v.len() != 64 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Hash must be 64 bytes",
));
}
let mut arr = [0u8; 64];
arr.copy_from_slice(v);
Ok(ScalarNonZero::from_hash(&arr).into())
}
#[staticmethod]
#[pyo3(name = "one")]
fn one() -> PyScalarNonZero {
ScalarNonZero::one().into()
}
#[pyo3(name = "invert")]
fn invert(&self) -> PyScalarNonZero {
self.0.invert().into()
}
#[pyo3(name = "mul")]
fn mul(&self, other: &PyScalarNonZero) -> PyScalarNonZero {
(self.0 * other.0).into() }
#[pyo3(name = "to_can_be_zero")]
fn to_can_be_zero(&self) -> PyScalarCanBeZero {
let s: ScalarCanBeZero = self.0.into();
PyScalarCanBeZero(s)
}
fn __mul__(&self, other: &PyScalarNonZero) -> PyScalarNonZero {
self.mul(other)
}
fn __repr__(&self) -> String {
format!("ScalarNonZero({})", self.as_hex())
}
fn __str__(&self) -> String {
self.as_hex()
}
fn __eq__(&self, other: &PyScalarNonZero) -> bool {
self.0 == other.0
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, From, Into, Deref)]
#[pyclass(name = "ScalarCanBeZero", from_py_object)]
pub struct PyScalarCanBeZero(pub(crate) ScalarCanBeZero);
#[pymethods]
impl PyScalarCanBeZero {
#[pyo3(name = "to_bytes")]
fn encode(&self, py: Python) -> Py<PyAny> {
PyBytes::new(py, &self.0.to_bytes()).into()
}
#[staticmethod]
#[pyo3(name = "from_bytes")]
fn decode(bytes: &[u8]) -> Option<PyScalarCanBeZero> {
ScalarCanBeZero::from_slice(bytes).map(PyScalarCanBeZero)
}
#[staticmethod]
#[pyo3(name = "from_hex")]
fn from_hex(hex: &str) -> Option<PyScalarCanBeZero> {
ScalarCanBeZero::from_hex(hex).map(PyScalarCanBeZero)
}
#[pyo3(name = "to_hex")]
fn as_hex(&self) -> String {
self.0.to_hex()
}
#[staticmethod]
#[pyo3(name = "one")]
fn one() -> PyScalarCanBeZero {
ScalarCanBeZero::one().into()
}
#[staticmethod]
#[pyo3(name = "zero")]
fn zero() -> PyScalarCanBeZero {
ScalarCanBeZero::zero().into()
}
#[staticmethod]
#[pyo3(name = "random")]
fn random() -> PyScalarCanBeZero {
ScalarCanBeZero::random(&mut rand::rng()).into()
}
#[pyo3(name = "is_zero")]
fn is_zero(&self) -> bool {
self.0.is_zero()
}
#[pyo3(name = "add")]
fn add(&self, other: &PyScalarCanBeZero) -> PyScalarCanBeZero {
(self.0 + other.0).into()
}
#[pyo3(name = "sub")]
fn sub(&self, other: &PyScalarCanBeZero) -> PyScalarCanBeZero {
(self.0 - other.0).into()
}
#[pyo3(name = "to_non_zero")]
fn to_non_zero(&self) -> Option<PyScalarNonZero> {
let s: ScalarNonZero = self.0.try_into().ok()?;
Some(PyScalarNonZero(s))
}
fn __add__(&self, other: &PyScalarCanBeZero) -> PyScalarCanBeZero {
self.add(other)
}
fn __sub__(&self, other: &PyScalarCanBeZero) -> PyScalarCanBeZero {
self.sub(other)
}
fn __repr__(&self) -> String {
format!("ScalarCanBeZero({})", self.as_hex())
}
fn __str__(&self) -> String {
self.as_hex()
}
fn __eq__(&self, other: &PyScalarCanBeZero) -> bool {
self.0 == other.0
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyScalarNonZero>()?;
m.add_class::<PyScalarCanBeZero>()?;
Ok(())
}