numpy/npyffi/
mod.rs

1//! Low-Level bindings for NumPy C API.
2//!
3//! <https://numpy.org/doc/stable/reference/c-api>
4#![allow(
5    non_camel_case_types,
6    missing_docs,
7    missing_debug_implementations,
8    clippy::too_many_arguments,
9    clippy::missing_safety_doc
10)]
11
12use std::mem::forget;
13use std::os::raw::{c_uint, c_void};
14
15use pyo3::{
16    sync::PyOnceLock,
17    types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
18    PyResult, Python,
19};
20
21pub const API_VERSION_2_0: c_uint = 0x00000012;
22
23static API_VERSION: PyOnceLock<c_uint> = PyOnceLock::new();
24
25fn get_numpy_api<'py>(
26    py: Python<'py>,
27    module: &str,
28    capsule: &str,
29) -> PyResult<*const *const c_void> {
30    let module = PyModule::import(py, module)?;
31    let capsule = module.getattr(capsule)?.cast_into::<PyCapsule>()?;
32
33    let api = capsule
34        .pointer_checked(None)?
35        .cast::<*const c_void>()
36        .as_ptr()
37        .cast_const();
38
39    // Intentionally leak a reference to the capsule
40    // so we can safely cache a pointer into its interior.
41    forget(capsule);
42
43    Ok(api)
44}
45
46/// Returns whether the runtime `numpy` version is 2.0 or greater.
47pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
48    let api_version = *API_VERSION.get_or_init(py, || unsafe {
49        PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
50    });
51    api_version >= API_VERSION_2_0
52}
53
54// Implements wrappers for NumPy's Array and UFunc API
55macro_rules! impl_api {
56    // API available on all versions
57    [$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
58        #[allow(non_snake_case)]
59        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
60            let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*;
61            (*fptr)($($arg), *)
62        }
63    };
64
65    // API with version constraints, checked at runtime
66    [$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
67        #[allow(non_snake_case)]
68        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
69            assert!(
70                !is_numpy_2(py),
71                "{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
72                stringify!($fname),
73                API_VERSION_2_0,
74                *API_VERSION.get(py).expect("API_VERSION is initialized"),
75            );
76            let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
77            (*fptr)($($arg), *)
78        }
79
80    };
81    [$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
82        #[allow(non_snake_case)]
83        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
84            assert!(
85                is_numpy_2(py),
86                "{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
87                stringify!($fname),
88                API_VERSION_2_0,
89                *API_VERSION.get(py).expect("API_VERSION is initialized"),
90            );
91            let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
92            (*fptr)($($arg), *)
93        }
94
95    };
96}
97
98pub mod array;
99pub mod flags;
100pub mod objects;
101pub mod types;
102pub mod ufunc;
103
104pub use self::array::*;
105pub use self::flags::*;
106pub use self::objects::*;
107pub use self::types::*;
108pub use self::ufunc::*;