1use crate::call;
2use crate::common::{DataType, OneDimArrayInt32, PlaceType, TwoDimArraySize};
3use crate::ctypes::{
4 PD_Tensor, PD_TensorCopyFromCpuFloat, PD_TensorCopyFromCpuInt32, PD_TensorCopyFromCpuInt64,
5 PD_TensorCopyFromCpuInt8, PD_TensorCopyFromCpuUint8, PD_TensorCopyToCpuFloat,
6 PD_TensorCopyToCpuInt32, PD_TensorCopyToCpuInt64, PD_TensorCopyToCpuInt8,
7 PD_TensorCopyToCpuUint8, PD_TensorDataFloat, PD_TensorDataInt32, PD_TensorDataInt64,
8 PD_TensorDataInt8, PD_TensorDataUint8, PD_TensorDestroy, PD_TensorGetDataType, PD_TensorGetLod,
9 PD_TensorGetName, PD_TensorGetShape, PD_TensorMutableDataFloat, PD_TensorMutableDataInt32,
10 PD_TensorMutableDataInt64, PD_TensorMutableDataInt8, PD_TensorMutableDataUint8,
11 PD_TensorReshape, PD_TensorSetLod,
12};
13use std::borrow::Cow;
14use std::ffi::CStr;
15
16pub struct Tensor {
19 ptr: *mut PD_Tensor,
20}
21
22impl Tensor {
23 pub fn from_ptr(ptr: *mut PD_Tensor) -> Self {
24 Self { ptr }
25 }
26}
27
28impl Tensor {
29 pub fn reshape(&self, shape: &[i32]) {
31 call! {
32 PD_TensorReshape(self.ptr, shape.len(), shape.as_ptr() as *mut _)
33 };
34 }
35
36 pub fn shape(&self) -> Vec<i32> {
38 let ptr = call! { PD_TensorGetShape(self.ptr) };
39 OneDimArrayInt32::from_ptr(ptr).to_vec()
40 }
41
42 pub fn data_type(&self) -> DataType {
43 call! { PD_TensorGetDataType(self.ptr) }
44 }
45
46 pub fn name(&self) -> Cow<str> {
47 let ptr = call! { PD_TensorGetName(self.ptr) };
48 unsafe { CStr::from_ptr(ptr).to_string_lossy() }
49 }
50}
51
52impl Tensor {
53 pub fn copy_from_f32(&self, data: &[f32]) {
54 call! {
55 PD_TensorCopyFromCpuFloat(self.ptr, data.as_ptr())
56 };
57 }
58
59 pub fn copy_from_i64(&self, data: &[i64]) {
60 call! {
61 PD_TensorCopyFromCpuInt64(self.ptr, data.as_ptr())
62 };
63 }
64
65 pub fn copy_from_i32(&self, data: &[i32]) {
66 call! {
67 PD_TensorCopyFromCpuInt32(self.ptr, data.as_ptr())
68 };
69 }
70
71 pub fn copy_from_u8(&self, data: &[u8]) {
72 call! {
73 PD_TensorCopyFromCpuUint8(self.ptr, data.as_ptr())
74 };
75 }
76
77 pub fn copy_from_i8(&self, data: &[i8]) {
78 call! {
79 PD_TensorCopyFromCpuInt8(self.ptr, data.as_ptr())
80 };
81 }
82}
83
84impl Tensor {
85 #[inline]
86 fn size(&self) -> usize {
87 self.shape().into_iter().fold(1usize, |s, i| s * i as usize)
88 }
89
90 fn check_data_type(&self, ty: DataType) -> bool {
91 let dt = self.data_type();
92 dt != DataType::Unknown && dt == ty
93 }
94
95 fn check(&self, size: usize, ty: DataType) -> bool {
96 size >= self.size() && self.check_data_type(ty)
97 }
98}
99
100impl Tensor {
101 pub fn copy_to_f32(&self, data: &mut [f32]) -> bool {
107 if self.check(data.len(), DataType::Float32) {
108 call! { PD_TensorCopyToCpuFloat(self.ptr, data.as_mut_ptr()) };
109 true
110 } else {
111 false
112 }
113 }
114
115 pub fn copy_to_i64(&self, data: &mut [i64]) -> bool {
121 if self.check(data.len(), DataType::Int64) {
122 call! { PD_TensorCopyToCpuInt64(self.ptr, data.as_mut_ptr()) };
123 true
124 } else {
125 false
126 }
127 }
128
129 pub fn copy_to_i32(&self, data: &mut [i32]) -> bool {
135 if self.check(data.len(), DataType::Int32) {
136 call! { PD_TensorCopyToCpuInt32(self.ptr, data.as_mut_ptr()) };
137 true
138 } else {
139 false
140 }
141 }
142
143 pub fn copy_to_u8(&self, data: &mut [u8]) -> bool {
149 if self.check(data.len(), DataType::Uint8) {
150 call! { PD_TensorCopyToCpuUint8(self.ptr, data.as_mut_ptr()) };
151 true
152 } else {
153 false
154 }
155 }
156
157 pub fn copy_to_i8(&self, data: &mut [i8]) -> bool {
163 if self.check(data.len(), DataType::Uint8) {
164 call! { PD_TensorCopyToCpuInt8(self.ptr, data.as_mut_ptr()) };
165 true
166 } else {
167 false
168 }
169 }
170}
171
172impl Tensor {
173 pub fn as_mut_slice_f32(&self, place_type: PlaceType) -> Option<&mut [f32]> {
179 self.check_data_type(DataType::Float32).then(|| {
180 let ptr = call! { PD_TensorMutableDataFloat(self.ptr, place_type) };
181 unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
182 })
183 }
184
185 pub fn as_mut_slice_i64(&self, place_type: PlaceType) -> Option<&mut [i64]> {
191 self.check_data_type(DataType::Int64).then(|| {
192 let ptr = call! { PD_TensorMutableDataInt64(self.ptr, place_type) };
193 unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
194 })
195 }
196
197 pub fn as_mut_slice_i32(&self, place_type: PlaceType) -> Option<&mut [i32]> {
203 self.check_data_type(DataType::Int32).then(|| {
204 let ptr = call! { PD_TensorMutableDataInt32(self.ptr, place_type) };
205 unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
206 })
207 }
208
209 pub fn as_mut_slice_u8(&self, place_type: PlaceType) -> Option<&mut [u8]> {
215 self.check_data_type(DataType::Uint8).then(|| {
216 let ptr = call! { PD_TensorMutableDataUint8(self.ptr, place_type) };
217 unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
218 })
219 }
220
221 pub fn as_mut_slice_i8(&self, place_type: PlaceType) -> Option<&mut [i8]> {
227 self.check_data_type(DataType::Uint8).then(|| {
228 let ptr = call! { PD_TensorMutableDataInt8(self.ptr, place_type) };
229 unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
230 })
231 }
232}
233
234impl Tensor {
235 pub fn as_slice_f32(&self) -> Option<(PlaceType, &[f32])> {
239 self.check_data_type(DataType::Float32).then(|| {
240 let mut place_type = PlaceType::Unknown;
241 let mut size = 0;
242 let ptr = call! { PD_TensorDataFloat(self.ptr, &mut place_type, &mut size) };
243 let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
244 (place_type, s)
245 })
246 }
247
248 pub fn as_slice_i64(&self) -> Option<(PlaceType, &[i64])> {
252 self.check_data_type(DataType::Int64).then(|| {
253 let mut place_type = PlaceType::Unknown;
254 let mut size = 0;
255 let ptr = call! { PD_TensorDataInt64(self.ptr, &mut place_type, &mut size) };
256 let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
257 (place_type, s)
258 })
259 }
260
261 pub fn as_slice_i32(&self) -> Option<(PlaceType, &[i32])> {
265 self.check_data_type(DataType::Int32).then(|| {
266 let mut place_type = PlaceType::Unknown;
267 let mut size = 0;
268 let ptr = call! { PD_TensorDataInt32(self.ptr, &mut place_type, &mut size) };
269 let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
270 (place_type, s)
271 })
272 }
273
274 pub fn as_slice_u8(&self) -> Option<(PlaceType, &[u8])> {
278 self.check_data_type(DataType::Uint8).then(|| {
279 let mut place_type = PlaceType::Unknown;
280 let mut size = 0;
281 let ptr = call! { PD_TensorDataUint8(self.ptr, &mut place_type, &mut size) };
282 let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
283 (place_type, s)
284 })
285 }
286
287 pub fn as_slice_i8(&self) -> Option<(PlaceType, &[i8])> {
291 self.check_data_type(DataType::Uint8).then(|| {
292 let mut place_type = PlaceType::Unknown;
293 let mut size = 0;
294 let ptr = call! { PD_TensorDataInt8(self.ptr, &mut place_type, &mut size) };
295 let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
296 (place_type, s)
297 })
298 }
299}
300
301impl Tensor {
302 pub fn set_lod(&self, lod: TwoDimArraySize) {
303 call! { PD_TensorSetLod(self.ptr, lod.ptr) };
304 }
305
306 pub fn lod(&self) -> TwoDimArraySize {
307 let ptr = call!(PD_TensorGetLod(self.ptr));
308 TwoDimArraySize::from_ptr(ptr)
309 }
310}
311
312impl Drop for Tensor {
313 fn drop(&mut self) {
314 call!(PD_TensorDestroy(self.ptr));
315 }
316}