scirs2_numpy/
structured.rs1use pyo3::prelude::*;
8
9#[derive(Debug, Clone)]
11#[pyclass(name = "DtypeField", from_py_object)]
12pub struct DtypeField {
13 #[pyo3(get)]
15 pub name: String,
16 #[pyo3(get)]
18 pub dtype: String,
19 #[pyo3(get)]
21 pub offset: usize,
22}
23
24#[pyclass(name = "StructuredDtype")]
28pub struct StructuredDtype {
29 fields: Vec<DtypeField>,
31 itemsize: usize,
33}
34
35#[pymethods]
36impl StructuredDtype {
37 #[new]
42 pub fn new(field_specs: Vec<(String, String)>) -> PyResult<Self> {
43 let mut offset = 0usize;
44 let mut fields = Vec::with_capacity(field_specs.len());
45 for (name, dtype) in field_specs {
46 let size = dtype_size(&dtype)?;
47 fields.push(DtypeField {
48 name,
49 dtype,
50 offset,
51 });
52 offset += size;
53 }
54 Ok(Self {
55 fields,
56 itemsize: offset,
57 })
58 }
59
60 pub fn names(&self) -> Vec<String> {
62 self.fields.iter().map(|f| f.name.clone()).collect()
63 }
64
65 pub fn itemsize(&self) -> usize {
67 self.itemsize
68 }
69
70 pub fn offsets(&self) -> Vec<usize> {
72 self.fields.iter().map(|f| f.offset).collect()
73 }
74
75 pub fn field_count(&self) -> usize {
77 self.fields.len()
78 }
79}
80
81fn dtype_size(dtype: &str) -> PyResult<usize> {
85 match dtype {
86 "f32" | "float32" => Ok(4),
87 "f64" | "float64" => Ok(8),
88 "i32" | "int32" => Ok(4),
89 "i64" | "int64" => Ok(8),
90 "u32" | "uint32" => Ok(4),
91 "u64" | "uint64" => Ok(8),
92 "bool" => Ok(1),
93 "i8" | "int8" => Ok(1),
94 "u8" | "uint8" => Ok(1),
95 _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
96 "unknown dtype '{dtype}'; supported: f32, f64, i32, i64, u32, u64, bool, i8, u8"
97 ))),
98 }
99}
100
101#[pyclass(name = "StructuredArray")]
104pub struct StructuredArray {
105 dtype: StructuredDtype,
107 data: Vec<u8>,
109 n_records: usize,
111}
112
113#[pymethods]
114impl StructuredArray {
115 #[new]
121 pub fn new_empty(n_records: usize, field_specs: Vec<(String, String)>) -> PyResult<Self> {
122 let dtype = StructuredDtype::new(field_specs)?;
123 let data = vec![0u8; n_records * dtype.itemsize];
124 Ok(Self {
125 dtype,
126 data,
127 n_records,
128 })
129 }
130
131 pub fn n_records(&self) -> usize {
133 self.n_records
134 }
135
136 pub fn itemsize(&self) -> usize {
138 self.dtype.itemsize
139 }
140
141 pub fn get_field_f64(&self, field_name: &str) -> PyResult<Vec<f64>> {
146 let field = self
147 .dtype
148 .fields
149 .iter()
150 .find(|f| f.name == field_name)
151 .ok_or_else(|| {
152 pyo3::exceptions::PyKeyError::new_err(format!("field '{field_name}' not found"))
153 })?;
154 if field.dtype != "f64" && field.dtype != "float64" {
155 return Err(pyo3::exceptions::PyTypeError::new_err(format!(
156 "field '{field_name}' has dtype '{}', not f64",
157 field.dtype
158 )));
159 }
160 let mut result = Vec::with_capacity(self.n_records);
161 for i in 0..self.n_records {
162 let byte_offset = i * self.dtype.itemsize + field.offset;
163 let bytes: [u8; 8] = self.data[byte_offset..byte_offset + 8]
164 .try_into()
165 .map_err(|_| pyo3::exceptions::PyValueError::new_err("slice conversion error"))?;
166 result.push(f64::from_le_bytes(bytes));
167 }
168 Ok(result)
169 }
170
171 pub fn set_field_f64(&mut self, field_name: &str, values: Vec<f64>) -> PyResult<()> {
176 if values.len() != self.n_records {
177 return Err(pyo3::exceptions::PyValueError::new_err(format!(
178 "values length {} does not match n_records {}",
179 values.len(),
180 self.n_records
181 )));
182 }
183 let field = self
184 .dtype
185 .fields
186 .iter()
187 .find(|f| f.name == field_name)
188 .ok_or_else(|| {
189 pyo3::exceptions::PyKeyError::new_err(format!("field '{field_name}' not found"))
190 })?;
191 if field.dtype != "f64" && field.dtype != "float64" {
192 return Err(pyo3::exceptions::PyTypeError::new_err(format!(
193 "field '{field_name}' has dtype '{}', not f64",
194 field.dtype
195 )));
196 }
197 let field_offset = field.offset;
198 let itemsize = self.dtype.itemsize;
199 for (i, &v) in values.iter().enumerate() {
200 let byte_offset = i * itemsize + field_offset;
201 self.data[byte_offset..byte_offset + 8].copy_from_slice(&v.to_le_bytes());
202 }
203 Ok(())
204 }
205
206 pub fn field_names(&self) -> Vec<String> {
208 self.dtype.names()
209 }
210}
211
212pub fn register_structured_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
217 m.add_class::<DtypeField>()?;
218 m.add_class::<StructuredDtype>()?;
219 m.add_class::<StructuredArray>()?;
220 Ok(())
221}