use std::sync::Arc;
use std::marker::PhantomData;
use super::jit_core::{JitCompilable, GenericJitCompilable, JitFunction, JitResult};
use super::types::{JitType, JitNumeric, TypedVector, NumericValue};
#[derive(Clone)]
pub struct GenericJitFunction<I, O>
where
I: JitType,
O: JitType,
{
inner: Box<dyn GenericJitCompilable>,
_input_type: PhantomData<I>,
_output_type: PhantomData<O>,
}
impl<I, O> GenericJitFunction<I, O>
where
I: JitType,
O: JitType,
{
pub fn new(inner: Box<dyn GenericJitCompilable>) -> Self {
Self {
inner,
_input_type: PhantomData,
_output_type: PhantomData,
}
}
pub fn execute(&self, input: Vec<I>) -> Option<O> {
let typed_input = TypedVector::new(input)?;
let result = self.inner.execute_typed(typed_input);
O::from_numeric_value(&result)
}
}
#[derive(Clone)]
pub struct F64JitFunction {
inner: JitFunction<dyn Fn(Vec<f64>) -> f64 + Send + Sync>,
}
impl F64JitFunction {
pub fn new(name: impl Into<String>, f: impl Fn(Vec<f64>) -> f64 + Send + Sync + 'static) -> Self {
Self {
inner: JitFunction::new(name, f),
}
}
#[cfg(feature = "jit")]
pub fn with_jit(self) -> JitResult<Self> {
let inner = self.inner.with_jit()?;
Ok(Self { inner })
}
}
impl JitCompilable<Vec<f64>, f64> for F64JitFunction {
fn execute(&self, args: Vec<f64>) -> f64 {
self.inner.execute(args)
}
}
impl GenericJitCompilable for F64JitFunction {
fn execute_typed(&self, args: TypedVector) -> NumericValue {
match args {
TypedVector::F64(vec) => NumericValue::F64(self.inner.execute(vec)),
_ => NumericValue::F64(self.inner.execute(args.to_f64_vec())),
}
}
fn input_type_name(&self) -> &'static str {
"f64"
}
fn output_type_name(&self) -> &'static str {
"f64"
}
}
#[derive(Clone)]
pub struct F32JitFunction {
f: Arc<dyn Fn(Vec<f32>) -> f32 + Send + Sync>,
name: String,
#[cfg(feature = "jit")]
jit_context: Option<Arc<super::jit_core::JitContext>>,
}
impl F32JitFunction {
pub fn new(name: impl Into<String>, f: impl Fn(Vec<f32>) -> f32 + Send + Sync + 'static) -> Self {
Self {
f: Arc::new(f),
name: name.into(),
#[cfg(feature = "jit")]
jit_context: None,
}
}
#[cfg(feature = "jit")]
pub fn with_jit(mut self) -> JitResult<Self> {
match super::jit_core::JitContext::compile(&self.name) {
Ok(ctx) => {
self.jit_context = Some(Arc::new(ctx));
Ok(self)
}
Err(e) => Err(e),
}
}
}
impl JitCompilable<Vec<f32>, f32> for F32JitFunction {
fn execute(&self, args: Vec<f32>) -> f32 {
(self.f)(args)
}
}
impl GenericJitCompilable for F32JitFunction {
fn execute_typed(&self, args: TypedVector) -> NumericValue {
match args {
TypedVector::F32(vec) => NumericValue::F32(self.execute(vec)),
TypedVector::F64(vec) => {
let f32_vec = vec.into_iter().map(|x| x as f32).collect();
NumericValue::F32(self.execute(f32_vec))
}
TypedVector::I64(vec) => {
let f32_vec = vec.into_iter().map(|x| x as f32).collect();
NumericValue::F32(self.execute(f32_vec))
}
TypedVector::I32(vec) => {
let f32_vec = vec.into_iter().map(|x| x as f32).collect();
NumericValue::F32(self.execute(f32_vec))
}
}
}
fn input_type_name(&self) -> &'static str {
"f32"
}
fn output_type_name(&self) -> &'static str {
"f32"
}
}
#[derive(Clone)]
pub struct I64JitFunction {
f: Arc<dyn Fn(Vec<i64>) -> i64 + Send + Sync>,
name: String,
#[cfg(feature = "jit")]
jit_context: Option<Arc<super::jit_core::JitContext>>,
}
impl I64JitFunction {
pub fn new(name: impl Into<String>, f: impl Fn(Vec<i64>) -> i64 + Send + Sync + 'static) -> Self {
Self {
f: Arc::new(f),
name: name.into(),
#[cfg(feature = "jit")]
jit_context: None,
}
}
#[cfg(feature = "jit")]
pub fn with_jit(mut self) -> JitResult<Self> {
match super::jit_core::JitContext::compile(&self.name) {
Ok(ctx) => {
self.jit_context = Some(Arc::new(ctx));
Ok(self)
}
Err(e) => Err(e),
}
}
}
impl JitCompilable<Vec<i64>, i64> for I64JitFunction {
fn execute(&self, args: Vec<i64>) -> i64 {
(self.f)(args)
}
}
impl GenericJitCompilable for I64JitFunction {
fn execute_typed(&self, args: TypedVector) -> NumericValue {
match args {
TypedVector::I64(vec) => NumericValue::I64(self.execute(vec)),
TypedVector::I32(vec) => {
let i64_vec = vec.into_iter().map(|x| x as i64).collect();
NumericValue::I64(self.execute(i64_vec))
}
TypedVector::F64(vec) => {
let i64_vec = vec.into_iter().map(|x| x as i64).collect();
NumericValue::I64(self.execute(i64_vec))
}
TypedVector::F32(vec) => {
let i64_vec = vec.into_iter().map(|x| x as i64).collect();
NumericValue::I64(self.execute(i64_vec))
}
}
}
fn input_type_name(&self) -> &'static str {
"i64"
}
fn output_type_name(&self) -> &'static str {
"i64"
}
}
#[derive(Clone)]
pub struct I32JitFunction {
f: Arc<dyn Fn(Vec<i32>) -> i32 + Send + Sync>,
name: String,
#[cfg(feature = "jit")]
jit_context: Option<Arc<super::jit_core::JitContext>>,
}
impl I32JitFunction {
pub fn new(name: impl Into<String>, f: impl Fn(Vec<i32>) -> i32 + Send + Sync + 'static) -> Self {
Self {
f: Arc::new(f),
name: name.into(),
#[cfg(feature = "jit")]
jit_context: None,
}
}
#[cfg(feature = "jit")]
pub fn with_jit(mut self) -> JitResult<Self> {
match super::jit_core::JitContext::compile(&self.name) {
Ok(ctx) => {
self.jit_context = Some(Arc::new(ctx));
Ok(self)
}
Err(e) => Err(e),
}
}
}
impl JitCompilable<Vec<i32>, i32> for I32JitFunction {
fn execute(&self, args: Vec<i32>) -> i32 {
(self.f)(args)
}
}
impl GenericJitCompilable for I32JitFunction {
fn execute_typed(&self, args: TypedVector) -> NumericValue {
match args {
TypedVector::I32(vec) => NumericValue::I32(self.execute(vec)),
TypedVector::I64(vec) => {
let i32_vec = vec.into_iter().map(|x| x as i32).collect();
NumericValue::I32(self.execute(i32_vec))
}
TypedVector::F64(vec) => {
let i32_vec = vec.into_iter().map(|x| x as i32).collect();
NumericValue::I32(self.execute(i32_vec))
}
TypedVector::F32(vec) => {
let i32_vec = vec.into_iter().map(|x| x as i32).collect();
NumericValue::I32(self.execute(i32_vec))
}
}
}
fn input_type_name(&self) -> &'static str {
"i32"
}
fn output_type_name(&self) -> &'static str {
"i32"
}
}
pub fn jit_f64(name: impl Into<String>, f: impl Fn(Vec<f64>) -> f64 + Send + Sync + 'static) -> F64JitFunction {
F64JitFunction::new(name, f)
}
pub fn jit_f32(name: impl Into<String>, f: impl Fn(Vec<f32>) -> f32 + Send + Sync + 'static) -> F32JitFunction {
F32JitFunction::new(name, f)
}
pub fn jit_i64(name: impl Into<String>, f: impl Fn(Vec<i64>) -> i64 + Send + Sync + 'static) -> I64JitFunction {
I64JitFunction::new(name, f)
}
pub fn jit_i32(name: impl Into<String>, f: impl Fn(Vec<i32>) -> i32 + Send + Sync + 'static) -> I32JitFunction {
I32JitFunction::new(name, f)
}