use std::mem;
use std::sync::Arc;
use arrow::datatypes::{DataType, SchemaRef};
use crossbeam::channel::{Receiver, TryRecvError};
use datafusion::common::{DataFusionError, TableReference};
use datafusion::execution::SendableRecordBatchStream;
use datafusion::functions::expr_fn::concat;
use datafusion::logical_expr::{binary_expr, col as datafusion_col, lit};
use datafusion::prelude::{SessionContext, cast, encode};
use futures::{StreamExt as _, TryStreamExt as _};
use re_log::{error, warn};
use re_log_types::Timestamp;
use re_mutex::Mutex;
use re_quota_channel::send_crossbeam;
use re_sorbet::{BatchType, SorbetBatch, SorbetSchema};
use re_viewer_context::AsyncRuntimeHandle;
use crate::ColumnFilter;
use crate::grid_view::FlagChangeEvent;
use crate::table_blueprint::{EntryLinksSpec, SegmentLinksSpec, SortBy, TableBlueprint};
use crate::table_selection::TableSelectionState;
fn col(name: &str) -> datafusion::logical_expr::Expr {
datafusion_col(format!("{name:?}"))
}
#[derive(Debug, Clone, PartialEq, Default)]
struct DataFusionQueryData {
pub sort_by: Option<SortBy>,
pub segment_links: Option<SegmentLinksSpec>,
pub entry_links: Option<EntryLinksSpec>,
pub prefilter: Option<datafusion::prelude::Expr>,
pub column_filters: Vec<ColumnFilter>,
}
impl From<&TableBlueprint> for DataFusionQueryData {
fn from(value: &TableBlueprint) -> Self {
let TableBlueprint {
sort_by,
segment_links,
entry_links,
prefilter,
column_filters,
grid_view_card_title: _,
flag_column: _,
} = value;
Self {
sort_by: sort_by.clone(),
segment_links: segment_links.clone(),
entry_links: entry_links.clone(),
prefilter: prefilter.clone(),
column_filters: column_filters.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct DataFusionQueryResult {
pub sorbet_batches: Vec<SorbetBatch>,
pub original_schema: SchemaRef,
pub sorbet_schema: re_sorbet::SorbetSchema,
pub finished: bool,
}
impl DataFusionQueryResult {
fn find_row_indices(&self, global_row: u64) -> Option<(usize, usize)> {
let mut remaining = global_row as usize;
for (batch_idx, batch) in self.sorbet_batches.iter().enumerate() {
let num_rows = batch.num_rows();
if remaining < num_rows {
return Some((batch_idx, remaining));
}
remaining -= num_rows;
}
None
}
pub fn find_row_batch(&self, global_row: u64) -> Option<(&SorbetBatch, usize)> {
let (idx, offset) = self.find_row_indices(global_row)?;
Some((&self.sorbet_batches[idx], offset))
}
pub fn find_row_batch_mut(&mut self, global_row: u64) -> Option<(&mut SorbetBatch, usize)> {
let (idx, offset) = self.find_row_indices(global_row)?;
Some((&mut self.sorbet_batches[idx], offset))
}
}
#[derive(Clone)]
struct DataFusionQuery {
session_ctx: Arc<SessionContext>,
table_ref: TableReference,
query_data: DataFusionQueryData,
}
impl DataFusionQuery {
fn new(
session_ctx: Arc<SessionContext>,
table_ref: TableReference,
query_data: DataFusionQueryData,
) -> Self {
Self {
session_ctx,
table_ref,
query_data,
}
}
async fn batch_stream(self) -> Result<SendableRecordBatchStream, DataFusionError> {
let mut dataframe = self.session_ctx.table(self.table_ref).await?;
let DataFusionQueryData {
sort_by,
segment_links,
entry_links,
prefilter,
column_filters,
} = &self.query_data;
if let Some(segment_links) = segment_links {
let uri = format!(
"{}/dataset/{}/data?segment_id=",
segment_links.origin, segment_links.dataset_id
);
dataframe = dataframe.with_column(
&segment_links.column_name,
concat(vec![lit(uri), col(&segment_links.segment_id_column_name)]),
)?;
}
if let Some(entry_links) = entry_links {
let uri = format!("{}/entry/", entry_links.origin);
let column = concat(vec![
lit(uri),
encode(
cast(col(&entry_links.entry_id_column_name), DataType::Binary),
lit("hex"),
),
]);
dataframe = dataframe.with_column(&entry_links.column_name, column)?;
}
if let Some(prefilter) = prefilter {
dataframe = dataframe.filter(prefilter.clone())?;
}
let filter_exprs = column_filters
.iter()
.filter_map(|filter| {
filter
.as_filter_expression()
.inspect_err(|err| {
re_log::warn_once!("invalid filter: {err}");
})
.ok()
})
.collect();
let filter_expr =
balanced_binary_exprs(filter_exprs, datafusion::logical_expr::Operator::And);
if let Some(filter_expr) = filter_expr {
dataframe = dataframe.filter(filter_expr)?;
}
if let Some(sort_by) = sort_by {
dataframe = dataframe.sort(vec![
col(&sort_by.column_physical_name).sort(sort_by.direction.is_ascending(), true),
])?;
}
let stream = dataframe.execute_stream().await?;
Ok(stream)
}
fn execute_streaming(self, runtime: &AsyncRuntimeHandle) -> Receiver<QueryEvent> {
let (tx, rx) = crate::create_channel(1000);
runtime.spawn_future(async move {
if let Ok(stream) = self.batch_stream().await {
let schema = stream.schema();
let mut sorbet_stream = stream.and_then(|s| {
std::future::ready(
SorbetBatch::try_from_record_batch(&s, BatchType::Dataframe)
.map_err(|err| DataFusionError::External(err.into())),
)
});
let mut sent_schemas = false;
let mut sent_error = false;
while let Some(frame) = sorbet_stream.next().await {
match frame {
Ok(batch) => {
if !sent_schemas {
let sorbet_schema = batch.sorbet_schema().clone();
let original_schema = Arc::clone(&schema);
if send_crossbeam(
&tx,
QueryEvent::Schema {
original_schema,
sorbet_schema,
},
)
.is_err()
{
return; }
sent_schemas = true;
}
if send_crossbeam(&tx, QueryEvent::Batch(batch)).is_err() {
return; }
}
Err(err) => {
sent_error = true;
send_crossbeam(&tx, QueryEvent::Error(err)).ok();
}
}
}
if !sent_schemas && !sent_error {
let sorbet_schema = SorbetSchema::try_from_raw_arrow_schema(schema.clone());
match sorbet_schema {
Ok(sorbet_schema) => {
send_crossbeam(
&tx,
QueryEvent::Schema {
original_schema: schema,
sorbet_schema,
},
)
.ok();
}
Err(err) => {
send_crossbeam(
&tx,
QueryEvent::Error(DataFusionError::External(err.into())),
)
.ok();
}
}
}
}
});
rx
}
}
#[derive(Debug)]
pub enum QueryEvent {
Schema {
original_schema: SchemaRef,
sorbet_schema: re_sorbet::SorbetSchema,
},
Batch(SorbetBatch),
Error(DataFusionError),
}
impl PartialEq for DataFusionQuery {
fn eq(&self, other: &Self) -> bool {
let Self {
session_ctx,
table_ref,
query_data,
} = self;
Arc::ptr_eq(session_ctx, &other.session_ctx)
&& table_ref == &other.table_ref
&& query_data == &other.query_data
}
}
#[derive(Clone)]
pub struct DataFusionAdapter {
id: egui::Id,
blueprint: TableBlueprint,
query: DataFusionQuery,
pub last_query_results: Option<Result<DataFusionQueryResult, Arc<DataFusionError>>>,
pub rx: Arc<Mutex<Receiver<QueryEvent>>>,
pub results: Option<Result<DataFusionQueryResult, Arc<DataFusionError>>>,
pub queried_at: Timestamp,
}
impl DataFusionAdapter {
pub fn clear_state(egui_ctx: &egui::Context, id: egui::Id) {
egui_ctx.data_mut(|data| {
data.remove::<Self>(id);
});
}
pub fn get(
runtime: &AsyncRuntimeHandle,
ui: &egui::Ui,
session_ctx: &Arc<SessionContext>,
table_ref: TableReference,
id: egui::Id,
initial_blueprint: TableBlueprint,
) -> Self {
let adapter = ui.data(|data| data.get_temp::<Self>(id));
let mut adapter = adapter.unwrap_or_else(|| {
let initial_query = DataFusionQueryData::from(&initial_blueprint);
let query = DataFusionQuery::new(Arc::clone(session_ctx), table_ref, initial_query);
let rx = query.clone().execute_streaming(runtime);
let table_state = Self {
id,
blueprint: initial_blueprint,
rx: Arc::new(Mutex::new(rx)),
results: None,
query,
last_query_results: None,
queried_at: Timestamp::now(),
};
ui.data_mut(|data| {
data.insert_temp(id, table_state.clone());
});
table_state
});
{
let rx = adapter.rx.lock();
let mut changed = false;
loop {
match rx.try_recv() {
Ok(QueryEvent::Schema {
sorbet_schema,
original_schema,
}) => {
adapter.results = Some(Ok(DataFusionQueryResult {
original_schema,
sorbet_schema,
sorbet_batches: vec![],
finished: false,
}));
changed = true;
}
Ok(QueryEvent::Batch(batch)) => match &mut adapter.results {
Some(Ok(data)) => {
data.sorbet_batches.push(batch);
changed = true;
adapter.last_query_results = None;
}
Some(Err(err)) => {
warn!("Received data after receiving an error: {err}");
}
None => {
error!("Received data before receiving schema");
}
},
Ok(QueryEvent::Error(err)) => {
error!("DataFusion query error: {err}");
adapter.results = Some(Err(Arc::new(err)));
changed = true;
}
Err(TryRecvError::Empty) => {
break;
}
Err(TryRecvError::Disconnected) => {
if let Some(Ok(data)) = &mut adapter.results {
data.finished = true;
changed = true;
}
break;
}
}
}
if changed {
ui.data_mut(|data| {
data.insert_temp(adapter.id, adapter.clone());
});
}
}
adapter
}
pub fn blueprint(&self) -> &TableBlueprint {
&self.blueprint
}
pub fn update_query(
mut self,
runtime: &AsyncRuntimeHandle,
ui: &egui::Ui,
new_blueprint: TableBlueprint,
) {
self.blueprint = new_blueprint;
let new_query_data = DataFusionQueryData::from(&self.blueprint);
if self.query.query_data != new_query_data {
self.query.query_data = new_query_data;
self.last_query_results = mem::take(&mut self.results);
if let Some(Ok(results)) = &mut self.last_query_results {
results.finished = true;
}
let rx = self.query.clone().execute_streaming(runtime);
self.rx = Arc::new(Mutex::new(rx));
TableSelectionState::clear(ui.ctx(), self.id);
}
ui.data_mut(|data| {
data.insert_temp(self.id, self);
});
}
pub fn apply_flag_changes(
&mut self,
ui: &egui::Ui,
flag_column_name: &str,
changes: &[FlagChangeEvent],
) {
let Some(Ok(results)) = &mut self.results else {
return;
};
let Some(col_idx) = results.sorbet_schema.columns.iter().position(|desc| {
matches!(desc, re_sorbet::ColumnDescriptor::Component(c) if c.component.as_str() == flag_column_name)
}) else {
return;
};
update_existing_flag_column(results, col_idx, changes);
ui.data_mut(|data| {
data.insert_temp(self.id, self.clone());
});
}
}
fn update_existing_flag_column(
results: &mut DataFusionQueryResult,
col_idx: usize,
changes: &[FlagChangeEvent],
) {
use arrow::array::{Array as _, BooleanArray};
for change in changes {
let Some((batch, row_offset)) = results.find_row_batch_mut(change.row) else {
continue;
};
let Some(old_col) = batch
.column(col_idx)
.as_any()
.downcast_ref::<BooleanArray>()
else {
re_log::warn_once!("Flag column at index {col_idx} is not a boolean column");
break;
};
let new_col: BooleanArray = (0..batch.num_rows())
.map(|i| {
if i == row_offset {
Some(change.new_value)
} else if old_col.is_null(i) {
None
} else {
Some(old_col.value(i))
}
})
.collect();
if let Some(new_batch) = batch.with_replaced_column(col_idx, std::sync::Arc::new(new_col)) {
*batch = new_batch;
}
}
}
fn balanced_binary_exprs(
mut exprs: Vec<datafusion::logical_expr::Expr>,
op: datafusion::logical_expr::Operator,
) -> Option<datafusion::logical_expr::Expr> {
while exprs.len() > 1 {
let mut exprs_next = Vec::with_capacity(exprs.len() / 2 + 1);
let mut exprs_prev = exprs.into_iter();
while let Some(left) = exprs_prev.next() {
if let Some(right) = exprs_prev.next() {
exprs_next.push(binary_expr(left, op, right));
} else {
exprs_next.push(left);
}
}
exprs = exprs_next;
}
exprs.into_iter().next()
}