1use std::marker::PhantomData;
4
5pub mod runtime;
6use runtime::*;
7
8#[repr(u8)]
9#[derive(Clone, Copy, PartialEq, PartialOrd)]
10pub enum Kind {
11 Int = halide_type_code_t::halide_type_int as u8,
12 UInt = halide_type_code_t::halide_type_uint as u8,
13 Float = halide_type_code_t::halide_type_float as u8,
14}
15
16#[derive(Clone, Copy, PartialEq, PartialOrd)]
20pub struct Type(pub Kind, pub u8, pub u16);
21
22impl Type {
23 pub fn new(kind: Kind, bits: u8) -> Type {
24 Type(kind, bits, 1)
25 }
26
27 pub fn new_with_lanes(kind: Kind, bits: u8, lanes: u16) -> Type {
28 Type(kind, bits, lanes)
29 }
30
31 pub fn bits(&self) -> u8 {
32 self.1
33 }
34
35 pub fn kind(&self) -> Kind {
36 self.0
37 }
38
39 pub fn size(&self) -> usize {
40 self.bits() as usize / 8
41 }
42}
43
44#[repr(transparent)]
47pub struct Buffer<'a>(halide_buffer_t, PhantomData<&'a ()>);
48
49fn halide_buffer(
50 width: i32,
51 height: i32,
52 channels: i32,
53 t: Type,
54 data: *mut u8,
55) -> halide_buffer_t {
56 let t = halide_type_t {
57 code: t.0 as u8,
58 bits: t.1,
59 lanes: t.2,
60 };
61
62 let mut dim = vec![
63 halide_dimension_t {
64 flags: 0,
65 min: 0,
66 extent: width,
67 stride: channels,
68 },
69 halide_dimension_t {
70 flags: 0,
71 min: 0,
72 extent: height,
73 stride: channels * width,
74 },
75 ];
76
77 if channels > 1 {
78 dim.push(halide_dimension_t {
79 flags: 0,
80 min: 0,
81 extent: channels,
82 stride: 1,
83 });
84 }
85
86 dim.shrink_to_fit();
87
88 let buf = halide_buffer_t {
89 device: 0,
90 device_interface: std::ptr::null(),
91 dimensions: if channels < 2 { 2 } else { 3 },
92 host: data,
93 flags: 0,
94 padding: std::ptr::null_mut(),
95 type_: t,
96 dim: dim.as_mut_ptr(),
97 };
98
99 std::mem::forget(dim);
100
101 buf
102}
103
104impl<'a> From<&'a halide_buffer_t> for Buffer<'a> {
105 fn from(buf: &'a halide_buffer_t) -> Buffer {
106 let mut dest = *buf;
107 let mut dim = Vec::new();
108
109 for i in 0..dest.dimensions as usize {
110 unsafe {
111 dim.push(*dest.dim.add(i));
112 }
113 }
114
115 dest.dim = dim.as_mut_ptr();
116 std::mem::forget(dim);
117
118 Buffer(dest, PhantomData)
119 }
120}
121
122impl<'a> Clone for Buffer<'a> {
123 fn clone(&self) -> Self {
124 let mut dest = self.0;
125 let mut dim = Vec::new();
126
127 for i in 0..dest.dimensions as usize {
128 unsafe {
129 dim.push(*dest.dim.add(i));
130 }
131 }
132
133 dest.dim = dim.as_mut_ptr();
134 std::mem::forget(dim);
135
136 Buffer(dest, PhantomData)
137 }
138}
139
140impl<'a> Buffer<'a> {
141 pub fn new<T>(width: i32, height: i32, channels: i32, t: Type, data: &'a mut [T]) -> Self {
142 Buffer(
143 halide_buffer(width, height, channels, t, data.as_mut_ptr() as *mut u8),
144 PhantomData,
145 )
146 }
147
148 pub fn new_const<T>(width: i32, height: i32, channels: i32, t: Type, data: &'a [T]) -> Self {
149 Buffer(
150 halide_buffer(width, height, channels, t, data.as_ptr() as *mut u8),
151 PhantomData,
152 )
153 }
154
155 pub fn copy_to_host(&mut self) {
156 unsafe {
157 runtime::halide_copy_to_host(std::ptr::null_mut(), &mut self.0);
158 }
159 }
160
161 #[cfg(feature = "gpu")]
162 pub fn copy_to_device(&mut self, device: &gpu::Device) {
163 unsafe {
164 runtime::halide_copy_to_device(std::ptr::null_mut(), &mut self.0, device.0);
165 }
166 }
167}
168
169impl<'a> Drop for Buffer<'a> {
170 fn drop(&mut self) {
171 unsafe {
172 Vec::from_raw_parts(
173 self.0.dim,
174 self.0.dimensions as usize,
175 self.0.dimensions as usize,
176 );
177 }
178 }
179}
180
181#[cfg(feature = "gpu")]
182pub mod gpu {
183 use crate::*;
184
185 extern "C" {
186 fn halide_opencl_device_interface() -> *const halide_device_interface_t;
187
188 fn halide_opengl_device_interface() -> *const halide_device_interface_t;
189
190 fn halide_cuda_device_interface() -> *const halide_device_interface_t;
191
192 #[cfg(target_os = "macos")]
193 fn halide_metal_device_interface() -> *const halide_device_interface_t;
194 }
195
196 pub struct Device(pub *const halide_device_interface_t);
197
198 impl Device {
199 pub fn opencl() -> Device {
200 unsafe { Device(halide_opencl_device_interface()) }
201 }
202
203 pub fn opengl() -> Device {
204 unsafe { Device(halide_opengl_device_interface()) }
205 }
206
207 pub fn cuda() -> Device {
208 unsafe { Device(halide_cuda_device_interface()) }
209 }
210
211 #[cfg(target_os = "macos")]
212 pub fn metal() -> Device {
213 unsafe { Device(halide_metal_device_interface()) }
214 }
215 }
216
217 pub fn set_gpu_device(i: i32) {
218 unsafe {
219 halide_set_gpu_device(i);
220 }
221 }
222
223 pub fn get_gpu_device() {
224 unsafe {
225 halide_get_gpu_device(std::ptr::null_mut());
226 }
227 }
228
229 impl<'a> Buffer<'a> {
230 pub fn set_device(&mut self, device: u64, handle: Device) {
231 self.0.device = device;
232 self.0.device_interface = handle.0;
233 }
234 }
235}
236
237pub type Status = runtime::Status;
238
239#[cfg(test)]
240mod tests {
241 use crate::*;
242
243 extern "C" {
244 pub fn brighter(a: *const Buffer, b: *mut Buffer) -> Status;
245 }
246
247 #[test]
248 fn it_works() {
249 let width = 800;
250 let height = 600;
251 let channels = 3;
252 let t = Type::new(Kind::UInt, 8);
253 let input = vec![0u8; width * height * channels * t.size()];
254 let mut output = vec![0u8; width * height * channels * t.size()];
255
256 {
257 let buf = Buffer::new_const(width as i32, height as i32, channels as i32, t, &input);
258 let mut out = Buffer::new(width as i32, height as i32, channels as i32, t, &mut output);
259
260 unsafe {
261 assert!(brighter(&buf, &mut out) == Status::Success);
262 }
263
264 out.copy_to_host();
265 }
266
267 for i in output.iter() {
268 assert!(*i == 10);
269 }
270 }
271}