#![allow(non_local_definitions)]
#![allow(missing_docs)]
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use std::collections::HashMap;
use crate::{
Algorithm, CompressionParameters, EncParameters, FormatFlags, KemParameters, PqcBinaryFormat,
PqcMetadata, SigParameters,
};
#[pyclass(name = "Algorithm")]
#[derive(Clone)]
pub struct PyAlgorithm {
inner: Algorithm,
}
#[pymethods]
impl PyAlgorithm {
#[new]
fn new(name: &str) -> PyResult<Self> {
let inner = Algorithm::from_name(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Unknown algorithm: {}", name))
})?;
Ok(Self { inner })
}
#[getter]
fn name(&self) -> String {
self.inner.name().to_string()
}
#[getter]
fn id(&self) -> u16 {
self.inner.as_id()
}
fn __str__(&self) -> String {
self.name()
}
fn __repr__(&self) -> String {
format!("Algorithm('{}')", self.name())
}
}
#[pyclass(name = "EncParameters")]
#[derive(Clone)]
pub struct PyEncParameters {
#[pyo3(get, set)]
pub iv: Vec<u8>,
#[pyo3(get, set)]
pub tag: Vec<u8>,
}
#[pymethods]
impl PyEncParameters {
#[new]
fn new(iv: Vec<u8>, tag: Vec<u8>) -> Self {
Self { iv, tag }
}
fn to_dict(&self) -> HashMap<String, Vec<u8>> {
let mut map = HashMap::new();
map.insert("iv".to_string(), self.iv.clone());
map.insert("tag".to_string(), self.tag.clone());
map
}
}
#[pyclass(name = "KemParameters")]
#[derive(Clone)]
pub struct PyKemParameters {
#[pyo3(get, set)]
pub public_key: Vec<u8>,
#[pyo3(get, set)]
pub ciphertext: Vec<u8>,
}
#[pymethods]
impl PyKemParameters {
#[new]
fn new(public_key: Vec<u8>, ciphertext: Vec<u8>) -> Self {
Self {
public_key,
ciphertext,
}
}
}
#[pyclass(name = "SigParameters")]
#[derive(Clone)]
pub struct PySigParameters {
#[pyo3(get, set)]
pub public_key: Vec<u8>,
#[pyo3(get, set)]
pub signature: Vec<u8>,
}
#[pymethods]
impl PySigParameters {
#[new]
fn new(public_key: Vec<u8>, signature: Vec<u8>) -> Self {
Self {
public_key,
signature,
}
}
}
#[pyclass(name = "CompressionParameters")]
#[derive(Clone)]
pub struct PyCompressionParameters {
#[pyo3(get, set)]
pub algorithm: String,
#[pyo3(get, set)]
pub level: u8,
#[pyo3(get, set)]
pub original_size: u64,
}
#[pymethods]
impl PyCompressionParameters {
#[new]
fn new(algorithm: String, level: u8, original_size: u64) -> Self {
Self {
algorithm,
level,
original_size,
}
}
}
#[pyclass(name = "PqcMetadata")]
#[derive(Clone)]
pub struct PyPqcMetadata {
#[pyo3(get, set)]
pub enc_params: PyEncParameters,
#[pyo3(get, set)]
pub kem_params: Option<PyKemParameters>,
#[pyo3(get, set)]
pub sig_params: Option<PySigParameters>,
#[pyo3(get, set)]
pub compression_params: Option<PyCompressionParameters>,
}
#[pymethods]
impl PyPqcMetadata {
#[new]
fn new(
enc_params: PyEncParameters,
kem_params: Option<PyKemParameters>,
sig_params: Option<PySigParameters>,
compression_params: Option<PyCompressionParameters>,
) -> Self {
Self {
enc_params,
kem_params,
sig_params,
compression_params,
}
}
fn add_custom(&mut self, _key: String, _value: Vec<u8>) {
}
}
impl PyPqcMetadata {
fn to_rust(&self) -> PqcMetadata {
PqcMetadata {
kem_params: self.kem_params.as_ref().map(|k| KemParameters {
public_key: k.public_key.clone(),
ciphertext: k.ciphertext.clone(),
params: HashMap::new(),
}),
sig_params: self.sig_params.as_ref().map(|s| SigParameters {
public_key: s.public_key.clone(),
signature: s.signature.clone(),
params: HashMap::new(),
}),
enc_params: EncParameters {
iv: self.enc_params.iv.clone(),
tag: self.enc_params.tag.clone(),
params: HashMap::new(),
},
compression_params: self
.compression_params
.as_ref()
.map(|c| CompressionParameters {
algorithm: c.algorithm.clone(),
level: c.level,
original_size: c.original_size,
params: HashMap::new(),
}),
custom: HashMap::new(),
}
}
}
#[pyclass(name = "FormatFlags")]
#[derive(Clone)]
pub struct PyFormatFlags {
inner: FormatFlags,
}
#[pymethods]
impl PyFormatFlags {
#[new]
fn new() -> Self {
Self {
inner: FormatFlags::new(),
}
}
fn with_compression(&mut self) -> Self {
Self {
inner: self.inner.with_compression(),
}
}
fn with_streaming(&mut self) -> Self {
Self {
inner: self.inner.with_streaming(),
}
}
fn with_additional_auth(&mut self) -> Self {
Self {
inner: self.inner.with_additional_auth(),
}
}
fn with_experimental(&mut self) -> Self {
Self {
inner: self.inner.with_experimental(),
}
}
#[getter]
fn has_compression(&self) -> bool {
self.inner.has_compression()
}
#[getter]
fn has_streaming(&self) -> bool {
self.inner.has_streaming()
}
#[getter]
fn has_additional_auth(&self) -> bool {
self.inner.has_additional_auth()
}
#[getter]
fn has_experimental(&self) -> bool {
self.inner.has_experimental()
}
}
#[pyclass(name = "PqcBinaryFormat")]
pub struct PyPqcBinaryFormat {
inner: PqcBinaryFormat,
}
#[pymethods]
impl PyPqcBinaryFormat {
#[new]
fn new(algorithm: PyAlgorithm, metadata: PyPqcMetadata, data: Vec<u8>) -> Self {
let rust_metadata = metadata.to_rust();
let inner = PqcBinaryFormat::new(algorithm.inner, rust_metadata, data);
Self { inner }
}
#[staticmethod]
fn with_flags(
algorithm: PyAlgorithm,
flags: PyFormatFlags,
metadata: PyPqcMetadata,
data: Vec<u8>,
) -> Self {
let rust_metadata = metadata.to_rust();
let inner = PqcBinaryFormat::with_flags(algorithm.inner, flags.inner, rust_metadata, data);
Self { inner }
}
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.inner
.to_bytes()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
Ok(PyBytes::new(py, &bytes))
}
#[staticmethod]
fn from_bytes(data: &[u8]) -> PyResult<Self> {
let inner = PqcBinaryFormat::from_bytes(data)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
Ok(Self { inner })
}
fn validate(&self) -> PyResult<()> {
self.inner
.validate()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
}
#[getter]
fn algorithm(&self) -> PyAlgorithm {
PyAlgorithm {
inner: self.inner.algorithm(),
}
}
#[getter]
fn data(&self) -> Vec<u8> {
self.inner.data().to_vec()
}
#[getter]
fn flags(&self) -> PyFormatFlags {
PyFormatFlags {
inner: self.inner.flags(),
}
}
fn total_size(&self) -> usize {
self.inner.total_size()
}
fn __repr__(&self) -> String {
format!(
"PqcBinaryFormat(algorithm='{}', data_len={})",
self.inner.algorithm().name(),
self.inner.data().len()
)
}
}
#[pymodule]
fn pqc_binary_format(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAlgorithm>()?;
m.add_class::<PyEncParameters>()?;
m.add_class::<PyKemParameters>()?;
m.add_class::<PySigParameters>()?;
m.add_class::<PyCompressionParameters>()?;
m.add_class::<PyPqcMetadata>()?;
m.add_class::<PyFormatFlags>()?;
m.add_class::<PyPqcBinaryFormat>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add("PQC_BINARY_VERSION", crate::PQC_BINARY_VERSION)?;
Ok(())
}