use crate::{
decrypt_from_storage as rust_decrypt_from_storage, encrypt_from_file as rust_encrypt_from_file,
streaming_decrypt_from_storage as rust_streaming_decrypt_from_storage, ChunkInfo, DataMap,
EncryptedChunk, XorName,
};
use bytes::Bytes;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyInt, PyTuple};
use std::borrow::Cow;
use std::path::Path;
#[pyclass]
#[derive(Clone)]
pub struct PyXorName {
inner: XorName,
}
#[pymethods]
impl PyXorName {
#[staticmethod]
pub fn from_content(content: &[u8]) -> Self {
Self {
inner: XorName::from_content(content),
}
}
pub fn as_bytes(&self) -> Cow<'_, [u8]> {
self.inner.0.to_vec().into()
}
}
#[pyclass]
#[derive(Clone)]
pub struct PyChunkInfo {
inner: ChunkInfo,
}
#[pymethods]
impl PyChunkInfo {
#[new]
pub fn new(index: usize, dst_hash: PyXorName, src_hash: PyXorName, src_size: usize) -> Self {
Self {
inner: ChunkInfo {
index,
dst_hash: dst_hash.inner,
src_hash: src_hash.inner,
src_size,
},
}
}
#[getter]
pub fn index(&self) -> usize {
self.inner.index
}
#[getter]
pub fn dst_hash(&self) -> PyXorName {
PyXorName {
inner: self.inner.dst_hash,
}
}
#[getter]
pub fn src_hash(&self) -> PyXorName {
PyXorName {
inner: self.inner.src_hash,
}
}
#[getter]
pub fn src_size(&self) -> usize {
self.inner.src_size
}
}
#[pyclass]
#[derive(Clone)]
pub struct PyDataMap {
inner: DataMap,
}
#[pymethods]
impl PyDataMap {
#[staticmethod]
pub fn from_json(json_str: &str) -> PyResult<Self> {
let inner = serde_json::from_str(json_str)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid JSON: {e}")))?;
Ok(Self { inner })
}
pub fn to_json(&self) -> PyResult<String> {
serde_json::to_string(&self.inner).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Failed to serialize: {e}"))
})
}
#[new]
pub fn new(chunk_infos: Vec<PyChunkInfo>) -> Self {
let inner_infos = chunk_infos.into_iter().map(|info| info.inner).collect();
Self {
inner: DataMap::new(inner_infos),
}
}
#[staticmethod]
pub fn with_child(chunk_infos: Vec<PyChunkInfo>, child: usize) -> Self {
let inner_infos = chunk_infos.into_iter().map(|info| info.inner).collect();
Self {
inner: DataMap::with_child(inner_infos, child),
}
}
pub fn child(&self) -> Option<usize> {
self.inner.child()
}
pub fn infos(&self) -> Vec<PyChunkInfo> {
self.inner
.infos()
.iter()
.map(|info| PyChunkInfo {
inner: info.clone(),
})
.collect()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_child(&self) -> bool {
self.inner.is_child()
}
}
#[pyclass]
#[derive(Clone)]
pub struct PyEncryptedChunk {
inner: EncryptedChunk,
}
#[pymethods]
impl PyEncryptedChunk {
pub fn content_size(&self) -> usize {
self.inner.content.len()
}
pub fn hash(&self) -> Cow<'_, [u8]> {
XorName::from_content(&self.inner.content).0.to_vec().into()
}
}
#[pyfunction]
pub fn encrypt(data: &[u8]) -> PyResult<(PyDataMap, Vec<PyEncryptedChunk>)> {
let bytes = Bytes::copy_from_slice(data);
let (data_map, chunks) = crate::encrypt(bytes)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Encryption failed: {e}")))?;
let py_chunks = chunks
.into_iter()
.map(|chunk| PyEncryptedChunk { inner: chunk })
.collect();
Ok((PyDataMap { inner: data_map }, py_chunks))
}
#[pyfunction]
pub fn decrypt(
data_map: &PyDataMap,
chunks: Vec<PyEncryptedChunk>,
) -> PyResult<std::borrow::Cow<'_, [u8]>> {
let inner_chunks = chunks
.into_iter()
.map(|chunk| chunk.inner)
.collect::<Vec<_>>();
let bytes = crate::decrypt(&data_map.inner, &inner_chunks)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Decryption failed: {e}")))?;
Ok(bytes.to_vec().into())
}
#[pyfunction]
pub fn encrypt_from_file(input_file: &str, output_dir: &str) -> PyResult<(PyDataMap, Vec<String>)> {
let input_path = Path::new(input_file);
let output_path = Path::new(output_dir);
let (data_map, chunk_names) = rust_encrypt_from_file(input_path, output_path).map_err(|e| {
pyo3::exceptions::PyOSError::new_err(format!("Failed to encrypt file: {e}"))
})?;
let chunk_names = chunk_names.iter().map(|name| hex::encode(name.0)).collect();
Ok((PyDataMap { inner: data_map }, chunk_names))
}
#[pyfunction]
pub fn decrypt_from_storage(
data_map: PyDataMap,
output_file: &str,
get_chunk: Bound<'_, PyAny>,
) -> PyResult<()> {
let output_path = Path::new(output_file);
let get_chunk_wrapper = |name: XorName| -> crate::Result<Bytes> {
let name_str = hex::encode(name.0);
let chunk = get_chunk
.call1((name_str,))
.map_err(|e| crate::Error::Python(format!("Failed to call get_chunk: {e}")))?;
let bytes = chunk
.downcast::<PyBytes>()
.map_err(|e| crate::Error::Python(format!("get_chunk must return bytes: {e}")))?;
Ok(Bytes::copy_from_slice(bytes.as_bytes()))
};
rust_decrypt_from_storage(&data_map.inner, output_path, get_chunk_wrapper)
.map_err(|e| pyo3::exceptions::PyOSError::new_err(format!("Decryption failed: {e}")))
}
#[pyfunction]
pub fn streaming_decrypt_from_storage(
data_map: PyDataMap,
output_file: &str,
get_chunks: Bound<'_, PyAny>,
) -> PyResult<()> {
let output_path = Path::new(output_file);
let get_chunks_wrapper = |names: &[(usize, XorName)]| -> crate::Result<Vec<(usize, Bytes)>> {
let name_strs: Vec<(usize, String)> =
names.iter().map(|(i, x)| (*i, hex::encode(x.0))).collect();
let chunks = get_chunks
.call1((name_strs,))
.map_err(|e| crate::Error::Python(format!("Failed to call get_chunks: {e}")))?;
let chunks = chunks
.try_iter()
.map_err(|e| crate::Error::Python(format!("get_chunks must return a list: {e}")))?;
let mut result = Vec::new();
for chunk in chunks {
let chunk = chunk
.map_err(|e| crate::Error::Python(format!("Failed to iterate chunks: {e}")))?;
let chunk_tuple = chunk
.downcast::<PyTuple>()
.map_err(|e| crate::Error::Python(format!("get_chunks must return tuple: {e}")))?;
if chunk_tuple.len() != 2 {
return Err(crate::Error::Python(
"get_chunks must return tuples of length 2".to_string(),
));
}
let index_item = chunk_tuple.get_item(0)?;
let index = index_item
.downcast::<PyInt>()
.map_err(|e| crate::Error::Python(format!("First element must be integer: {e}")))?
.extract::<usize>()
.map_err(|e| crate::Error::Python(format!("Failed to extract index: {e}")))?;
let bytes_item = chunk_tuple.get_item(1)?;
let bytes = bytes_item
.downcast::<PyBytes>()
.map_err(|e| crate::Error::Python(format!("Second element must be bytes: {e}")))?;
result.push((index, Bytes::copy_from_slice(bytes.as_bytes())));
}
Ok(result)
};
rust_streaming_decrypt_from_storage(&data_map.inner, output_path, get_chunks_wrapper).map_err(
|e| pyo3::exceptions::PyOSError::new_err(format!("Streaming decryption failed: {e}")),
)
}
#[pymodule]
#[pyo3(name = "_self_encryption")]
fn self_encryption_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyDataMap>()?;
m.add_class::<PyEncryptedChunk>()?;
m.add_class::<PyXorName>()?;
m.add_class::<PyChunkInfo>()?;
m.add_function(wrap_pyfunction!(encrypt, m)?)?;
m.add_function(wrap_pyfunction!(decrypt, m)?)?;
m.add_function(wrap_pyfunction!(encrypt_from_file, m)?)?;
m.add_function(wrap_pyfunction!(decrypt_from_storage, m)?)?;
m.add_function(wrap_pyfunction!(streaming_decrypt_from_storage, m)?)?;
Ok(())
}