executorch_sys/cxx_bridge/tensor_ptr.rs
1// Clippy doesnt detect the 'Safety' comments in the cxx bridge.
2#![allow(clippy::missing_safety_doc)]
3
4pub mod cxx_util {
5 /// A wrapper around `std::any::Any` that can be used in a cxx bridge.
6 ///
7 /// This struct is useful to pass any Rust object to C++ code as `Box<RustAny>`, and the C++ code will call
8 /// the destructor of the object when the `RustAny` object is dropped.
9 pub struct RustAny {
10 #[allow(unused)]
11 inner: Box<dyn std::any::Any>,
12 }
13 impl RustAny {
14 /// Create a new `RustAny` object.
15 pub fn new(inner: Box<dyn std::any::Any>) -> Self {
16 Self { inner }
17 }
18 }
19}
20
21use cxx_util::RustAny;
22
23#[cxx::bridge]
24pub(crate) mod ffi {
25
26 extern "Rust" {
27 #[namespace = "executorch_rs::cxx_util"]
28 type RustAny;
29 }
30
31 unsafe extern "C++" {
32 include!("executorch-sys/cpp/executorch_rs/cxx_bridge.hpp");
33
34 /// Redefinition of the [`ScalarType`](crate::ScalarType).
35 type ScalarType = crate::ScalarType;
36 /// Redefinition of the [`TensorShapeDynamism`](crate::TensorShapeDynamism).
37 type TensorShapeDynamism = crate::TensorShapeDynamism;
38 /// A minimal Tensor type whose API is a source compatible subset of at::Tensor.
39 #[namespace = "executorch::aten"]
40 type Tensor;
41
42 /// Create a new tensor pointer.
43 ///
44 /// Arguments:
45 /// - `sizes`: The dimensions of the tensor.
46 /// - `data`: A pointer to the beginning of the data buffer.
47 /// - `dim_order`: The order of the dimensions.
48 /// - `strides`: The strides of the tensor, in units of elements (not bytes).
49 /// - `scalar_type`: The scalar type of the tensor.
50 /// - `dynamism`: The dynamism of the tensor.
51 /// - `allocation`: A `Box<RustAny>` object that will be dropped when the tensor is dropped. Can be used to
52 /// manage the lifetime of the data buffer.
53 ///
54 /// Returns a shared pointer to the tensor.
55 ///
56 /// # Safety
57 ///
58 /// The `data` pointer must be valid for the lifetime of the tensor, and accessing it according to the data
59 /// type, sizes, dim order, and strides must be valid.
60 #[namespace = "executorch_rs"]
61 unsafe fn TensorPtr_new(
62 sizes: UniquePtr<CxxVector<i32>>,
63 data: *mut u8,
64 dim_order: UniquePtr<CxxVector<u8>>,
65 strides: UniquePtr<CxxVector<i32>>,
66 scalar_type: ScalarType,
67 dynamism: TensorShapeDynamism,
68 allocation: Box<RustAny>,
69 ) -> SharedPtr<Tensor>;
70 }
71
72 impl SharedPtr<Tensor> {}
73}