executorch_sys/cxx_bridge/
tensor_ptr.rs

1// Clippy doesnt detect the 'Safety' comments in the cxx bridge.
2#![allow(clippy::missing_safety_doc)]
3
4/// Utility functions and structs used by the cxx bridge.
5pub mod cxx_util {
6    /// A wrapper around `std::any::Any` that can be used in the cxx bridge.
7    ///
8    /// This struct is useful to pass any Rust object to C++ code as `Box<RustAny>`, and the C++ code will call
9    /// the destructor of the object when the `RustAny` object is dropped.
10    pub struct RustAny {
11        #[allow(dead_code)]
12        inner: Box<dyn std::any::Any>,
13    }
14    impl RustAny {
15        /// Create a new `RustAny` object.
16        pub fn new(inner: Box<dyn std::any::Any>) -> Self {
17            Self { inner }
18        }
19    }
20}
21
22use cxx_util::RustAny;
23
24#[cxx::bridge]
25pub(crate) mod ffi {
26
27    extern "Rust" {
28        #[namespace = "executorch_rs::cxx_util"]
29        type RustAny;
30    }
31
32    unsafe extern "C++" {
33        include!("executorch-sys/cpp/executorch_rs/cxx_bridge.hpp");
34
35        /// Redefinition of the [`ScalarType`](crate::ScalarType).
36        type ScalarType = crate::ScalarType;
37        /// Redefinition of the [`TensorShapeDynamism`](crate::TensorShapeDynamism).
38        type TensorShapeDynamism = crate::TensorShapeDynamism;
39        /// Redefinition of the [`Tensor`](crate::Tensor).
40        #[namespace = "executorch::aten"]
41        type Tensor;
42
43        /// Create a new tensor pointer.
44        ///
45        /// Arguments:
46        /// - `sizes`: The dimensions of the tensor.
47        /// - `data`: A pointer to the beginning of the data buffer.
48        /// - `dim_order`: The order of the dimensions.
49        /// - `strides`: The strides of the tensor.
50        /// - `scalar_type`: The scalar type of the tensor.
51        /// - `dynamism`: The dynamism of the tensor.
52        /// - `allocation`: A `Box<RustAny>` object that will be dropped when the tensor is dropped. Can be used to
53        ///    manage the lifetime of the data buffer.
54        ///
55        /// Returns a shared pointer to the tensor.
56        ///
57        /// # Safety
58        ///
59        /// The `data` pointer must be valid for the lifetime of the tensor, and accessing it according to the data
60        /// type, sizes, dim order, and strides must be valid.
61        #[namespace = "executorch_rs"]
62        unsafe fn TensorPtr_new(
63            sizes: UniquePtr<CxxVector<i32>>,
64            data: *mut u8,
65            dim_order: UniquePtr<CxxVector<u8>>,
66            strides: UniquePtr<CxxVector<i32>>,
67            scalar_type: ScalarType,
68            dynamism: TensorShapeDynamism,
69            allocation: Box<RustAny>,
70        ) -> SharedPtr<Tensor>;
71    }
72
73    impl SharedPtr<Tensor> {}
74}