use itertools::Itertools;
use polars_core::{prelude::*, series::Series};
use polars_ops::prelude::*;
use re_log_types::{DataCell, EntityPath, RowId, TimeInt};
use re_types_core::ComponentName;
use crate::{ArrayExt, DataStore, LatestAtQuery, RangeQuery};
pub type SharedPolarsError = Arc<PolarsError>;
pub type SharedResult<T> = ::std::result::Result<T, SharedPolarsError>;
pub fn latest_component(
store: &DataStore,
query: &LatestAtQuery,
ent_path: &EntityPath,
primary: ComponentName,
) -> SharedResult<DataFrame> {
let cluster_key = store.cluster_key();
let components = &[cluster_key, primary];
let (_, cells) = store
.latest_at(query, ent_path, primary, components)
.unwrap_or((RowId::ZERO, [(); 2].map(|_| None)));
dataframe_from_cells(&cells)
}
pub fn latest_components(
store: &DataStore,
query: &LatestAtQuery,
ent_path: &EntityPath,
primaries: &[ComponentName],
join_type: &JoinType,
) -> SharedResult<DataFrame> {
let cluster_key = store.cluster_key();
let dfs = primaries
.iter()
.filter(|primary| **primary != cluster_key)
.map(|primary| latest_component(store, query, ent_path, *primary));
join_dataframes(cluster_key, join_type, dfs)
}
pub fn range_components<'a, const N: usize>(
store: &'a DataStore,
query: &'a RangeQuery,
ent_path: &'a EntityPath,
primary: ComponentName,
components: [ComponentName; N],
join_type: &'a JoinType,
) -> impl Iterator<Item = SharedResult<(Option<TimeInt>, DataFrame)>> + 'a {
let cluster_key = store.cluster_key();
assert!(components.contains(&cluster_key));
assert!(components.contains(&primary));
let mut state = None;
let latest_time = query.range.min.as_i64().checked_sub(1).map(Into::into);
let mut df_latest = None;
if let Some(latest_time) = latest_time {
let df = latest_components(
store,
&LatestAtQuery::new(query.timeline, latest_time),
ent_path,
&components,
join_type,
);
if df.as_ref().map_or(false, |df| {
!df.is_empty() && df.column(primary.as_ref()).is_ok()
}) {
df_latest = Some(df);
}
}
let primary_col = components
.iter()
.find_position(|component| **component == primary)
.map(|(col, _)| col)
.unwrap();
df_latest
.into_iter()
.map(move |df| (latest_time, true, df))
.chain(
store
.range(query, ent_path, components)
.map(move |(time, _, cells)| {
(
time,
cells[primary_col].is_some(), dataframe_from_cells(&cells),
)
}),
)
.filter_map(move |(time, is_primary, df)| {
state = Some(join_dataframes(
cluster_key,
join_type,
[state.clone() , Some(df)]
.into_iter()
.flatten(),
));
is_primary.then_some(state.clone().unwrap().map(|df| {
let columns = df.get_column_names();
let df = df
.select(
components
.clone()
.iter()
.filter(|col| columns.contains(&col.as_ref())),
)
.unwrap();
(time, df)
}))
})
}
pub fn dataframe_from_cells<const N: usize>(
cells: &[Option<DataCell>; N],
) -> SharedResult<DataFrame> {
let series: Result<Vec<_>, _> = cells
.iter()
.flatten()
.map(|cell| {
Series::try_from((
cell.component_name().as_ref(),
cell.as_arrow_ref().clean_for_polars(),
))
})
.collect();
DataFrame::new(series?).map_err(Into::into)
}
pub fn join_dataframes(
cluster_key: ComponentName,
join_type: &JoinType,
dfs: impl Iterator<Item = SharedResult<DataFrame>>,
) -> SharedResult<DataFrame> {
let df = dfs
.into_iter()
.filter(|df| df.as_ref().map_or(true, |df| !df.is_empty()))
.reduce(|left, right| {
let mut left = left?;
let right = right?;
for col in right
.get_column_names()
.iter()
.filter(|col| *col != &cluster_key)
{
_ = left.drop_in_place(col);
}
left.join(
&right,
[cluster_key],
[cluster_key],
join_type.clone(),
None,
)
.map(|df| drop_all_nulls(&df, &cluster_key).unwrap())
.map_err(Into::into)
})
.unwrap_or_else(|| Ok(DataFrame::default()))?;
Ok(df.sort([cluster_key.as_str()], false).unwrap_or(df))
}
pub fn drop_all_nulls(df: &DataFrame, cluster_key: &ComponentName) -> PolarsResult<DataFrame> {
let cols = df
.get_column_names()
.into_iter()
.filter(|col| *col != cluster_key.as_str());
let mut iter = df.select_series(cols)?.into_iter();
if iter.clone().all(|s| !s.has_validity()) {
return Ok(df.clone());
}
let mask = iter
.next()
.ok_or_else(|| PolarsError::NoData("No data to drop nulls from".into()))?;
let mut mask = mask.is_not_null();
for s in iter {
mask = mask | s.is_not_null();
}
df.filter(&mask)
}