paddle_inference/
tensor.rs

1use crate::call;
2use crate::common::{DataType, OneDimArrayInt32, PlaceType, TwoDimArraySize};
3use crate::ctypes::{
4    PD_Tensor, PD_TensorCopyFromCpuFloat, PD_TensorCopyFromCpuInt32, PD_TensorCopyFromCpuInt64,
5    PD_TensorCopyFromCpuInt8, PD_TensorCopyFromCpuUint8, PD_TensorCopyToCpuFloat,
6    PD_TensorCopyToCpuInt32, PD_TensorCopyToCpuInt64, PD_TensorCopyToCpuInt8,
7    PD_TensorCopyToCpuUint8, PD_TensorDataFloat, PD_TensorDataInt32, PD_TensorDataInt64,
8    PD_TensorDataInt8, PD_TensorDataUint8, PD_TensorDestroy, PD_TensorGetDataType, PD_TensorGetLod,
9    PD_TensorGetName, PD_TensorGetShape, PD_TensorMutableDataFloat, PD_TensorMutableDataInt32,
10    PD_TensorMutableDataInt64, PD_TensorMutableDataInt8, PD_TensorMutableDataUint8,
11    PD_TensorReshape, PD_TensorSetLod,
12};
13use std::borrow::Cow;
14use std::ffi::CStr;
15
16/// Tensor 是 Paddle Inference 的数据组织形式,用于对底层数据进行封装并提供接口对数据进行操作,包括设置 Shape、
17/// 数据、LoD 信息等。
18pub struct Tensor {
19    ptr: *mut PD_Tensor,
20}
21
22impl Tensor {
23    pub fn from_ptr(ptr: *mut PD_Tensor) -> Self {
24        Self { ptr }
25    }
26}
27
28impl Tensor {
29    /// 设置维度信息
30    pub fn reshape(&self, shape: &[i32]) {
31        call! {
32            PD_TensorReshape(self.ptr, shape.len(), shape.as_ptr() as *mut _)
33        };
34    }
35
36    /// 获取维度信息
37    pub fn shape(&self) -> Vec<i32> {
38        let ptr = call! { PD_TensorGetShape(self.ptr) };
39        OneDimArrayInt32::from_ptr(ptr).to_vec()
40    }
41
42    pub fn data_type(&self) -> DataType {
43        call! { PD_TensorGetDataType(self.ptr) }
44    }
45
46    pub fn name(&self) -> Cow<str> {
47        let ptr = call! { PD_TensorGetName(self.ptr) };
48        unsafe { CStr::from_ptr(ptr).to_string_lossy() }
49    }
50}
51
52impl Tensor {
53    pub fn copy_from_f32(&self, data: &[f32]) {
54        call! {
55            PD_TensorCopyFromCpuFloat(self.ptr, data.as_ptr())
56        };
57    }
58
59    pub fn copy_from_i64(&self, data: &[i64]) {
60        call! {
61            PD_TensorCopyFromCpuInt64(self.ptr, data.as_ptr())
62        };
63    }
64
65    pub fn copy_from_i32(&self, data: &[i32]) {
66        call! {
67            PD_TensorCopyFromCpuInt32(self.ptr, data.as_ptr())
68        };
69    }
70
71    pub fn copy_from_u8(&self, data: &[u8]) {
72        call! {
73            PD_TensorCopyFromCpuUint8(self.ptr, data.as_ptr())
74        };
75    }
76
77    pub fn copy_from_i8(&self, data: &[i8]) {
78        call! {
79            PD_TensorCopyFromCpuInt8(self.ptr, data.as_ptr())
80        };
81    }
82}
83
84impl Tensor {
85    #[inline]
86    fn size(&self) -> usize {
87        self.shape().into_iter().fold(1usize, |s, i| s * i as usize)
88    }
89
90    fn check_data_type(&self, ty: DataType) -> bool {
91        let dt = self.data_type();
92        dt != DataType::Unknown && dt == ty
93    }
94
95    fn check(&self, size: usize, ty: DataType) -> bool {
96        size >= self.size() && self.check_data_type(ty)
97    }
98}
99
100impl Tensor {
101    /// 从 Tensor 中获取数据,返回是否获取成功
102    ///
103    /// 如果出现以下情况则获取失败
104    /// - 输入类型和[`Self::data_type`]不匹配
105    /// - 输入数据大小小于[`Self::shape`]结果之积
106    pub fn copy_to_f32(&self, data: &mut [f32]) -> bool {
107        if self.check(data.len(), DataType::Float32) {
108            call! { PD_TensorCopyToCpuFloat(self.ptr, data.as_mut_ptr()) };
109            true
110        } else {
111            false
112        }
113    }
114
115    /// 从 Tensor 中获取数据,返回是否获取成功
116    ///
117    /// 如果出现以下情况则获取失败
118    /// - 输入类型和[`Self::data_type`]不匹配
119    /// - 输入数据大小小于[`Self::shape`]结果之积
120    pub fn copy_to_i64(&self, data: &mut [i64]) -> bool {
121        if self.check(data.len(), DataType::Int64) {
122            call! { PD_TensorCopyToCpuInt64(self.ptr, data.as_mut_ptr()) };
123            true
124        } else {
125            false
126        }
127    }
128
129    /// 从 Tensor 中获取数据,返回是否获取成功
130    ///
131    /// 如果出现以下情况则获取失败
132    /// - 输入类型和[`Self::data_type`]不匹配
133    /// - 输入数据大小小于[`Self::shape`]结果之积
134    pub fn copy_to_i32(&self, data: &mut [i32]) -> bool {
135        if self.check(data.len(), DataType::Int32) {
136            call! { PD_TensorCopyToCpuInt32(self.ptr, data.as_mut_ptr()) };
137            true
138        } else {
139            false
140        }
141    }
142
143    /// 从 Tensor 中获取数据,返回是否获取成功
144    ///
145    /// 如果出现以下情况则获取失败
146    /// - 输入类型和[`Self::data_type`]不匹配
147    /// - 输入数据大小小于[`Self::shape`]结果之积
148    pub fn copy_to_u8(&self, data: &mut [u8]) -> bool {
149        if self.check(data.len(), DataType::Uint8) {
150            call! { PD_TensorCopyToCpuUint8(self.ptr, data.as_mut_ptr()) };
151            true
152        } else {
153            false
154        }
155    }
156
157    /// 从 Tensor 中获取数据,返回是否获取成功
158    ///
159    /// 如果出现以下情况则获取失败
160    /// - 输入类型和[`Self::data_type`]不匹配
161    /// - 输入数据大小小于[`Self::shape`]结果之积
162    pub fn copy_to_i8(&self, data: &mut [i8]) -> bool {
163        if self.check(data.len(), DataType::Uint8) {
164            call! { PD_TensorCopyToCpuInt8(self.ptr, data.as_mut_ptr()) };
165            true
166        } else {
167            false
168        }
169    }
170}
171
172impl Tensor {
173    /// 获取 Tensor 底层数据,用于设置输入数据。
174    ///
175    /// **需要先调用[`Self::reshape`]**
176    ///
177    /// 如果底层数据类型([`DataType`])不对应则返回`None`
178    pub fn as_mut_slice_f32(&self, place_type: PlaceType) -> Option<&mut [f32]> {
179        self.check_data_type(DataType::Float32).then(|| {
180            let ptr = call! { PD_TensorMutableDataFloat(self.ptr, place_type) };
181            unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
182        })
183    }
184
185    /// 获取 Tensor 底层数据,用于设置输入数据。
186    ///
187    /// **需要先调用[`Self::reshape`]**
188    ///
189    /// 如果底层数据类型([`DataType`])不对应则返回`None`
190    pub fn as_mut_slice_i64(&self, place_type: PlaceType) -> Option<&mut [i64]> {
191        self.check_data_type(DataType::Int64).then(|| {
192            let ptr = call! { PD_TensorMutableDataInt64(self.ptr, place_type) };
193            unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
194        })
195    }
196
197    /// 获取 Tensor 底层数据,用于设置输入数据。
198    ///
199    /// **需要先调用[`Self::reshape`]**
200    ///
201    /// 如果底层数据类型([`DataType`])不对应则返回`None`
202    pub fn as_mut_slice_i32(&self, place_type: PlaceType) -> Option<&mut [i32]> {
203        self.check_data_type(DataType::Int32).then(|| {
204            let ptr = call! { PD_TensorMutableDataInt32(self.ptr, place_type) };
205            unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
206        })
207    }
208
209    /// 获取 Tensor 底层数据,用于设置输入数据。
210    ///
211    /// **需要先调用[`Self::reshape`]**
212    ///
213    /// 如果底层数据类型([`DataType`])不对应则返回`None`
214    pub fn as_mut_slice_u8(&self, place_type: PlaceType) -> Option<&mut [u8]> {
215        self.check_data_type(DataType::Uint8).then(|| {
216            let ptr = call! { PD_TensorMutableDataUint8(self.ptr, place_type) };
217            unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
218        })
219    }
220
221    /// 获取 Tensor 底层数据,用于设置输入数据。
222    ///
223    /// **需要先调用[`Self::reshape`]**
224    ///
225    /// 如果底层数据类型([`DataType`])不对应则返回`None`
226    pub fn as_mut_slice_i8(&self, place_type: PlaceType) -> Option<&mut [i8]> {
227        self.check_data_type(DataType::Uint8).then(|| {
228            let ptr = call! { PD_TensorMutableDataInt8(self.ptr, place_type) };
229            unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
230        })
231    }
232}
233
234impl Tensor {
235    /// 获取 Tensor 底层数据,用于读取输出数据。
236    ///
237    /// 如果底层数据类型([`DataType`])不对应则返回`None`
238    pub fn as_slice_f32(&self) -> Option<(PlaceType, &[f32])> {
239        self.check_data_type(DataType::Float32).then(|| {
240            let mut place_type = PlaceType::Unknown;
241            let mut size = 0;
242            let ptr = call! { PD_TensorDataFloat(self.ptr, &mut place_type, &mut size) };
243            let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
244            (place_type, s)
245        })
246    }
247
248    /// 获取 Tensor 底层数据,用于读取输出数据。
249    ///
250    /// 如果底层数据类型([`DataType`])不对应则返回`None`
251    pub fn as_slice_i64(&self) -> Option<(PlaceType, &[i64])> {
252        self.check_data_type(DataType::Int64).then(|| {
253            let mut place_type = PlaceType::Unknown;
254            let mut size = 0;
255            let ptr = call! { PD_TensorDataInt64(self.ptr, &mut place_type, &mut size) };
256            let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
257            (place_type, s)
258        })
259    }
260
261    /// 获取 Tensor 底层数据,用于读取输出数据。
262    ///
263    /// 如果底层数据类型([`DataType`])不对应则返回`None`
264    pub fn as_slice_i32(&self) -> Option<(PlaceType, &[i32])> {
265        self.check_data_type(DataType::Int32).then(|| {
266            let mut place_type = PlaceType::Unknown;
267            let mut size = 0;
268            let ptr = call! { PD_TensorDataInt32(self.ptr, &mut place_type, &mut size) };
269            let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
270            (place_type, s)
271        })
272    }
273
274    /// 获取 Tensor 底层数据,用于读取输出数据。
275    ///
276    /// 如果底层数据类型([`DataType`])不对应则返回`None`
277    pub fn as_slice_u8(&self) -> Option<(PlaceType, &[u8])> {
278        self.check_data_type(DataType::Uint8).then(|| {
279            let mut place_type = PlaceType::Unknown;
280            let mut size = 0;
281            let ptr = call! { PD_TensorDataUint8(self.ptr, &mut place_type, &mut size) };
282            let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
283            (place_type, s)
284        })
285    }
286
287    /// 获取 Tensor 底层数据,用于读取输出数据。
288    ///
289    /// 如果底层数据类型([`DataType`])不对应则返回`None`
290    pub fn as_slice_i8(&self) -> Option<(PlaceType, &[i8])> {
291        self.check_data_type(DataType::Uint8).then(|| {
292            let mut place_type = PlaceType::Unknown;
293            let mut size = 0;
294            let ptr = call! { PD_TensorDataInt8(self.ptr, &mut place_type, &mut size) };
295            let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
296            (place_type, s)
297        })
298    }
299}
300
301impl Tensor {
302    pub fn set_lod(&self, lod: TwoDimArraySize) {
303        call! { PD_TensorSetLod(self.ptr, lod.ptr) };
304    }
305
306    pub fn lod(&self) -> TwoDimArraySize {
307        let ptr = call!(PD_TensorGetLod(self.ptr));
308        TwoDimArraySize::from_ptr(ptr)
309    }
310}
311
312impl Drop for Tensor {
313    fn drop(&mut self) {
314        call!(PD_TensorDestroy(self.ptr));
315    }
316}