1use std::marker::PhantomData;
4
5pub mod runtime;
6use runtime::*;
7
8#[repr(u8)]
9#[derive(Clone, Copy, PartialEq, PartialOrd)]
10pub enum Kind {
11    Int = halide_type_code_t::halide_type_int as u8,
12    UInt = halide_type_code_t::halide_type_uint as u8,
13    Float = halide_type_code_t::halide_type_float as u8,
14}
15
16#[derive(Clone, Copy, PartialEq, PartialOrd)]
20pub struct Type(pub Kind, pub u8, pub u16);
21
22impl Type {
23    pub fn new(kind: Kind, bits: u8) -> Type {
24        Type(kind, bits, 1)
25    }
26
27    pub fn new_with_lanes(kind: Kind, bits: u8, lanes: u16) -> Type {
28        Type(kind, bits, lanes)
29    }
30
31    pub fn bits(&self) -> u8 {
32        self.1
33    }
34
35    pub fn kind(&self) -> Kind {
36        self.0
37    }
38
39    pub fn size(&self) -> usize {
40        self.bits() as usize / 8
41    }
42}
43
44#[repr(transparent)]
47pub struct Buffer<'a>(halide_buffer_t, PhantomData<&'a ()>);
48
49fn halide_buffer(
50    width: i32,
51    height: i32,
52    channels: i32,
53    t: Type,
54    data: *mut u8,
55) -> halide_buffer_t {
56    let t = halide_type_t {
57        code: t.0 as u8,
58        bits: t.1,
59        lanes: t.2,
60    };
61
62    let mut dim = vec![
63        halide_dimension_t {
64            flags: 0,
65            min: 0,
66            extent: width,
67            stride: channels,
68        },
69        halide_dimension_t {
70            flags: 0,
71            min: 0,
72            extent: height,
73            stride: channels * width,
74        },
75    ];
76
77    if channels > 1 {
78        dim.push(halide_dimension_t {
79            flags: 0,
80            min: 0,
81            extent: channels,
82            stride: 1,
83        });
84    }
85
86    dim.shrink_to_fit();
87
88    let buf = halide_buffer_t {
89        device: 0,
90        device_interface: std::ptr::null(),
91        dimensions: if channels < 2 { 2 } else { 3 },
92        host: data,
93        flags: 0,
94        padding: std::ptr::null_mut(),
95        type_: t,
96        dim: dim.as_mut_ptr(),
97    };
98
99    std::mem::forget(dim);
100
101    buf
102}
103
104impl<'a> From<&'a halide_buffer_t> for Buffer<'a> {
105    fn from(buf: &'a halide_buffer_t) -> Buffer {
106        let mut dest = *buf;
107        let mut dim = Vec::new();
108
109        for i in 0..dest.dimensions as usize {
110            unsafe {
111                dim.push(*dest.dim.add(i));
112            }
113        }
114
115        dest.dim = dim.as_mut_ptr();
116        std::mem::forget(dim);
117
118        Buffer(dest, PhantomData)
119    }
120}
121
122impl<'a> Clone for Buffer<'a> {
123    fn clone(&self) -> Self {
124        let mut dest = self.0;
125        let mut dim = Vec::new();
126
127        for i in 0..dest.dimensions as usize {
128            unsafe {
129                dim.push(*dest.dim.add(i));
130            }
131        }
132
133        dest.dim = dim.as_mut_ptr();
134        std::mem::forget(dim);
135
136        Buffer(dest, PhantomData)
137    }
138}
139
140impl<'a> Buffer<'a> {
141    pub fn new<T>(width: i32, height: i32, channels: i32, t: Type, data: &'a mut [T]) -> Self {
142        Buffer(
143            halide_buffer(width, height, channels, t, data.as_mut_ptr() as *mut u8),
144            PhantomData,
145        )
146    }
147
148    pub fn new_const<T>(width: i32, height: i32, channels: i32, t: Type, data: &'a [T]) -> Self {
149        Buffer(
150            halide_buffer(width, height, channels, t, data.as_ptr() as *mut u8),
151            PhantomData,
152        )
153    }
154
155    pub fn copy_to_host(&mut self) {
156        unsafe {
157            runtime::halide_copy_to_host(std::ptr::null_mut(), &mut self.0);
158        }
159    }
160
161    #[cfg(feature = "gpu")]
162    pub fn copy_to_device(&mut self, device: &gpu::Device) {
163        unsafe {
164            runtime::halide_copy_to_device(std::ptr::null_mut(), &mut self.0, device.0);
165        }
166    }
167}
168
169impl<'a> Drop for Buffer<'a> {
170    fn drop(&mut self) {
171        unsafe {
172            Vec::from_raw_parts(
173                self.0.dim,
174                self.0.dimensions as usize,
175                self.0.dimensions as usize,
176            );
177        }
178    }
179}
180
181#[cfg(feature = "gpu")]
182pub mod gpu {
183    use crate::*;
184
185    extern "C" {
186        fn halide_opencl_device_interface() -> *const halide_device_interface_t;
187
188        fn halide_opengl_device_interface() -> *const halide_device_interface_t;
189
190        fn halide_cuda_device_interface() -> *const halide_device_interface_t;
191
192        #[cfg(target_os = "macos")]
193        fn halide_metal_device_interface() -> *const halide_device_interface_t;
194    }
195
196    pub struct Device(pub *const halide_device_interface_t);
197
198    impl Device {
199        pub fn opencl() -> Device {
200            unsafe { Device(halide_opencl_device_interface()) }
201        }
202
203        pub fn opengl() -> Device {
204            unsafe { Device(halide_opengl_device_interface()) }
205        }
206
207        pub fn cuda() -> Device {
208            unsafe { Device(halide_cuda_device_interface()) }
209        }
210
211        #[cfg(target_os = "macos")]
212        pub fn metal() -> Device {
213            unsafe { Device(halide_metal_device_interface()) }
214        }
215    }
216
217    pub fn set_gpu_device(i: i32) {
218        unsafe {
219            halide_set_gpu_device(i);
220        }
221    }
222
223    pub fn get_gpu_device() {
224        unsafe {
225            halide_get_gpu_device(std::ptr::null_mut());
226        }
227    }
228
229    impl<'a> Buffer<'a> {
230        pub fn set_device(&mut self, device: u64, handle: Device) {
231            self.0.device = device;
232            self.0.device_interface = handle.0;
233        }
234    }
235}
236
237pub type Status = runtime::Status;
238
239#[cfg(test)]
240mod tests {
241    use crate::*;
242
243    extern "C" {
244        pub fn brighter(a: *const Buffer, b: *mut Buffer) -> Status;
245    }
246
247    #[test]
248    fn it_works() {
249        let width = 800;
250        let height = 600;
251        let channels = 3;
252        let t = Type::new(Kind::UInt, 8);
253        let input = vec![0u8; width * height * channels * t.size()];
254        let mut output = vec![0u8; width * height * channels * t.size()];
255
256        {
257            let buf = Buffer::new_const(width as i32, height as i32, channels as i32, t, &input);
258            let mut out = Buffer::new(width as i32, height as i32, channels as i32, t, &mut output);
259
260            unsafe {
261                assert!(brighter(&buf, &mut out) == Status::Success);
262            }
263
264            out.copy_to_host();
265        }
266
267        for i in output.iter() {
268            assert!(*i == 10);
269        }
270    }
271}