1#![allow(non_local_definitions)]
7#![allow(missing_docs)]
8
9use pyo3::prelude::*;
10use pyo3::types::PyBytes;
11use std::collections::HashMap;
12
13use crate::{
14 Algorithm, CompressionParameters, EncParameters, FormatFlags, KemParameters, PqcBinaryFormat,
15 PqcMetadata, SigParameters,
16};
17
18#[pyclass(name = "Algorithm")]
20#[derive(Clone)]
21pub struct PyAlgorithm {
22 inner: Algorithm,
23}
24
25#[pymethods]
26impl PyAlgorithm {
27 #[new]
28 fn new(name: &str) -> PyResult<Self> {
29 let inner = match name.to_lowercase().as_str() {
30 "classical" => Algorithm::Classical,
31 "hybrid" => Algorithm::Hybrid,
32 "post-quantum" | "postquantum" => Algorithm::PostQuantum,
33 "ml-kem-1024" | "mlkem1024" => Algorithm::MlKem1024,
34 "multi-algorithm" | "multialgorithm" => Algorithm::MultiAlgorithm,
35 "multi-kem" | "multikem" => Algorithm::MultiKem,
36 "multi-kem-triple" | "multikemtriple" => Algorithm::MultiKemTriple,
37 "quad-layer" | "quadlayer" => Algorithm::QuadLayer,
38 "pq3-stack" | "pq3stack" => Algorithm::Pq3Stack,
39 "lattice-code-hybrid" | "latticecodehybrid" => Algorithm::LatticeCodeHybrid,
40 _ => {
41 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
42 "Unknown algorithm: {}",
43 name
44 )))
45 }
46 };
47 Ok(Self { inner })
48 }
49
50 #[getter]
51 fn name(&self) -> String {
52 self.inner.name().to_string()
53 }
54
55 #[getter]
56 fn id(&self) -> u16 {
57 self.inner.as_id()
58 }
59
60 fn __str__(&self) -> String {
61 self.name()
62 }
63
64 fn __repr__(&self) -> String {
65 format!("Algorithm('{}')", self.name())
66 }
67}
68
69#[pyclass(name = "EncParameters")]
71#[derive(Clone)]
72pub struct PyEncParameters {
73 #[pyo3(get, set)]
74 pub iv: Vec<u8>,
75 #[pyo3(get, set)]
76 pub tag: Vec<u8>,
77}
78
79#[pymethods]
80impl PyEncParameters {
81 #[new]
82 fn new(iv: Vec<u8>, tag: Vec<u8>) -> Self {
83 Self { iv, tag }
84 }
85
86 fn to_dict(&self) -> HashMap<String, Vec<u8>> {
87 let mut map = HashMap::new();
88 map.insert("iv".to_string(), self.iv.clone());
89 map.insert("tag".to_string(), self.tag.clone());
90 map
91 }
92}
93
94#[pyclass(name = "KemParameters")]
96#[derive(Clone)]
97pub struct PyKemParameters {
98 #[pyo3(get, set)]
99 pub public_key: Vec<u8>,
100 #[pyo3(get, set)]
101 pub ciphertext: Vec<u8>,
102}
103
104#[pymethods]
105impl PyKemParameters {
106 #[new]
107 fn new(public_key: Vec<u8>, ciphertext: Vec<u8>) -> Self {
108 Self {
109 public_key,
110 ciphertext,
111 }
112 }
113}
114
115#[pyclass(name = "SigParameters")]
117#[derive(Clone)]
118pub struct PySigParameters {
119 #[pyo3(get, set)]
120 pub public_key: Vec<u8>,
121 #[pyo3(get, set)]
122 pub signature: Vec<u8>,
123}
124
125#[pymethods]
126impl PySigParameters {
127 #[new]
128 fn new(public_key: Vec<u8>, signature: Vec<u8>) -> Self {
129 Self {
130 public_key,
131 signature,
132 }
133 }
134}
135
136#[pyclass(name = "CompressionParameters")]
138#[derive(Clone)]
139pub struct PyCompressionParameters {
140 #[pyo3(get, set)]
141 pub algorithm: String,
142 #[pyo3(get, set)]
143 pub level: u8,
144 #[pyo3(get, set)]
145 pub original_size: u64,
146}
147
148#[pymethods]
149impl PyCompressionParameters {
150 #[new]
151 fn new(algorithm: String, level: u8, original_size: u64) -> Self {
152 Self {
153 algorithm,
154 level,
155 original_size,
156 }
157 }
158}
159
160#[pyclass(name = "PqcMetadata")]
162#[derive(Clone)]
163pub struct PyPqcMetadata {
164 #[pyo3(get, set)]
165 pub enc_params: PyEncParameters,
166 #[pyo3(get, set)]
167 pub kem_params: Option<PyKemParameters>,
168 #[pyo3(get, set)]
169 pub sig_params: Option<PySigParameters>,
170 #[pyo3(get, set)]
171 pub compression_params: Option<PyCompressionParameters>,
172}
173
174#[pymethods]
175impl PyPqcMetadata {
176 #[new]
177 fn new(
178 enc_params: PyEncParameters,
179 kem_params: Option<PyKemParameters>,
180 sig_params: Option<PySigParameters>,
181 compression_params: Option<PyCompressionParameters>,
182 ) -> Self {
183 Self {
184 enc_params,
185 kem_params,
186 sig_params,
187 compression_params,
188 }
189 }
190
191 fn add_custom(&mut self, _key: String, _value: Vec<u8>) {
192 }
194}
195
196impl PyPqcMetadata {
197 fn to_rust(&self) -> PqcMetadata {
198 PqcMetadata {
199 kem_params: self.kem_params.as_ref().map(|k| KemParameters {
200 public_key: k.public_key.clone(),
201 ciphertext: k.ciphertext.clone(),
202 params: HashMap::new(),
203 }),
204 sig_params: self.sig_params.as_ref().map(|s| SigParameters {
205 public_key: s.public_key.clone(),
206 signature: s.signature.clone(),
207 params: HashMap::new(),
208 }),
209 enc_params: EncParameters {
210 iv: self.enc_params.iv.clone(),
211 tag: self.enc_params.tag.clone(),
212 params: HashMap::new(),
213 },
214 compression_params: self
215 .compression_params
216 .as_ref()
217 .map(|c| CompressionParameters {
218 algorithm: c.algorithm.clone(),
219 level: c.level,
220 original_size: c.original_size,
221 params: HashMap::new(),
222 }),
223 custom: HashMap::new(),
224 }
225 }
226}
227
228#[pyclass(name = "FormatFlags")]
230#[derive(Clone)]
231pub struct PyFormatFlags {
232 inner: FormatFlags,
233}
234
235#[pymethods]
236impl PyFormatFlags {
237 #[new]
238 fn new() -> Self {
239 Self {
240 inner: FormatFlags::new(),
241 }
242 }
243
244 fn with_compression(&mut self) -> Self {
245 Self {
246 inner: self.inner.with_compression(),
247 }
248 }
249
250 fn with_streaming(&mut self) -> Self {
251 Self {
252 inner: self.inner.with_streaming(),
253 }
254 }
255
256 fn with_additional_auth(&mut self) -> Self {
257 Self {
258 inner: self.inner.with_additional_auth(),
259 }
260 }
261
262 fn with_experimental(&mut self) -> Self {
263 Self {
264 inner: self.inner.with_experimental(),
265 }
266 }
267
268 #[getter]
269 fn has_compression(&self) -> bool {
270 self.inner.has_compression()
271 }
272
273 #[getter]
274 fn has_streaming(&self) -> bool {
275 self.inner.has_streaming()
276 }
277
278 #[getter]
279 fn has_additional_auth(&self) -> bool {
280 self.inner.has_additional_auth()
281 }
282
283 #[getter]
284 fn has_experimental(&self) -> bool {
285 self.inner.has_experimental()
286 }
287}
288
289#[pyclass(name = "PqcBinaryFormat")]
291pub struct PyPqcBinaryFormat {
292 inner: PqcBinaryFormat,
293}
294
295#[pymethods]
296impl PyPqcBinaryFormat {
297 #[new]
307 fn new(algorithm: PyAlgorithm, metadata: PyPqcMetadata, data: Vec<u8>) -> Self {
308 let rust_metadata = metadata.to_rust();
309 let inner = PqcBinaryFormat::new(algorithm.inner, rust_metadata, data);
310 Self { inner }
311 }
312
313 #[staticmethod]
315 fn with_flags(
316 algorithm: PyAlgorithm,
317 flags: PyFormatFlags,
318 metadata: PyPqcMetadata,
319 data: Vec<u8>,
320 ) -> Self {
321 let rust_metadata = metadata.to_rust();
322 let inner = PqcBinaryFormat::with_flags(algorithm.inner, flags.inner, rust_metadata, data);
323 Self { inner }
324 }
325
326 fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<&'py PyBytes> {
331 let bytes = self
332 .inner
333 .to_bytes()
334 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
335 Ok(PyBytes::new(py, &bytes))
336 }
337
338 #[staticmethod]
346 fn from_bytes(data: &[u8]) -> PyResult<Self> {
347 let inner = PqcBinaryFormat::from_bytes(data)
348 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
349 Ok(Self { inner })
350 }
351
352 fn validate(&self) -> PyResult<()> {
354 self.inner
355 .validate()
356 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
357 }
358
359 #[getter]
361 fn algorithm(&self) -> PyAlgorithm {
362 PyAlgorithm {
363 inner: self.inner.algorithm(),
364 }
365 }
366
367 #[getter]
369 fn data(&self) -> Vec<u8> {
370 self.inner.data().to_vec()
371 }
372
373 #[getter]
375 fn flags(&self) -> PyFormatFlags {
376 PyFormatFlags {
377 inner: self.inner.flags(),
378 }
379 }
380
381 fn total_size(&self) -> usize {
383 self.inner.total_size()
384 }
385
386 fn __repr__(&self) -> String {
387 format!(
388 "PqcBinaryFormat(algorithm='{}', data_len={})",
389 self.inner.algorithm().name(),
390 self.inner.data().len()
391 )
392 }
393}
394
395#[pymodule]
397fn pqc_binary_format(_py: Python, m: &PyModule) -> PyResult<()> {
398 m.add_class::<PyAlgorithm>()?;
399 m.add_class::<PyEncParameters>()?;
400 m.add_class::<PyKemParameters>()?;
401 m.add_class::<PySigParameters>()?;
402 m.add_class::<PyCompressionParameters>()?;
403 m.add_class::<PyPqcMetadata>()?;
404 m.add_class::<PyFormatFlags>()?;
405 m.add_class::<PyPqcBinaryFormat>()?;
406
407 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
409 m.add("PQC_BINARY_VERSION", crate::PQC_BINARY_VERSION)?;
410
411 Ok(())
412}