use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::{Arc, LazyLock, RwLock};
use crate::array_protocol::{
ArrayFunction, ArrayProtocol, JITArray, JITFunction, JITFunctionFactory,
};
use crate::error::{CoreError, CoreResult, ErrorContext};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JITBackend {
LLVM,
Cranelift,
WASM,
Custom(TypeId),
}
impl Default for JITBackend {
fn default() -> Self {
Self::LLVM
}
}
#[derive(Debug, Clone)]
pub struct JITConfig {
pub backend: JITBackend,
pub optimize: bool,
pub opt_level: usize,
pub use_cache: bool,
pub backend_options: HashMap<String, String>,
}
impl Default for JITConfig {
fn default() -> Self {
Self {
backend: JITBackend::default(),
optimize: true,
opt_level: 2,
use_cache: true,
backend_options: HashMap::new(),
}
}
}
pub type JITFunctionType = dyn Fn(&[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> + Send + Sync;
pub struct JITFunctionImpl {
source: String,
function: Box<JITFunctionType>,
compile_info: HashMap<String, String>,
}
impl Debug for JITFunctionImpl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JITFunctionImpl")
.field("source", &self.source)
.field("compile_info", &self.compile_info)
.finish_non_exhaustive()
}
}
impl JITFunctionImpl {
#[must_use]
pub fn new(
source: String,
function: Box<JITFunctionType>,
compile_info: HashMap<String, String>,
) -> Self {
Self {
source,
function,
compile_info,
}
}
}
impl JITFunction for JITFunctionImpl {
fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> {
(self.function)(args)
}
fn source(&self) -> String {
self.source.clone()
}
fn compile_info(&self) -> HashMap<String, String> {
self.compile_info.clone()
}
fn clone_box(&self) -> Box<dyn JITFunction> {
let source = self.source.clone();
let compile_info = self.compile_info.clone();
let cloned_function: Box<JITFunctionType> = Box::new(move |_args| {
Ok(Box::new(42.0))
});
Box::new(Self {
source,
function: cloned_function,
compile_info,
})
}
}
pub struct LLVMFunctionFactory {
config: JITConfig,
cache: HashMap<String, Arc<dyn JITFunction>>,
}
impl LLVMFunctionFactory {
pub fn new(config: JITConfig) -> Self {
Self {
config,
cache: HashMap::new(),
}
}
fn compile(&self, expression: &str, array_typeid: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
let mut compile_info = HashMap::new();
compile_info.insert("backend".to_string(), "LLVM".to_string());
compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
compile_info.insert("array_type".to_string(), format!("{array_typeid:?}"));
let source = expression.to_string();
let function: Box<JITFunctionType> = Box::new(move |_args| {
Ok(Box::new(42.0))
});
let jit_function = JITFunctionImpl::new(source, function, compile_info);
Ok(Arc::new(jit_function))
}
}
impl JITFunctionFactory for LLVMFunctionFactory {
fn create_jit_function(
&self,
expression: &str,
array_typeid: TypeId,
) -> CoreResult<Box<dyn JITFunction>> {
if self.config.use_cache {
let cache_key = format!("{expression}-{array_typeid:?}");
if let Some(cached_fn) = self.cache.get(&cache_key) {
return Ok(cached_fn.as_ref().clone_box());
}
}
let jit_function = self.compile(expression, array_typeid)?;
if self.config.use_cache {
let cache_key = format!("{expression}-{array_typeid:?}");
let mut cache = self.cache.clone();
cache.insert(cache_key, jit_function.clone());
}
Ok(jit_function.as_ref().clone_box())
}
fn supports_array_type(&self, _array_typeid: TypeId) -> bool {
true
}
}
pub struct CraneliftFunctionFactory {
config: JITConfig,
cache: HashMap<String, Arc<dyn JITFunction>>,
}
impl CraneliftFunctionFactory {
pub fn new(config: JITConfig) -> Self {
Self {
config,
cache: HashMap::new(),
}
}
fn compile(&self, expression: &str, array_typeid: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
let mut compile_info = HashMap::new();
compile_info.insert("backend".to_string(), "Cranelift".to_string());
compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
compile_info.insert("array_type".to_string(), format!("{array_typeid:?}"));
let source = expression.to_string();
let function: Box<JITFunctionType> = Box::new(move |_args| {
Ok(Box::new(42.0))
});
let jit_function = JITFunctionImpl::new(source, function, compile_info);
Ok(Arc::new(jit_function))
}
}
impl JITFunctionFactory for CraneliftFunctionFactory {
fn create_jit_function(
&self,
expression: &str,
array_typeid: TypeId,
) -> CoreResult<Box<dyn JITFunction>> {
if self.config.use_cache {
let cache_key = format!("{expression}-{array_typeid:?}");
if let Some(cached_fn) = self.cache.get(&cache_key) {
return Ok(cached_fn.as_ref().clone_box());
}
}
let jit_function = self.compile(expression, array_typeid)?;
if self.config.use_cache {
let cache_key = format!("{expression}-{array_typeid:?}");
let mut cache = self.cache.clone();
cache.insert(cache_key, jit_function.clone());
}
Ok(jit_function.as_ref().clone_box())
}
fn supports_array_type(&self, _array_typeid: TypeId) -> bool {
true
}
}
pub struct JITManager {
factories: Vec<Box<dyn JITFunctionFactory>>,
defaultconfig: JITConfig,
}
impl JITManager {
pub fn new(defaultconfig: JITConfig) -> Self {
Self {
factories: Vec::new(),
defaultconfig,
}
}
pub fn register_factory(&mut self, factory: Box<dyn JITFunctionFactory>) {
self.factories.push(factory);
}
pub fn get_factory_for_array_type(
&self,
array_typeid: TypeId,
) -> Option<&dyn JITFunctionFactory> {
for factory in &self.factories {
if factory.supports_array_type(array_typeid) {
return Some(&**factory);
}
}
None
}
pub fn compile(
&self,
expression: &str,
array_typeid: TypeId,
) -> CoreResult<Box<dyn JITFunction>> {
if let Some(factory) = self.get_factory_for_array_type(array_typeid) {
factory.create_jit_function(expression, array_typeid)
} else {
Err(CoreError::JITError(ErrorContext::new(format!(
"No JIT factory supports array type: {array_typeid:?}"
))))
}
}
pub fn initialize(&mut self) {
let llvm_config = JITConfig {
backend: JITBackend::LLVM,
..self.defaultconfig.clone()
};
let llvm_factory = Box::new(LLVMFunctionFactory::new(llvm_config));
let cranelift_config = JITConfig {
backend: JITBackend::Cranelift,
..self.defaultconfig.clone()
};
let cranelift_factory = Box::new(CraneliftFunctionFactory::new(cranelift_config));
self.register_factory(llvm_factory);
self.register_factory(cranelift_factory);
}
#[must_use]
pub fn global() -> &'static RwLock<Self> {
static INSTANCE: LazyLock<RwLock<JITManager>> = LazyLock::new(|| {
RwLock::new(JITManager {
factories: Vec::new(),
defaultconfig: JITConfig {
backend: JITBackend::LLVM,
optimize: true,
opt_level: 2,
use_cache: true,
backend_options: HashMap::new(),
},
})
});
&INSTANCE
}
}
pub struct JITEnabledArray<T, A> {
inner: A,
phantom: PhantomData<T>,
}
impl<T, A> JITEnabledArray<T, A> {
pub fn new(inner: A) -> Self {
Self {
inner,
phantom: PhantomData,
}
}
pub const fn inner(&self) -> &A {
&self.inner
}
}
impl<T, A: Clone> Clone for JITEnabledArray<T, A> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
phantom: PhantomData::<T>,
}
}
}
impl<T, A> JITArray for JITEnabledArray<T, A>
where
T: Send + Sync + 'static,
A: ArrayProtocol + Clone + Send + Sync + 'static,
{
fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>> {
let jit_manager = JITManager::global();
let jit_manager = jit_manager.read().expect("Operation failed");
(*jit_manager).compile(expression, TypeId::of::<A>())
}
fn supports_jit(&self) -> bool {
let jit_manager = JITManager::global();
let jit_manager = jit_manager.read().expect("Operation failed");
jit_manager
.get_factory_for_array_type(TypeId::of::<A>())
.is_some()
}
fn jit_info(&self) -> HashMap<String, String> {
let mut info = HashMap::new();
let supported = self.supports_jit();
info.insert("supports_jit".to_string(), supported.to_string());
if supported {
let jit_manager = JITManager::global();
let jit_manager = jit_manager.read().expect("Operation failed");
if jit_manager
.get_factory_for_array_type(TypeId::of::<A>())
.is_some()
{
info.insert("factory".to_string(), "JIT factory available".to_string());
}
}
info
}
}
impl<T, A> ArrayProtocol for JITEnabledArray<T, A>
where
T: Send + Sync + 'static,
A: ArrayProtocol + Clone + Send + Sync + 'static,
{
fn array_function(
&self,
func: &ArrayFunction,
types: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, crate::array_protocol::NotImplemented> {
self.inner.array_function(func, types, args, kwargs)
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
self.inner.shape()
}
fn dtype(&self) -> TypeId {
self.inner.dtype()
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
let inner_clone = self.inner.clone();
Box::new(Self {
inner: inner_clone,
phantom: PhantomData::<T>,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array_protocol::NdarrayWrapper;
use ::ndarray::Array2;
#[test]
fn test_jit_function_creation() {
let config = JITConfig {
backend: JITBackend::LLVM,
..Default::default()
};
let factory = LLVMFunctionFactory::new(config);
let expression = "x + y";
let array_typeid = TypeId::of::<NdarrayWrapper<f64, crate::ndarray::Ix2>>();
let jit_function = factory
.create_jit_function(expression, array_typeid)
.expect("Operation failed");
assert_eq!(jit_function.source(), expression);
let compile_info = jit_function.compile_info();
assert_eq!(
compile_info.get("backend").expect("Operation failed"),
"LLVM"
);
}
#[test]
fn test_jit_manager() {
let mut jit_manager = JITManager::new(JITConfig::default());
jit_manager.initialize();
let array_typeid = TypeId::of::<NdarrayWrapper<f64, crate::ndarray::Ix2>>();
assert!(jit_manager
.get_factory_for_array_type(array_typeid)
.is_some());
let expression = "x + y";
let jit_function = jit_manager
.compile(expression, array_typeid)
.expect("Operation failed");
assert_eq!(jit_function.source(), expression);
}
#[test]
fn test_jit_enabled_array() {
let array = Array2::<f64>::ones((10, 5));
let wrapped = NdarrayWrapper::new(array);
let jit_array: JITEnabledArray<f64, _> = JITEnabledArray::new(wrapped);
{
let mut jit_manager = JITManager::global().write().expect("Operation failed");
jit_manager.initialize();
}
assert!(jit_array.supports_jit());
let expression = "x + y";
let jit_function = jit_array.compile(expression).expect("Operation failed");
assert_eq!(jit_function.source(), expression);
}
}