executorch_sys/cxx_bridge/
tensor_ptr.rs

1// Clippy doesnt detect the 'Safety' comments in the cxx bridge.
2#![allow(clippy::missing_safety_doc)]
3
4pub mod cxx_util {
5    /// A wrapper around `std::any::Any` that can be used in a cxx bridge.
6    ///
7    /// This struct is useful to pass any Rust object to C++ code as `Box<RustAny>`, and the C++ code will call
8    /// the destructor of the object when the `RustAny` object is dropped.
9    pub struct RustAny {
10        #[allow(unused)]
11        inner: Box<dyn std::any::Any>,
12    }
13    impl RustAny {
14        /// Create a new `RustAny` object.
15        pub fn new(inner: Box<dyn std::any::Any>) -> Self {
16            Self { inner }
17        }
18    }
19}
20
21use cxx_util::RustAny;
22
23#[cxx::bridge]
24pub(crate) mod ffi {
25
26    extern "Rust" {
27        #[namespace = "executorch_rs::cxx_util"]
28        type RustAny;
29    }
30
31    unsafe extern "C++" {
32        include!("executorch-sys/cpp/executorch_rs/cxx_bridge.hpp");
33
34        /// Redefinition of the [`ScalarType`](crate::ScalarType).
35        type ScalarType = crate::ScalarType;
36        /// Redefinition of the [`TensorShapeDynamism`](crate::TensorShapeDynamism).
37        type TensorShapeDynamism = crate::TensorShapeDynamism;
38        /// A minimal Tensor type whose API is a source compatible subset of at::Tensor.
39        #[namespace = "executorch::aten"]
40        type Tensor;
41
42        /// Create a new tensor pointer.
43        ///
44        /// Arguments:
45        /// - `sizes`: The dimensions of the tensor.
46        /// - `data`: A pointer to the beginning of the data buffer.
47        /// - `dim_order`: The order of the dimensions.
48        /// - `strides`: The strides of the tensor, in units of elements (not bytes).
49        /// - `scalar_type`: The scalar type of the tensor.
50        /// - `dynamism`: The dynamism of the tensor.
51        /// - `allocation`: A `Box<RustAny>` object that will be dropped when the tensor is dropped. Can be used to
52        ///    manage the lifetime of the data buffer.
53        ///
54        /// Returns a shared pointer to the tensor.
55        ///
56        /// # Safety
57        ///
58        /// The `data` pointer must be valid for the lifetime of the tensor, and accessing it according to the data
59        /// type, sizes, dim order, and strides must be valid.
60        #[namespace = "executorch_rs"]
61        unsafe fn TensorPtr_new(
62            sizes: UniquePtr<CxxVector<i32>>,
63            data: *mut u8,
64            dim_order: UniquePtr<CxxVector<u8>>,
65            strides: UniquePtr<CxxVector<i32>>,
66            scalar_type: ScalarType,
67            dynamism: TensorShapeDynamism,
68            allocation: Box<RustAny>,
69        ) -> SharedPtr<Tensor>;
70    }
71
72    impl SharedPtr<Tensor> {}
73}