use std::sync::Arc;
use chrono_tz::Tz;
use polars_core::frame::DataFrame;
use polars_core::prelude::{Column, DataType, GroupsType, TimeUnit};
use polars_core::schema::Schema;
use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure};
use polars_expr::state::ExecutionState;
use polars_time::prelude::{RollingWindower, ensure_duration_matches_dtype};
use polars_time::{ClosedWindow, Duration};
use polars_utils::IdxSize;
use polars_utils::pl_str::PlSmallStr;
use super::ComputeNode;
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
use crate::async_primitives::distributor_channel::distributor_channel;
use crate::async_primitives::wait_group::WaitGroup;
use crate::execute::StreamingExecutionState;
use crate::expression::StreamExpr;
use crate::graph::PortState;
use crate::morsel::{Morsel, MorselSeq, SourceToken};
use crate::pipe::{RecvPort, SendPort};
type NextWindows = (Vec<[IdxSize; 2]>, DataFrame, Column);
pub struct RollingGroupBy {
buf_df: DataFrame,
buf_df_offset: IdxSize,
buf_index_column: Column,
buf_key_column: Column,
seq: MorselSeq,
slice_offset: IdxSize,
slice_length: IdxSize,
index_column: PlSmallStr,
windower: RollingWindower,
aggs: Arc<[(PlSmallStr, StreamExpr)]>,
}
impl RollingGroupBy {
pub fn new(
schema: Arc<Schema>,
index_column: PlSmallStr,
period: Duration,
offset: Duration,
closed: ClosedWindow,
slice: Option<(IdxSize, IdxSize)>,
aggs: Arc<[(PlSmallStr, StreamExpr)]>,
) -> PolarsResult<Self> {
polars_ensure!(
!period.is_zero() && !period.negative(),
ComputeError: "rolling window period should be strictly positive",
);
let key_dtype = schema.get(&index_column).unwrap();
ensure_duration_matches_dtype(period, key_dtype, "period")?;
ensure_duration_matches_dtype(offset, key_dtype, "offset")?;
use DataType as DT;
let (tu, tz) = match key_dtype {
DT::Datetime(tu, tz) => (*tu, tz.clone()),
DT::Date => (TimeUnit::Microseconds, None),
DT::UInt32 | DT::UInt64 | DT::Int64 | DT::Int32 => (TimeUnit::Nanoseconds, None),
dt => polars_bail!(
ComputeError:
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",
dt
),
};
let buf_df = DataFrame::empty_with_arc_schema(schema.clone());
let buf_key_column = Column::new_empty(index_column.clone(), key_dtype);
let buf_index_column =
Column::new_empty(index_column.clone(), &DT::Datetime(tu, tz.clone()));
let tz = tz.and_then(|tz| tz.parse::<Tz>().ok());
let windower = RollingWindower::new(period, offset, closed, tu, tz);
let (slice_offset, slice_length) = slice.unwrap_or((0, IdxSize::MAX));
Ok(Self {
buf_df,
buf_df_offset: 0,
buf_index_column,
buf_key_column,
seq: MorselSeq::default(),
slice_offset,
slice_length,
index_column,
windower,
aggs,
})
}
async fn evaluate_one(
windows: Vec<[IdxSize; 2]>,
key: Column,
aggs: &[(PlSmallStr, StreamExpr)],
state: &ExecutionState,
mut df: DataFrame,
) -> PolarsResult<DataFrame> {
assert_eq!(windows.len(), key.len());
let groups = GroupsType::new_slice(windows, true, true).into_sliceable();
df.rechunk_mut();
let mut columns = Vec::with_capacity(1 + aggs.len());
let height = key.len();
columns.push(key);
for (name, agg) in aggs.iter() {
let mut agg = agg.evaluate_on_groups(&df, &groups, state).await?;
let agg = agg.finalize();
columns.push(agg.with_name(name.clone()));
}
Ok(unsafe { DataFrame::new_unchecked(height, columns) })
}
fn next_windows(&mut self, finalize: bool) -> PolarsResult<Option<NextWindows>> {
let buf_index_col_dt = self.buf_index_column.datetime()?;
let mut time = Vec::new();
time.extend(
buf_index_col_dt
.physical()
.downcast_iter()
.map(|arr| arr.values().as_slice()),
);
let mut windows = Vec::new();
let num_retired = if finalize {
self.windower.finalize(&time, &mut windows);
self.buf_key_column.len() as IdxSize
} else {
self.windower.insert(&time, &mut windows)?
};
if num_retired == 0 && windows.is_empty() {
return Ok(None);
}
let start_row_offset = self.buf_df_offset;
self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);
let new_buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);
let data = std::mem::replace(&mut self.buf_df, new_buf_df);
self.buf_df_offset += num_retired;
if windows.is_empty() {
return Ok(None);
}
let key;
(key, self.buf_key_column) = self.buf_key_column.split_at(windows.len() as i64);
let key = key.slice(self.slice_offset as i64, self.slice_length as usize);
let offset = windows[0][0];
let end = windows.last().unwrap();
let end = end[0] + end[1];
if self.slice_offset as usize > windows.len() {
self.slice_offset -= windows.len() as IdxSize;
windows.clear();
} else if self.slice_offset > 0 {
let offset = self.slice_offset as usize;
self.slice_offset = self.slice_offset.saturating_sub(windows.len() as IdxSize);
windows.drain(..offset);
}
windows.truncate(windows.len().min(self.slice_length as usize));
self.slice_length -= windows.len() as IdxSize;
if windows.is_empty() {
return Ok(None);
}
windows.iter_mut().for_each(|[s, _]| *s -= offset);
let data = data.slice(
(offset - start_row_offset) as i64,
(end - start_row_offset) as usize,
);
Ok(Some((windows, data, key)))
}
}
impl ComputeNode for RollingGroupBy {
fn name(&self) -> &str {
"rolling-group-by"
}
fn update_state(
&mut self,
recv: &mut [PortState],
send: &mut [PortState],
_state: &StreamingExecutionState,
) -> PolarsResult<()> {
assert!(recv.len() == 1 && send.len() == 1);
if self.slice_length == 0 {
recv[0] = PortState::Done;
send[0] = PortState::Done;
std::mem::take(&mut self.buf_df);
return Ok(());
}
if send[0] == PortState::Done {
recv[0] = PortState::Done;
std::mem::take(&mut self.buf_df);
} else if recv[0] == PortState::Done {
if self.buf_df.height() == 0 {
send[0] = PortState::Done;
} else {
send[0] = PortState::Ready;
}
} else {
recv.swap_with_slice(send);
}
Ok(())
}
fn spawn<'env, 's>(
&'env mut self,
scope: &'s TaskScope<'s, 'env>,
recv_ports: &mut [Option<RecvPort<'_>>],
send_ports: &mut [Option<SendPort<'_>>],
state: &'s StreamingExecutionState,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) {
assert!(recv_ports.len() == 1 && send_ports.len() == 1);
let Some(recv) = recv_ports[0].take() else {
assert!(self.buf_df.height() > 0);
assert!(self.slice_length > 0);
let mut send = send_ports[0].take().unwrap().serial();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
if let Some((windows, df, key)) = self.next_windows(true)? {
let df = Self::evaluate_one(
windows,
key,
&self.aggs,
&state.in_memory_exec_state,
df,
)
.await?;
_ = send
.send(Morsel::new(df, self.seq.successor(), SourceToken::new()))
.await;
}
self.buf_df = self.buf_df.clear();
self.buf_key_column = self.buf_key_column.clear();
self.buf_index_column = self.buf_index_column.clear();
Ok(())
}));
return;
};
let mut recv = recv.serial();
let send = send_ports[0].take().unwrap().parallel();
let (mut distributor, rxs) = distributor_channel::<(Morsel, Column, Vec<[IdxSize; 2]>)>(
send.len(),
*DEFAULT_DISTRIBUTOR_BUFFER_SIZE,
);
join_handles.extend(rxs.into_iter().zip(send).map(|(mut rx, mut tx)| {
let wg = WaitGroup::default();
let aggs = self.aggs.clone();
let state = state.in_memory_exec_state.split();
scope.spawn_task(TaskPriority::High, async move {
while let Ok((mut morsel, key, windows)) = rx.recv().await {
morsel = morsel
.async_try_map::<PolarsError, _, _>(async |df| {
Self::evaluate_one(windows, key, &aggs, &state, df).await
})
.await?;
morsel.set_consume_token(wg.token());
if tx.send(morsel).await.is_err() {
break;
}
wg.wait().await;
}
Ok(())
})
}));
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
while let Ok(morsel) = recv.recv().await
&& self.slice_length > 0
{
let (df, seq, source_token, wait_token) = morsel.into_inner();
self.seq = seq;
drop(wait_token);
if df.height() == 0 {
continue;
}
let morsel_index_column = df.column(&self.index_column)?;
polars_ensure!(
morsel_index_column.null_count() == 0,
ComputeError: "null values in `rolling` not supported, fill nulls."
);
self.buf_key_column.append(morsel_index_column)?;
use DataType as DT;
let morsel_index_column = match morsel_index_column.dtype() {
DT::Datetime(_, _) => morsel_index_column.clone(),
DT::Date => {
morsel_index_column.cast(&DT::Datetime(TimeUnit::Microseconds, None))?
},
DT::UInt32 | DT::UInt64 | DT::Int32 => morsel_index_column
.cast(&DT::Int64)?
.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?,
DT::Int64 => {
morsel_index_column.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?
},
_ => unreachable!(),
};
self.buf_index_column.append(&morsel_index_column)?;
self.buf_df.vstack_mut_owned(df)?;
if let Some((windows, df, key)) = self.next_windows(false)? {
if distributor
.send((Morsel::new(df, seq, source_token), key, windows))
.await
.is_err()
{
break;
}
}
}
Ok(())
}));
}
}