use pyo3::prelude::*;
#[pyclass(name = "UntypedArray")]
pub struct UntypedArray {
data: Vec<u8>,
dtype_name: String,
shape: Vec<usize>,
itemsize: usize,
}
#[pymethods]
impl UntypedArray {
#[new]
pub fn new(shape: Vec<usize>, dtype_name: String) -> PyResult<Self> {
let itemsize = resolve_itemsize(&dtype_name)?;
let n: usize = shape.iter().product::<usize>().max(1);
Ok(Self {
data: vec![0u8; n * itemsize],
dtype_name,
shape,
itemsize,
})
}
pub fn dtype_name(&self) -> &str {
&self.dtype_name
}
pub fn itemsize(&self) -> usize {
self.itemsize
}
pub fn shape(&self) -> Vec<usize> {
self.shape.clone()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn nbytes(&self) -> usize {
self.data.len()
}
pub fn size(&self) -> usize {
self.shape.iter().product()
}
pub fn is_floating(&self) -> bool {
matches!(
self.dtype_name.as_str(),
"float32" | "f32" | "float64" | "f64"
)
}
pub fn is_integer(&self) -> bool {
matches!(
self.dtype_name.as_str(),
"int32" | "i32" | "int64" | "i64" | "int8" | "i8" | "uint8" | "u8"
)
}
pub fn read_as_f64(&self, flat_index: usize) -> PyResult<f64> {
let offset = flat_index * self.itemsize;
if offset + self.itemsize > self.data.len() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"flat_index {flat_index} is out of bounds"
)));
}
let value = match self.dtype_name.as_str() {
"float32" | "f32" => {
let bytes: [u8; 4] = self.data[offset..offset + 4].try_into().map_err(|_| {
pyo3::exceptions::PyValueError::new_err("slice conversion error (f32)")
})?;
f32::from_le_bytes(bytes) as f64
}
"float64" | "f64" => {
let bytes: [u8; 8] = self.data[offset..offset + 8].try_into().map_err(|_| {
pyo3::exceptions::PyValueError::new_err("slice conversion error (f64)")
})?;
f64::from_le_bytes(bytes)
}
"int32" | "i32" => {
let bytes: [u8; 4] = self.data[offset..offset + 4].try_into().map_err(|_| {
pyo3::exceptions::PyValueError::new_err("slice conversion error (i32)")
})?;
i32::from_le_bytes(bytes) as f64
}
"int64" | "i64" => {
let bytes: [u8; 8] = self.data[offset..offset + 8].try_into().map_err(|_| {
pyo3::exceptions::PyValueError::new_err("slice conversion error (i64)")
})?;
i64::from_le_bytes(bytes) as f64
}
"int8" | "i8" => self.data[offset] as i8 as f64,
"uint8" | "u8" | "bool" | "b" => self.data[offset] as f64,
_ => 0.0,
};
Ok(value)
}
pub fn write_f64(&mut self, flat_index: usize, value: f64) -> PyResult<()> {
let offset = flat_index * self.itemsize;
if offset + self.itemsize > self.data.len() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"flat_index {flat_index} is out of bounds"
)));
}
match self.dtype_name.as_str() {
"float32" | "f32" => {
self.data[offset..offset + 4].copy_from_slice(&(value as f32).to_le_bytes());
}
"float64" | "f64" => {
self.data[offset..offset + 8].copy_from_slice(&value.to_le_bytes());
}
"int32" | "i32" => {
self.data[offset..offset + 4].copy_from_slice(&(value as i32).to_le_bytes());
}
"int64" | "i64" => {
self.data[offset..offset + 8].copy_from_slice(&(value as i64).to_le_bytes());
}
"int8" | "i8" => {
self.data[offset] = value as i8 as u8;
}
"uint8" | "u8" => {
self.data[offset] = value as u8;
}
"bool" | "b" => {
self.data[offset] = if value != 0.0 { 1u8 } else { 0u8 };
}
_ => {}
}
Ok(())
}
}
fn resolve_itemsize(dtype_name: &str) -> PyResult<usize> {
match dtype_name {
"float32" | "f32" => Ok(4),
"float64" | "f64" => Ok(8),
"int32" | "i32" => Ok(4),
"int64" | "i64" => Ok(8),
"bool" | "b" => Ok(1),
"uint8" | "u8" => Ok(1),
"int8" | "i8" => Ok(1),
_ => Err(pyo3::exceptions::PyValueError::new_err(format!(
"unsupported dtype '{dtype_name}'; supported: float32, f32, float64, f64, \
int32, i32, int64, i64, bool, b, uint8, u8, int8, i8"
))),
}
}
pub fn register_untyped_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<UntypedArray>()?;
Ok(())
}