use pyo3::prelude::*;
#[derive(Debug, Clone)]
#[pyclass(name = "DtypeField", from_py_object)]
pub struct DtypeField {
#[pyo3(get)]
pub name: String,
#[pyo3(get)]
pub dtype: String,
#[pyo3(get)]
pub offset: usize,
}
#[pyclass(name = "StructuredDtype")]
pub struct StructuredDtype {
fields: Vec<DtypeField>,
itemsize: usize,
}
#[pymethods]
impl StructuredDtype {
#[new]
pub fn new(field_specs: Vec<(String, String)>) -> PyResult<Self> {
let mut offset = 0usize;
let mut fields = Vec::with_capacity(field_specs.len());
for (name, dtype) in field_specs {
let size = dtype_size(&dtype)?;
fields.push(DtypeField {
name,
dtype,
offset,
});
offset += size;
}
Ok(Self {
fields,
itemsize: offset,
})
}
pub fn names(&self) -> Vec<String> {
self.fields.iter().map(|f| f.name.clone()).collect()
}
pub fn itemsize(&self) -> usize {
self.itemsize
}
pub fn offsets(&self) -> Vec<usize> {
self.fields.iter().map(|f| f.offset).collect()
}
pub fn field_count(&self) -> usize {
self.fields.len()
}
}
fn dtype_size(dtype: &str) -> PyResult<usize> {
match dtype {
"f32" | "float32" => Ok(4),
"f64" | "float64" => Ok(8),
"i32" | "int32" => Ok(4),
"i64" | "int64" => Ok(8),
"u32" | "uint32" => Ok(4),
"u64" | "uint64" => Ok(8),
"bool" => Ok(1),
"i8" | "int8" => Ok(1),
"u8" | "uint8" => Ok(1),
_ => Err(pyo3::exceptions::PyValueError::new_err(format!(
"unknown dtype '{dtype}'; supported: f32, f64, i32, i64, u32, u64, bool, i8, u8"
))),
}
}
#[pyclass(name = "StructuredArray")]
pub struct StructuredArray {
dtype: StructuredDtype,
data: Vec<u8>,
n_records: usize,
}
#[pymethods]
impl StructuredArray {
#[new]
pub fn new_empty(n_records: usize, field_specs: Vec<(String, String)>) -> PyResult<Self> {
let dtype = StructuredDtype::new(field_specs)?;
let data = vec![0u8; n_records * dtype.itemsize];
Ok(Self {
dtype,
data,
n_records,
})
}
pub fn n_records(&self) -> usize {
self.n_records
}
pub fn itemsize(&self) -> usize {
self.dtype.itemsize
}
pub fn get_field_f64(&self, field_name: &str) -> PyResult<Vec<f64>> {
let field = self
.dtype
.fields
.iter()
.find(|f| f.name == field_name)
.ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("field '{field_name}' not found"))
})?;
if field.dtype != "f64" && field.dtype != "float64" {
return Err(pyo3::exceptions::PyTypeError::new_err(format!(
"field '{field_name}' has dtype '{}', not f64",
field.dtype
)));
}
let mut result = Vec::with_capacity(self.n_records);
for i in 0..self.n_records {
let byte_offset = i * self.dtype.itemsize + field.offset;
let bytes: [u8; 8] = self.data[byte_offset..byte_offset + 8]
.try_into()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("slice conversion error"))?;
result.push(f64::from_le_bytes(bytes));
}
Ok(result)
}
pub fn set_field_f64(&mut self, field_name: &str, values: Vec<f64>) -> PyResult<()> {
if values.len() != self.n_records {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"values length {} does not match n_records {}",
values.len(),
self.n_records
)));
}
let field = self
.dtype
.fields
.iter()
.find(|f| f.name == field_name)
.ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("field '{field_name}' not found"))
})?;
if field.dtype != "f64" && field.dtype != "float64" {
return Err(pyo3::exceptions::PyTypeError::new_err(format!(
"field '{field_name}' has dtype '{}', not f64",
field.dtype
)));
}
let field_offset = field.offset;
let itemsize = self.dtype.itemsize;
for (i, &v) in values.iter().enumerate() {
let byte_offset = i * itemsize + field_offset;
self.data[byte_offset..byte_offset + 8].copy_from_slice(&v.to_le_bytes());
}
Ok(())
}
pub fn field_names(&self) -> Vec<String> {
self.dtype.names()
}
}
pub fn register_structured_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DtypeField>()?;
m.add_class::<StructuredDtype>()?;
m.add_class::<StructuredArray>()?;
Ok(())
}