1use pyo3::prelude::*;
14use pyo3::types::{PyDict, PyList, PyTuple};
15use thiserror::Error;
16
17#[derive(Debug, Error)]
21pub enum ArrayProtocolError {
22 #[error("unsupported dtype: {0}")]
24 UnsupportedDtype(String),
25
26 #[error("invalid typestr: {0}")]
28 InvalidTypestr(String),
29
30 #[error("python error: {0}")]
32 PythonError(String),
33}
34
35impl From<PyErr> for ArrayProtocolError {
36 fn from(e: PyErr) -> Self {
37 Self::PythonError(e.to_string())
38 }
39}
40
41impl From<ArrayProtocolError> for PyErr {
42 fn from(e: ArrayProtocolError) -> Self {
43 pyo3::exceptions::PyValueError::new_err(e.to_string())
44 }
45}
46
47pub fn parse_typestr(typestr: &str) -> Result<(char, usize), ArrayProtocolError> {
68 if typestr.len() < 3 {
69 return Err(ArrayProtocolError::InvalidTypestr(format!(
70 "too short: {typestr:?}"
71 )));
72 }
73 let mut chars = typestr.chars();
74 let endian = chars
75 .next()
76 .ok_or_else(|| ArrayProtocolError::InvalidTypestr(format!("empty typestr: {typestr:?}")))?;
77 if !matches!(endian, '<' | '>' | '=' | '|') {
79 return Err(ArrayProtocolError::InvalidTypestr(format!(
80 "unknown endianness character {endian:?} in {typestr:?}"
81 )));
82 }
83 let kind = chars.next().ok_or_else(|| {
84 ArrayProtocolError::InvalidTypestr(format!("missing kind in {typestr:?}"))
85 })?;
86 if !kind.is_ascii_alphabetic() {
87 return Err(ArrayProtocolError::InvalidTypestr(format!(
88 "invalid kind character {kind:?} in {typestr:?}"
89 )));
90 }
91 let size_str: String = chars.collect();
92 let byte_count = size_str.parse::<usize>().map_err(|_| {
93 ArrayProtocolError::InvalidTypestr(format!(
94 "invalid byte count {size_str:?} in {typestr:?}"
95 ))
96 })?;
97 if byte_count == 0 {
98 return Err(ArrayProtocolError::InvalidTypestr(format!(
99 "byte count must be > 0 in {typestr:?}"
100 )));
101 }
102 Ok((kind, byte_count))
103}
104
105pub trait ArrayProtocol {
112 fn array_interface(&self) -> ArrayInterfaceDict;
114
115 fn dtype_str(&self) -> &'static str;
117
118 fn shape(&self) -> Vec<usize>;
120
121 fn strides(&self) -> Vec<usize>;
123
124 fn data_ptr(&self) -> *const u8;
126
127 fn nbytes(&self) -> usize;
129}
130
131pub struct ArrayInterfaceDict {
137 pub shape: Vec<usize>,
139 pub typestr: String,
141 pub data_ptr: usize,
143 pub readonly: bool,
145 pub strides: Option<Vec<usize>>,
147 pub version: u8,
149}
150
151impl ArrayInterfaceDict {
152 pub fn to_py_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
157 let dict = PyDict::new(py);
158
159 let shape_tuple = PyTuple::new(py, self.shape.iter().copied())?;
161 dict.set_item("shape", shape_tuple)?;
162
163 dict.set_item("typestr", &self.typestr)?;
165
166 let data_tuple = PyTuple::new(py, [self.data_ptr, self.readonly as usize])?;
168 dict.set_item("data", data_tuple)?;
169
170 dict.set_item("version", self.version)?;
172
173 if let Some(ref strides) = self.strides {
175 let strides_tuple = PyTuple::new(py, strides.iter().copied())?;
176 dict.set_item("strides", strides_tuple)?;
177 }
178
179 Ok(dict)
180 }
181}
182
183#[pyclass(name = "NdArrayWrapper")]
191pub struct NdArrayWrapper {
192 data: Vec<f64>,
194 shape: Vec<usize>,
196 strides: Vec<usize>,
198 dtype: String,
200}
201
202#[pymethods]
203impl NdArrayWrapper {
204 #[new]
210 pub fn new(data: Vec<f64>, shape: Vec<usize>) -> PyResult<Self> {
211 let n: usize = shape.iter().product();
212 if data.len() != n {
213 return Err(pyo3::exceptions::PyValueError::new_err(format!(
214 "data length {} does not match shape product {}",
215 data.len(),
216 n
217 )));
218 }
219 let strides = compute_c_strides_bytes(&shape, std::mem::size_of::<f64>());
220 Ok(Self {
221 data,
222 shape,
223 strides,
224 dtype: "<f8".to_owned(),
225 })
226 }
227
228 #[pyo3(name = "__array__")]
236 pub fn array_method(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
237 let np = py.import("numpy").map_err(|e| {
238 pyo3::exceptions::PyImportError::new_err(format!("numpy not available: {e}"))
239 })?;
240 let flat_list = PyList::new(py, &self.data)?;
242 let kwargs = PyDict::new(py);
244 kwargs.set_item("dtype", "f8")?;
245 let arr = np.call_method("array", (flat_list,), Some(&kwargs))?;
246 let shape_tuple = PyTuple::new(py, self.shape.iter().copied())?;
248 let reshaped = arr.call_method1("reshape", (shape_tuple,))?;
249 Ok(reshaped.unbind())
250 }
251
252 #[getter]
254 pub fn array_interface(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
255 let desc = ArrayInterfaceDict {
256 shape: self.shape.clone(),
257 typestr: self.dtype.clone(),
258 data_ptr: self.data.as_ptr() as usize,
259 readonly: true,
260 strides: Some(self.strides.clone()),
261 version: 3,
262 };
263 let dict = desc.to_py_dict(py)?;
264 Ok(dict.into_any().unbind())
265 }
266
267 pub fn shape_tuple(&self, py: Python<'_>) -> Py<PyAny> {
269 PyTuple::new(py, self.shape.iter().copied())
270 .map(|t| t.into_any().unbind())
271 .unwrap_or_else(|_| py.None())
272 }
273
274 pub fn dtype_str(&self) -> &str {
276 &self.dtype
277 }
278
279 pub fn data(&self) -> Vec<f64> {
281 self.data.clone()
282 }
283
284 pub fn ndim(&self) -> usize {
286 self.shape.len()
287 }
288}
289
290impl ArrayProtocol for NdArrayWrapper {
291 fn array_interface(&self) -> ArrayInterfaceDict {
292 ArrayInterfaceDict {
293 shape: self.shape.clone(),
294 typestr: self.dtype.clone(),
295 data_ptr: self.data.as_ptr() as usize,
296 readonly: true,
297 strides: Some(self.strides.clone()),
298 version: 3,
299 }
300 }
301
302 fn dtype_str(&self) -> &'static str {
303 "<f8"
304 }
305
306 fn shape(&self) -> Vec<usize> {
307 self.shape.clone()
308 }
309
310 fn strides(&self) -> Vec<usize> {
311 self.strides.clone()
312 }
313
314 fn data_ptr(&self) -> *const u8 {
315 self.data.as_ptr() as *const u8
316 }
317
318 fn nbytes(&self) -> usize {
319 self.data.len() * std::mem::size_of::<f64>()
320 }
321}
322
323pub fn register_array_protocol_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
327 m.add_class::<NdArrayWrapper>()?;
328 Ok(())
329}
330
331fn compute_c_strides_bytes(shape: &[usize], elem_size: usize) -> Vec<usize> {
338 let n = shape.len();
339 if n == 0 {
340 return Vec::new();
341 }
342 let mut strides = vec![elem_size; n];
343 for i in (0..n - 1).rev() {
344 strides[i] = strides[i + 1] * shape[i + 1];
345 }
346 strides
347}
348
349#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
358 fn test_parse_typestr_f64_le() {
359 let (kind, bytes) = parse_typestr("<f8").expect("parse_typestr failed");
360 assert_eq!(kind, 'f');
361 assert_eq!(bytes, 8);
362 }
363
364 #[test]
365 fn test_parse_typestr_i32_be() {
366 let (kind, bytes) = parse_typestr(">i4").expect("parse_typestr failed");
367 assert_eq!(kind, 'i');
368 assert_eq!(bytes, 4);
369 }
370
371 #[test]
372 fn test_parse_typestr_u16_native() {
373 let (kind, bytes) = parse_typestr("=u2").expect("parse_typestr failed");
374 assert_eq!(kind, 'u');
375 assert_eq!(bytes, 2);
376 }
377
378 #[test]
379 fn test_parse_typestr_bool_noendian() {
380 let (kind, bytes) = parse_typestr("|b1").expect("parse_typestr failed");
381 assert_eq!(kind, 'b');
382 assert_eq!(bytes, 1);
383 }
384
385 #[test]
386 fn test_parse_typestr_error_too_short() {
387 assert!(parse_typestr("<f").is_err());
388 assert!(parse_typestr("").is_err());
389 assert!(parse_typestr("<").is_err());
390 }
391
392 #[test]
393 fn test_parse_typestr_error_bad_endian() {
394 assert!(parse_typestr("?f8").is_err());
395 }
396
397 #[test]
398 fn test_parse_typestr_error_zero_bytes() {
399 assert!(parse_typestr("<f0").is_err());
400 }
401
402 #[test]
405 fn test_array_interface_dict_version() {
406 let data = vec![1.0_f64, 2.0, 3.0, 4.0];
407 let wrapper = NdArrayWrapper::new(data, vec![2, 2]).expect("NdArrayWrapper::new failed");
408 let iface = ArrayProtocol::array_interface(&wrapper);
409 assert_eq!(iface.version, 3, "version must be 3");
410 }
411
412 #[test]
413 fn test_array_interface_dict_shape() {
414 let data = vec![1.0_f64; 6];
415 let wrapper = NdArrayWrapper::new(data, vec![2, 3]).expect("NdArrayWrapper::new failed");
416 let iface = ArrayProtocol::array_interface(&wrapper);
417 assert_eq!(iface.shape, vec![2, 3]);
418 }
419
420 #[test]
421 fn test_array_interface_dict_typestr() {
422 let data = vec![0.0_f64; 4];
423 let wrapper = NdArrayWrapper::new(data, vec![4]).expect("NdArrayWrapper::new failed");
424 let iface = ArrayProtocol::array_interface(&wrapper);
425 assert_eq!(iface.typestr, "<f8");
426 }
427
428 #[test]
429 fn test_array_interface_dict_data_ptr_nonzero() {
430 let data = vec![1.0_f64, 2.0, 3.0];
431 let wrapper = NdArrayWrapper::new(data, vec![3]).expect("NdArrayWrapper::new failed");
432 let iface = ArrayProtocol::array_interface(&wrapper);
433 assert_ne!(iface.data_ptr, 0, "data pointer must be non-null");
434 }
435
436 #[test]
439 fn test_ndarray_wrapper_shape_mismatch() {
440 let result = NdArrayWrapper::new(vec![1.0; 4], vec![2, 3]);
442 assert!(result.is_err());
443 }
444
445 #[test]
446 fn test_ndarray_wrapper_scalar() {
447 let wrapper = NdArrayWrapper::new(vec![42.0], vec![1]).expect("scalar failed");
449 assert_eq!(wrapper.ndim(), 1);
450 assert_eq!(wrapper.data(), vec![42.0]);
451 }
452
453 #[test]
454 fn test_ndarray_wrapper_strides_c_order() {
455 let data = vec![0.0_f64; 12];
457 let wrapper = NdArrayWrapper::new(data, vec![3, 4]).expect("NdArrayWrapper::new failed");
458 let strides = ArrayProtocol::strides(&wrapper);
459 assert_eq!(strides, vec![32, 8]);
460 }
461
462 #[test]
465 fn test_array_interface_py_dict_keys() {
466 Python::attach(|py| {
467 let data = vec![1.0_f64, 2.0, 3.0, 4.0];
468 let wrapper =
469 NdArrayWrapper::new(data, vec![2, 2]).expect("NdArrayWrapper::new failed");
470 let iface = ArrayProtocol::array_interface(&wrapper);
471 let dict = iface.to_py_dict(py).expect("to_py_dict failed");
472
473 assert!(dict
474 .get_item("shape")
475 .expect("shape lookup failed")
476 .is_some());
477 assert!(dict
478 .get_item("typestr")
479 .expect("typestr lookup failed")
480 .is_some());
481 assert!(dict.get_item("data").expect("data lookup failed").is_some());
482 assert!(dict
483 .get_item("version")
484 .expect("version lookup failed")
485 .is_some());
486 });
487 }
488
489 #[test]
490 fn test_array_interface_py_dict_shape_values() {
491 Python::attach(|py| {
492 let data = vec![0.0_f64; 6];
493 let wrapper =
494 NdArrayWrapper::new(data, vec![2, 3]).expect("NdArrayWrapper::new failed");
495 let iface = ArrayProtocol::array_interface(&wrapper);
496 let dict = iface.to_py_dict(py).expect("to_py_dict failed");
497
498 let shape_obj = dict
499 .get_item("shape")
500 .expect("shape lookup failed")
501 .expect("shape missing");
502 let shape_tuple = shape_obj.cast::<PyTuple>().expect("shape is not a tuple");
503 assert_eq!(shape_tuple.len(), 2);
504 let v0: usize = shape_tuple
505 .get_item(0)
506 .expect("item 0")
507 .extract()
508 .expect("extract[0]");
509 let v1: usize = shape_tuple
510 .get_item(1)
511 .expect("item 1")
512 .extract()
513 .expect("extract[1]");
514 assert_eq!(v0, 2);
515 assert_eq!(v1, 3);
516 });
517 }
518}