use std::ops::Deref;
use half::f16;
use libc::size_t;
use crate::{NumberType, NumberTypedT, NumberTypeKind};
use crate::ffi::{
cl_bool, cl_char, cl_double, cl_float, cl_half, cl_int, cl_long, cl_uchar, cl_uint, cl_ulong,
};
use crate::ffi::{
cl_char16, cl_char2, cl_char3, cl_char4, cl_char8,
cl_float16, cl_float2, cl_float3, cl_float4, cl_float8, cl_int16, cl_int2,
cl_int3, cl_int4, cl_int8, cl_long16, cl_long2, cl_long3, cl_long4, cl_long8,
cl_short, cl_short16, cl_short2, cl_short3, cl_short4, cl_short8, cl_uchar16,
cl_uchar2, cl_uchar3, cl_uchar4, cl_uchar8, cl_uint16, cl_uint2, cl_uint3, cl_uint4,
cl_uint8, cl_ulong16, cl_ulong2, cl_ulong3, cl_ulong4, cl_ulong8, cl_ushort,
cl_ushort16, cl_ushort2, cl_ushort3, cl_ushort4, cl_ushort8,
};
pub unsafe trait ClNumber:
Sized + Clone + Copy + Send + Sync + 'static + Zeroed + NumberTypedT
{
}
unsafe impl ClNumber for u8 {}
unsafe impl ClNumber for i8 {}
unsafe impl ClNumber for u16 {}
unsafe impl ClNumber for i16 {}
unsafe impl ClNumber for u32 {}
unsafe impl ClNumber for i32 {}
unsafe impl ClNumber for f32 {}
unsafe impl ClNumber for u64 {}
unsafe impl ClNumber for i64 {}
unsafe impl ClNumber for f64 {}
impl NumberTypedT for f64 {
fn number_type() -> NumberType {
NumberType::ClDouble
}
}
macro_rules! impl_deref {
($new_t:ty => $t:ty) => {
impl Deref for $new_t {
type Target = $t;
fn deref(&self) -> &$t {
&self.0
}
}
}
}
const CL_BOOL_FALSE: cl_bool = 0;
const CL_BOOL_TRUE: cl_bool = 1;
impl Deref for ClBool {
type Target = cl_bool;
fn deref(&self) -> &cl_bool {
match self {
ClBool::True => &CL_BOOL_TRUE,
ClBool::False => &CL_BOOL_FALSE,
}
}
}
pub trait ToClNumber<T> {
fn to_cl_number(self) -> T;
}
pub trait FromClNumber<T> {
fn from_cl_number(value: T) -> Self;
}
pub trait Zeroed {
fn zeroed() -> Self;
}
impl ToClNumber<cl_bool> for bool {
fn to_cl_number(self) -> cl_bool {
match self {
true => 1,
false => 0,
}
}
}
impl ToClNumber<cl_bool> for ClBool {
fn to_cl_number(self) -> cl_bool {
match self {
ClBool::True => 1,
ClBool::False => 0,
}
}
}
impl FromClNumber<cl_bool> for bool {
fn from_cl_number(b: cl_bool) -> bool {
match b {
0 => false,
1 => true,
bad => panic!("Invalid cl_bool value {:?}: must be 0 or 1", bad),
}
}
}
impl FromClNumber<cl_bool> for ClBool {
fn from_cl_number(b: cl_bool) -> ClBool {
if bool::from_cl_number(b) {
ClBool::True
} else {
ClBool::False
}
}
}
impl_deref!(ClHalf => cl_half);
impl ToClNumber<cl_half> for ClHalf {
fn to_cl_number(self) -> cl_half {
*self
}
}
impl FromClNumber<cl_half> for ClHalf {
fn from_cl_number(val: cl_half) -> ClHalf {
ClHalf(val)
}
}
impl FromClNumber<cl_half> for f16 {
fn from_cl_number(val: cl_half) -> f16 {
f16::from_bits(val)
}
}
macro_rules! impl_primitive_conversion {
($t:ty, $new_t:ident, $rust_t:ty) => {
impl ToClNumber<$t> for $new_t {
fn to_cl_number(self) -> $t {
*self
}
}
impl FromClNumber<$t> for $new_t {
fn from_cl_number(val: $t) -> $new_t {
$new_t(val)
}
}
impl FromClNumber<$t> for $rust_t {
fn from_cl_number(val: $t) -> $rust_t {
val
}
}
}
}
pub enum ClBool {
True,
False,
}
pub struct ClDouble(cl_double);
pub struct SizeT(size_t);
pub struct ClHalf(cl_half);
macro_rules! from_cl_number_inner_s {
($t:ty, $new_t:ident, $rust_t:ty) => {
impl FromClNumber<$t> for $rust_t {
fn from_cl_number(num: $t) -> $rust_t {
unsafe { num.s }
}
}
impl ToClNumber<$t> for $rust_t {
fn to_cl_number(self) -> $t {
let mut num = unsafe { std::mem::zeroed::<$t>() };
num.s = self;
num
}
}
impl ToClNumber<$t> for $new_t {
fn to_cl_number(self) -> $t {
self.0
}
}
};
}
macro_rules! from_cl_number_inner_s3 {
($t:ident, $new_t:ident, $rust_t:ident) => {
paste::item! {
impl FromClNumber<[<$t 3>]> for [$rust_t; 3] {
fn from_cl_number(num: [<$t 3>]) -> [$rust_t; 3] {
let inner = unsafe { num.s };
[inner[0], inner[1], inner[2]]
}
}
impl FromClNumber<[<$t 3>]> for [<$new_t 3>] {
fn from_cl_number(num: [<$t 3>]) -> [<$new_t 3>] {
[<$new_t 3>](num)
}
}
impl ToClNumber<[<$t 3>]> for [$rust_t; 3] {
fn to_cl_number(self) -> [<$t 3>] {
let mut num = unsafe { std::mem::zeroed::<[<$t 3>]>() };
let new_inner = [self[0], self[1], self[2], 0 as $t];
num.s = new_inner;
num
}
}
impl ToClNumber<[<$t 3>]> for [<$new_t 3>] {
fn to_cl_number(self) -> [<$t 3>] {
self.0
}
}
}
};
}
macro_rules! impl_number_typed_t {
($snake:ident, $pascal:ident) => {
impl NumberTypedT for $snake {
fn number_type() -> NumberType {
NumberType::$pascal(NumberTypeKind::Primitive)
}
}
impl NumberTypedT for $pascal {
fn number_type() -> NumberType {
NumberType::$pascal(NumberTypeKind::Primitive)
}
}
};
($snake:ident, $pascal:ident, $num:expr) => {
paste::item! {
impl NumberTypedT for [<$pascal $num>] {
fn number_type() -> NumberType {
NumberType::[<$pascal>](num_to_kind!($num))
}
}
impl NumberTypedT for [<$snake $num>] {
fn number_type() -> NumberType {
NumberType::[<$pascal>](num_to_kind!($num))
}
}
}
}
}
macro_rules! impl_zeroed_num_vector {
($t:ident, $num:expr) => {
paste::item! {
impl Zeroed for [<$t $num>] {
fn zeroed() -> Self {
unsafe { std::mem::zeroed::<[<$t $num>]>() }
}
}
}
}
}
impl Zeroed for f64 {
fn zeroed() -> Self {
0.0 as f64
}
}
macro_rules! impl_zeroed_num {
($t:ident) => {
impl Zeroed for $t {
fn zeroed() -> Self {
0 as $t
}
}
impl_zeroed_num_vector!($t, 2);
impl_zeroed_num_vector!($t, 4);
impl_zeroed_num_vector!($t, 8);
impl_zeroed_num_vector!($t, 16);
}
}
macro_rules! num_to_kind {
(1) => { $crate::NumberTypeKind::Primitive };
(2) => { $crate::NumberTypeKind::Two };
(3) => { $crate::NumberTypeKind::Three };
(4) => { $crate::NumberTypeKind::Four };
(8) => { $crate::NumberTypeKind::Eight };
(16) => { $crate::NumberTypeKind::Sixteen };
}
macro_rules! newtype_primitive_and_newtype_vectors {
($t:ident, $new_t:ident, $rust_t:ident) => {
paste::item! {
pub struct $new_t($t);
/// Vector containing 2 $rust_t
pub struct [<$new_t 2>]([<$t 2>]);
pub struct [<$new_t 3>]([<$t 3>]);
/// Vector containing 4 $rust_t
pub struct [<$new_t 4>]([<$t 4>]);
pub struct [<$new_t 8>]([<$t 8>]);
/// Vector containing 16 $rust_t
pub struct [<$new_t 16>]([<$t 16>]);
unsafe impl ClNumber for [<$t 2>] {}
unsafe impl ClNumber for [<$t 4>] {}
unsafe impl ClNumber for [<$t 8>] {}
unsafe impl ClNumber for [<$t 16>] {}
impl_zeroed_num!($t);
impl_deref!($new_t => $t);
impl_deref!([<$new_t 2>] => [<$t 2>]);
impl_deref!([<$new_t 3>] => [<$t 3>]);
impl_deref!([<$new_t 4>] => [<$t 4>]);
impl_deref!([<$new_t 8>] => [<$t 8>]);
impl_deref!([<$new_t 16>] => [<$t 16>]);
impl_number_typed_t!($t, $new_t);
impl_number_typed_t!($t, $new_t, 2);
impl_number_typed_t!($t, $new_t, 4);
impl_number_typed_t!($t, $new_t, 8);
impl_number_typed_t!($t, $new_t, 16);
impl_primitive_conversion!($t, $new_t, $rust_t);
from_cl_number_inner_s!([<$t 2>], [<$new_t 2>], [$rust_t; 2]);
from_cl_number_inner_s!([<$t 4>], [<$new_t 4>], [$rust_t; 4]);
from_cl_number_inner_s!([<$t 8>], [<$new_t 8>], [$rust_t; 8]);
from_cl_number_inner_s!([<$t 16>], [<$new_t 16>], [$rust_t; 16]);
from_cl_number_inner_s3!($t, $new_t, $rust_t);
}
}
}
newtype_primitive_and_newtype_vectors!(cl_char, ClChar, i8);
newtype_primitive_and_newtype_vectors!(cl_uchar, ClUchar, u8);
newtype_primitive_and_newtype_vectors!(cl_short, ClShort, i16);
newtype_primitive_and_newtype_vectors!(cl_ushort, ClUshort, u16);
newtype_primitive_and_newtype_vectors!(cl_int, ClInt, i32);
newtype_primitive_and_newtype_vectors!(cl_uint, ClUint, u32);
newtype_primitive_and_newtype_vectors!(cl_long, ClLong, i64);
newtype_primitive_and_newtype_vectors!(cl_ulong, ClUlong, u64);
newtype_primitive_and_newtype_vectors!(cl_float, ClFloat, f32);