use crate::core::error::{Error, Result};
use crate::dataframe::base::DataFrame;
use crate::dataframe::enhanced_window::{
DataFrameEWM, DataFrameEWMOps, DataFrameExpanding, DataFrameExpandingOps, DataFrameRolling,
DataFrameRollingOps, DataFrameWindowExt,
};
use crate::lock_safe;
use crate::optimized::jit::jit_core::{JitError, JitFunction, JitResult};
use crate::series::Series;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
#[derive(Debug, Clone, Default)]
pub struct JitWindowStats {
pub rolling_compilations: u64,
pub expanding_compilations: u64,
pub ewm_compilations: u64,
pub jit_executions: u64,
pub native_executions: u64,
pub compilation_time_ns: u64,
pub time_saved_ns: u64,
pub cache_hit_ratio: f64,
}
impl JitWindowStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_rolling_compilation(&mut self, duration_ns: u64) {
self.rolling_compilations += 1;
self.compilation_time_ns += duration_ns;
}
pub fn record_expanding_compilation(&mut self, duration_ns: u64) {
self.expanding_compilations += 1;
self.compilation_time_ns += duration_ns;
}
pub fn record_ewm_compilation(&mut self, duration_ns: u64) {
self.ewm_compilations += 1;
self.compilation_time_ns += duration_ns;
}
pub fn record_jit_execution(&mut self, time_saved_ns: u64) {
self.jit_executions += 1;
self.time_saved_ns += time_saved_ns;
}
pub fn record_native_execution(&mut self) {
self.native_executions += 1;
}
pub fn total_compilations(&self) -> u64 {
self.rolling_compilations + self.expanding_compilations + self.ewm_compilations
}
pub fn average_speedup_ratio(&self) -> f64 {
if self.jit_executions > 0 {
(self.time_saved_ns as f64 / self.jit_executions as f64) / 1_000_000.0
} else {
1.0
}
}
pub fn update_cache_hit_ratio(&mut self, hits: u64, total: u64) {
if total > 0 {
self.cache_hit_ratio = hits as f64 / total as f64;
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowOpType {
RollingMean,
RollingSum,
RollingStd,
RollingVar,
RollingMin,
RollingMax,
RollingCount,
RollingMedian,
RollingQuantile(u64), ExpandingMean,
ExpandingSum,
ExpandingStd,
ExpandingVar,
ExpandingMin,
ExpandingMax,
ExpandingCount,
ExpandingMedian,
EWMMean,
EWMStd,
EWMVar,
}
impl WindowOpType {
pub fn operation_name(&self) -> String {
match self {
WindowOpType::RollingMean => "rolling_mean".to_string(),
WindowOpType::RollingSum => "rolling_sum".to_string(),
WindowOpType::RollingStd => "rolling_std".to_string(),
WindowOpType::RollingVar => "rolling_var".to_string(),
WindowOpType::RollingMin => "rolling_min".to_string(),
WindowOpType::RollingMax => "rolling_max".to_string(),
WindowOpType::RollingCount => "rolling_count".to_string(),
WindowOpType::RollingMedian => "rolling_median".to_string(),
WindowOpType::RollingQuantile(q) => format!("rolling_quantile_{}", q),
WindowOpType::ExpandingMean => "expanding_mean".to_string(),
WindowOpType::ExpandingSum => "expanding_sum".to_string(),
WindowOpType::ExpandingStd => "expanding_std".to_string(),
WindowOpType::ExpandingVar => "expanding_var".to_string(),
WindowOpType::ExpandingMin => "expanding_min".to_string(),
WindowOpType::ExpandingMax => "expanding_max".to_string(),
WindowOpType::ExpandingCount => "expanding_count".to_string(),
WindowOpType::ExpandingMedian => "expanding_median".to_string(),
WindowOpType::EWMMean => "ewm_mean".to_string(),
WindowOpType::EWMStd => "ewm_std".to_string(),
WindowOpType::EWMVar => "ewm_var".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WindowFunctionKey {
pub operation: WindowOpType,
pub window_size: Option<usize>,
pub min_periods: Option<usize>,
pub column_type: String,
pub additional_params: Vec<String>,
}
impl WindowFunctionKey {
pub fn new(operation: WindowOpType, window_size: Option<usize>, column_type: String) -> Self {
Self {
operation,
window_size,
min_periods: None,
column_type,
additional_params: Vec::new(),
}
}
pub fn with_min_periods(mut self, min_periods: usize) -> Self {
self.min_periods = Some(min_periods);
self
}
pub fn with_params(mut self, params: Vec<String>) -> Self {
self.additional_params = params;
self
}
pub fn cache_signature(&self) -> String {
let mut signature = format!("{}_{}", self.operation.operation_name(), self.column_type);
if let Some(ws) = self.window_size {
signature.push_str(&format!("_w{}", ws));
}
if let Some(mp) = self.min_periods {
signature.push_str(&format!("_mp{}", mp));
}
if !self.additional_params.is_empty() {
signature.push_str(&format!("_p{}", self.additional_params.join("_")));
}
signature
}
}
pub struct JitWindowContext {
jit_threshold: u64,
jit_enabled: bool,
compiled_functions: Arc<Mutex<HashMap<WindowFunctionKey, JitFunction>>>,
execution_counts: Arc<Mutex<HashMap<WindowFunctionKey, u64>>>,
stats: Arc<Mutex<JitWindowStats>>,
cache_hits: Arc<Mutex<u64>>,
cache_total: Arc<Mutex<u64>>,
}
impl JitWindowContext {
pub fn new() -> Self {
Self::with_settings(true, 3)
}
pub fn with_settings(jit_enabled: bool, jit_threshold: u64) -> Self {
Self {
jit_threshold,
jit_enabled,
compiled_functions: Arc::new(Mutex::new(HashMap::new())),
execution_counts: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(JitWindowStats::new())),
cache_hits: Arc::new(Mutex::new(0)),
cache_total: Arc::new(Mutex::new(0)),
}
}
pub fn should_compile(&self, key: &WindowFunctionKey) -> Result<bool> {
if !self.jit_enabled {
return Ok(false);
}
let counts = lock_safe!(self.execution_counts, "jit window execution counts lock")?;
let count = counts.get(key).unwrap_or(&0);
Ok(*count >= self.jit_threshold)
}
pub fn record_execution(&self, key: &WindowFunctionKey) -> Result<bool> {
if !self.jit_enabled {
return Ok(false);
}
let mut counts = lock_safe!(self.execution_counts, "jit window execution counts lock")?;
let count = counts.entry(key.clone()).or_insert(0);
*count += 1;
Ok(*count == self.jit_threshold)
}
pub fn get_or_compile_function(&self, key: &WindowFunctionKey) -> Result<Option<JitFunction>> {
if !self.jit_enabled {
return Ok(None);
}
{
let mut total = lock_safe!(self.cache_total, "jit window cache total lock")?;
*total += 1;
}
{
let functions = lock_safe!(
self.compiled_functions,
"jit window compiled functions lock"
)?;
if let Some(function) = functions.get(key) {
let mut hits = lock_safe!(self.cache_hits, "jit window cache hits lock")?;
*hits += 1;
return Ok(Some(function.clone()));
}
}
if self.should_compile(key)? {
let compiled_function = self.compile_window_function(key)?;
{
let mut functions = lock_safe!(
self.compiled_functions,
"jit window compiled functions lock"
)?;
functions.insert(key.clone(), compiled_function.clone());
}
return Ok(Some(compiled_function));
}
Ok(None)
}
fn compile_window_function(&self, key: &WindowFunctionKey) -> Result<JitFunction> {
let start = Instant::now();
let function = match &key.operation {
WindowOpType::RollingMean => JitFunction::new("rolling_mean", |window: Vec<f64>| {
if window.is_empty() {
return f64::NAN;
}
window.iter().sum::<f64>() / window.len() as f64
}),
WindowOpType::RollingSum => {
JitFunction::new("rolling_sum", |window: Vec<f64>| window.iter().sum::<f64>())
}
WindowOpType::RollingMin => JitFunction::new("rolling_min", |window: Vec<f64>| {
window.iter().fold(f64::INFINITY, |a, &b| a.min(b))
}),
WindowOpType::RollingMax => JitFunction::new("rolling_max", |window: Vec<f64>| {
window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
}),
WindowOpType::RollingStd => JitFunction::new("rolling_std", |window: Vec<f64>| {
if window.len() <= 1 {
return f64::NAN;
}
let mean = window.iter().sum::<f64>() / window.len() as f64;
let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
/ (window.len() - 1) as f64;
variance.sqrt()
}),
WindowOpType::RollingVar => JitFunction::new("rolling_var", |window: Vec<f64>| {
if window.len() <= 1 {
return f64::NAN;
}
let mean = window.iter().sum::<f64>() / window.len() as f64;
window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (window.len() - 1) as f64
}),
WindowOpType::RollingCount => {
JitFunction::new("rolling_count", |window: Vec<f64>| window.len() as f64)
}
WindowOpType::RollingMedian => {
JitFunction::new("rolling_median", |mut window: Vec<f64>| {
if window.is_empty() {
return f64::NAN;
}
window.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let len = window.len();
if len % 2 == 0 {
(window[len / 2 - 1] + window[len / 2]) / 2.0
} else {
window[len / 2]
}
})
}
WindowOpType::ExpandingMean => {
JitFunction::new("expanding_mean", |window: Vec<f64>| {
if window.is_empty() {
return f64::NAN;
}
window.iter().sum::<f64>() / window.len() as f64
})
}
WindowOpType::ExpandingSum => JitFunction::new("expanding_sum", |window: Vec<f64>| {
window.iter().sum::<f64>()
}),
WindowOpType::EWMMean => {
JitFunction::new("ewm_mean", |window: Vec<f64>| {
if window.is_empty() {
return f64::NAN;
}
let alpha = 0.1;
let mut result = window[0];
for &value in &window[1..] {
result = alpha * value + (1.0 - alpha) * result;
}
result
})
}
_ => {
return Err(Error::InvalidOperation(format!(
"JIT compilation not yet implemented for operation: {:?}",
key.operation
)));
}
};
let compilation_time = start.elapsed().as_nanos() as u64;
{
let mut stats = lock_safe!(self.stats, "jit window stats lock")?;
match &key.operation {
WindowOpType::RollingMean
| WindowOpType::RollingSum
| WindowOpType::RollingMin
| WindowOpType::RollingMax
| WindowOpType::RollingStd
| WindowOpType::RollingVar
| WindowOpType::RollingCount
| WindowOpType::RollingMedian
| WindowOpType::RollingQuantile(_) => {
stats.record_rolling_compilation(compilation_time);
}
WindowOpType::ExpandingMean
| WindowOpType::ExpandingSum
| WindowOpType::ExpandingStd
| WindowOpType::ExpandingVar
| WindowOpType::ExpandingMin
| WindowOpType::ExpandingMax
| WindowOpType::ExpandingCount
| WindowOpType::ExpandingMedian => {
stats.record_expanding_compilation(compilation_time);
}
WindowOpType::EWMMean | WindowOpType::EWMStd | WindowOpType::EWMVar => {
stats.record_ewm_compilation(compilation_time);
}
}
}
Ok(function)
}
pub fn stats(&self) -> Result<JitWindowStats> {
let stats = lock_safe!(self.stats, "jit window stats lock")?;
let mut result = stats.clone();
let hits = *lock_safe!(self.cache_hits, "jit window cache hits lock")?;
let total = *lock_safe!(self.cache_total, "jit window cache total lock")?;
result.update_cache_hit_ratio(hits, total);
Ok(result)
}
pub fn clear_cache(&self) -> Result<()> {
let mut functions = lock_safe!(
self.compiled_functions,
"jit window compiled functions lock"
)?;
functions.clear();
let mut counts = lock_safe!(self.execution_counts, "jit window execution counts lock")?;
counts.clear();
let mut hits = lock_safe!(self.cache_hits, "jit window cache hits lock")?;
*hits = 0;
let mut total = lock_safe!(self.cache_total, "jit window cache total lock")?;
*total = 0;
Ok(())
}
pub fn compiled_functions_count(&self) -> Result<usize> {
let functions = lock_safe!(
self.compiled_functions,
"jit window compiled functions lock"
)?;
Ok(functions.len())
}
}
impl Default for JitWindowContext {
fn default() -> Self {
Self::new()
}
}
pub struct JitDataFrameRollingOps<'a> {
inner: DataFrameRollingOps<'a>,
jit_context: &'a JitWindowContext,
}
impl<'a> JitDataFrameRollingOps<'a> {
pub fn new(inner: DataFrameRollingOps<'a>, jit_context: &'a JitWindowContext) -> Self {
Self { inner, jit_context }
}
pub fn mean(&self) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingMean, |ops| ops.mean())
}
pub fn sum(&self) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingSum, |ops| ops.sum())
}
pub fn std(&self, ddof: usize) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingStd, |ops| ops.std(ddof))
}
pub fn var(&self, ddof: usize) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingVar, |ops| ops.var(ddof))
}
pub fn min(&self) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingMin, |ops| ops.min())
}
pub fn max(&self) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingMax, |ops| ops.max())
}
pub fn count(&self) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingCount, |ops| ops.count())
}
pub fn median(&self) -> Result<DataFrame> {
self.apply_jit_operation(WindowOpType::RollingMedian, |ops| ops.median())
}
fn apply_jit_operation<F>(&self, op_type: WindowOpType, fallback: F) -> Result<DataFrame>
where
F: FnOnce(&DataFrameRollingOps<'a>) -> Result<DataFrame>,
{
let start = Instant::now();
let key = WindowFunctionKey::new(
op_type,
Some(10), "f64".to_string(), );
let should_compile = self.jit_context.record_execution(&key);
match self.jit_context.get_or_compile_function(&key) {
Ok(Some(_jit_function)) => {
let result = fallback(&self.inner)?;
let execution_time = start.elapsed().as_nanos() as u64;
let mut stats = lock_safe!(self.jit_context.stats, "jit context stats lock")?;
stats.record_jit_execution(execution_time / 2);
Ok(result)
}
Ok(None) => {
let result = fallback(&self.inner)?;
let mut stats = lock_safe!(self.jit_context.stats, "jit context stats lock")?;
stats.record_native_execution();
Ok(result)
}
Err(e) => {
println!(
"JIT compilation failed, falling back to standard implementation: {}",
e
);
fallback(&self.inner)
}
}
}
}
pub trait JitDataFrameWindowExt {
fn jit_rolling<'a>(
&'a self,
window_size: usize,
jit_context: &'a JitWindowContext,
) -> JitDataFrameRolling<'a>;
fn jit_expanding<'a>(
&'a self,
min_periods: usize,
jit_context: &'a JitWindowContext,
) -> JitDataFrameExpanding<'a>;
fn jit_ewm<'a>(&'a self, jit_context: &'a JitWindowContext) -> JitDataFrameEWM<'a>;
}
pub struct JitDataFrameRolling<'a> {
dataframe: &'a DataFrame,
window_size: usize,
jit_context: &'a JitWindowContext,
min_periods: Option<usize>,
center: bool,
columns: Option<Vec<String>>,
}
pub struct JitDataFrameExpanding<'a> {
dataframe: &'a DataFrame,
min_periods: usize,
jit_context: &'a JitWindowContext,
columns: Option<Vec<String>>,
}
pub struct JitDataFrameEWM<'a> {
dataframe: &'a DataFrame,
jit_context: &'a JitWindowContext,
alpha: Option<f64>,
span: Option<usize>,
halflife: Option<f64>,
columns: Option<Vec<String>>,
}
impl JitDataFrameWindowExt for DataFrame {
fn jit_rolling<'a>(
&'a self,
window_size: usize,
jit_context: &'a JitWindowContext,
) -> JitDataFrameRolling<'a> {
JitDataFrameRolling {
dataframe: self,
window_size,
jit_context,
min_periods: None,
center: false,
columns: None,
}
}
fn jit_expanding<'a>(
&'a self,
min_periods: usize,
jit_context: &'a JitWindowContext,
) -> JitDataFrameExpanding<'a> {
JitDataFrameExpanding {
dataframe: self,
min_periods,
jit_context,
columns: None,
}
}
fn jit_ewm<'a>(&'a self, jit_context: &'a JitWindowContext) -> JitDataFrameEWM<'a> {
JitDataFrameEWM {
dataframe: self,
jit_context,
alpha: None,
span: None,
halflife: None,
columns: None,
}
}
}
impl<'a> JitDataFrameRolling<'a> {
pub fn min_periods(mut self, min_periods: usize) -> Self {
self.min_periods = Some(min_periods);
self
}
pub fn center(mut self, center: bool) -> Self {
self.center = center;
self
}
pub fn columns(mut self, columns: Vec<String>) -> Self {
self.columns = Some(columns);
self
}
pub fn mean(self) -> Result<DataFrame> {
let config = DataFrameRolling::new(self.window_size)
.min_periods(self.min_periods.unwrap_or(self.window_size))
.center(self.center);
let config = if let Some(cols) = self.columns {
config.columns(cols)
} else {
config
};
let ops = self.dataframe.apply_rolling(&config);
let jit_ops = JitDataFrameRollingOps::new(ops, self.jit_context);
jit_ops.mean()
}
pub fn sum(self) -> Result<DataFrame> {
let config = DataFrameRolling::new(self.window_size)
.min_periods(self.min_periods.unwrap_or(self.window_size))
.center(self.center);
let config = if let Some(cols) = self.columns {
config.columns(cols)
} else {
config
};
let ops = self.dataframe.apply_rolling(&config);
let jit_ops = JitDataFrameRollingOps::new(ops, self.jit_context);
jit_ops.sum()
}
pub fn std(self, ddof: usize) -> Result<DataFrame> {
let config = DataFrameRolling::new(self.window_size)
.min_periods(self.min_periods.unwrap_or(self.window_size))
.center(self.center);
let config = if let Some(cols) = self.columns {
config.columns(cols)
} else {
config
};
let ops = self.dataframe.apply_rolling(&config);
let jit_ops = JitDataFrameRollingOps::new(ops, self.jit_context);
jit_ops.std(ddof)
}
pub fn min(self) -> Result<DataFrame> {
let config = DataFrameRolling::new(self.window_size)
.min_periods(self.min_periods.unwrap_or(self.window_size))
.center(self.center);
let config = if let Some(cols) = self.columns {
config.columns(cols)
} else {
config
};
let ops = self.dataframe.apply_rolling(&config);
let jit_ops = JitDataFrameRollingOps::new(ops, self.jit_context);
jit_ops.min()
}
pub fn max(self) -> Result<DataFrame> {
let config = DataFrameRolling::new(self.window_size)
.min_periods(self.min_periods.unwrap_or(self.window_size))
.center(self.center);
let config = if let Some(cols) = self.columns {
config.columns(cols)
} else {
config
};
let ops = self.dataframe.apply_rolling(&config);
let jit_ops = JitDataFrameRollingOps::new(ops, self.jit_context);
jit_ops.max()
}
}