1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#![allow(dead_code)]
use std::os::raw::c_void;
pub type DeviceTypeCode = i32;
pub mod device_type_codes {
use super::DeviceTypeCode;
pub const CPU: DeviceTypeCode = 1;
pub const GPU: DeviceTypeCode = 2;
pub const CPU_PINNED: DeviceTypeCode = 3;
pub const OPENCL: DeviceTypeCode = 4;
pub const METAL: DeviceTypeCode = 8;
pub const VPI: DeviceTypeCode = 9;
pub const ROCM: DeviceTypeCode = 10;
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct Context {
pub device_type: DeviceTypeCode,
pub device_id: i32,
}
pub type DataTypeCode = u8;
pub mod data_type_codes {
use super::DataTypeCode;
pub const INT: DataTypeCode = 0;
pub const UINT: DataTypeCode = 1;
pub const FLOAT: DataTypeCode = 2;
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct DataType {
pub code: DataTypeCode,
pub bits: u8,
pub lanes: u16,
}
#[repr(C)]
pub struct Tensor {
pub data: *mut c_void,
pub ctx: Context,
pub ndim: i32,
pub dtype: DataType,
pub shape: *mut i64,
pub strides: *mut i64,
pub byte_offset: u64,
}
#[repr(C)]
pub struct ManagedTensor {
pub dl_tensor: Tensor,
pub manager_ctx: *mut c_void,
pub deleter: extern fn (*mut ManagedTensor),
}