executorch_sys/cxx_bridge/tensor_ptr.rs
1// Clippy doesnt detect the 'Safety' comments in the cxx bridge.
2#![allow(clippy::missing_safety_doc)]
3
4/// Utility functions and structs used by the cxx bridge.
5pub mod cxx_util {
6 /// A wrapper around `std::any::Any` that can be used in the cxx bridge.
7 ///
8 /// This struct is useful to pass any Rust object to C++ code as `Box<RustAny>`, and the C++ code will call
9 /// the destructor of the object when the `RustAny` object is dropped.
10 pub struct RustAny {
11 #[allow(unused)]
12 inner: Box<dyn std::any::Any>,
13 }
14 impl RustAny {
15 /// Create a new `RustAny` object.
16 pub fn new(inner: Box<dyn std::any::Any>) -> Self {
17 Self { inner }
18 }
19 }
20}
21
22use cxx_util::RustAny;
23
24#[cxx::bridge]
25pub(crate) mod ffi {
26
27 extern "Rust" {
28 #[namespace = "executorch_rs::cxx_util"]
29 type RustAny;
30 }
31
32 unsafe extern "C++" {
33 include!("executorch-sys/cpp/executorch_rs/cxx_bridge.hpp");
34
35 /// Redefinition of the [`ScalarType`](crate::ScalarType).
36 type ScalarType = crate::ScalarType;
37 /// Redefinition of the [`TensorShapeDynamism`](crate::TensorShapeDynamism).
38 type TensorShapeDynamism = crate::TensorShapeDynamism;
39 /// Redefinition of the [`Tensor`](crate::Tensor).
40 #[namespace = "executorch::aten"]
41 type Tensor;
42
43 /// Create a new tensor pointer.
44 ///
45 /// Arguments:
46 /// - `sizes`: The dimensions of the tensor.
47 /// - `data`: A pointer to the beginning of the data buffer.
48 /// - `dim_order`: The order of the dimensions.
49 /// - `strides`: The strides of the tensor, in units of elements (not bytes).
50 /// - `scalar_type`: The scalar type of the tensor.
51 /// - `dynamism`: The dynamism of the tensor.
52 /// - `allocation`: A `Box<RustAny>` object that will be dropped when the tensor is dropped. Can be used to
53 /// manage the lifetime of the data buffer.
54 ///
55 /// Returns a shared pointer to the tensor.
56 ///
57 /// # Safety
58 ///
59 /// The `data` pointer must be valid for the lifetime of the tensor, and accessing it according to the data
60 /// type, sizes, dim order, and strides must be valid.
61 #[namespace = "executorch_rs"]
62 unsafe fn TensorPtr_new(
63 sizes: UniquePtr<CxxVector<i32>>,
64 data: *mut u8,
65 dim_order: UniquePtr<CxxVector<u8>>,
66 strides: UniquePtr<CxxVector<i32>>,
67 scalar_type: ScalarType,
68 dynamism: TensorShapeDynamism,
69 allocation: Box<RustAny>,
70 ) -> SharedPtr<Tensor>;
71 }
72
73 impl SharedPtr<Tensor> {}
74}