use async_trait::async_trait;
use datafusion::arrow::array::{
Array, ArrayRef, Int64Array, StringArray, TimestampMillisecondArray,
};
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionConfig;
use datafusion::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, instrument};
use super::{MetricsQuery, MetricsRepository, ResultKey};
use crate::analyzers::context::AnalyzerContext;
use crate::error::{Result, TermError};
pub struct DataFusionQueryExecutor;
impl DataFusionQueryExecutor {
fn create_optimized_context() -> SessionContext {
let config = SessionConfig::new()
.with_information_schema(true)
.with_default_catalog_and_schema("term", "repository")
.with_target_partitions(num_cpus::get())
.with_batch_size(8192);
SessionContext::new_with_config(config)
}
#[instrument(skip(data, tags, analyzers), fields(
data_size = data.len(),
filter_count = tags.len() + analyzers.as_ref().map(|a| a.len()).unwrap_or(0),
time_range = format_args!("{:?}-{:?}", after, before),
limit = limit,
offset = offset
))]
#[allow(clippy::too_many_arguments)]
pub async fn execute_optimized_query(
data: Vec<(ResultKey, AnalyzerContext)>,
before: Option<i64>,
after: Option<i64>,
tags: &HashMap<String, String>,
analyzers: &Option<Vec<String>>,
limit: Option<usize>,
offset: Option<usize>,
ascending: bool,
) -> Result<Vec<(ResultKey, AnalyzerContext)>> {
if data.is_empty() {
debug!("No data to query, returning empty results");
return Ok(vec![]);
}
let ctx = Self::create_optimized_context();
let record_batch = Self::create_record_batch(&data).map_err(|e| {
TermError::repository_with_source(
"datafusion",
"execute_query",
"Failed to convert repository data to Arrow format",
Box::new(e),
)
})?;
let schema = record_batch.schema();
let table = MemTable::try_new(schema, vec![vec![record_batch]]).map_err(|e| {
TermError::repository_with_source(
"datafusion",
"execute_query",
"Failed to create DataFusion table from Arrow data",
Box::new(e),
)
})?;
let table_name = "metrics_data";
ctx.register_table(table_name, Arc::new(table))
.map_err(|e| {
TermError::repository_with_source(
"datafusion",
"execute_query",
"Failed to register table with DataFusion context",
Box::new(e),
)
})?;
let sql = Self::build_optimized_sql(
table_name, before, after, tags, analyzers, limit, offset, ascending,
)?;
debug!("Executing DataFusion query: {}", sql);
let df = ctx.sql(&sql).await.map_err(|e| {
TermError::repository_with_source(
"datafusion",
"execute_query",
format!("Failed to parse SQL query: {sql}"),
Box::new(e),
)
})?;
let results = df.collect().await.map_err(|e| {
TermError::repository_with_source(
"datafusion",
"execute_query",
"Failed to execute DataFusion query",
Box::new(e),
)
})?;
Self::convert_results_back(&data, results).await
}
#[instrument(skip(data), fields(data_size = data.len()))]
fn create_record_batch(data: &[(ResultKey, AnalyzerContext)]) -> Result<RecordBatch> {
let len = data.len();
let timestamps: Vec<i64> = data.iter().map(|(key, _)| key.timestamp).collect();
let timestamp_array = TimestampMillisecondArray::from(timestamps);
let indices: Vec<i64> = (0..len as i64).collect();
let index_array = Int64Array::from(indices);
let mut all_tag_keys: std::collections::HashSet<String> = std::collections::HashSet::new();
for (key, _) in data {
for tag_key in key.tags.keys() {
all_tag_keys.insert(tag_key.clone());
}
}
let mut tag_arrays: Vec<(String, ArrayRef)> = Vec::new();
for tag_key in &all_tag_keys {
let tag_values: Vec<Option<String>> = data
.iter()
.map(|(key, _)| key.tags.get(tag_key).cloned())
.collect();
let tag_array = StringArray::from(tag_values);
tag_arrays.push((format!("tag_{tag_key}"), Arc::new(tag_array) as ArrayRef));
}
let has_metrics: Vec<bool> = data
.iter()
.map(|(_, ctx)| !ctx.all_metrics().is_empty())
.collect();
let metrics_array = Arc::new(
has_metrics
.iter()
.map(|&has| if has { Some("true") } else { Some("false") })
.collect::<StringArray>(),
) as ArrayRef;
let mut fields = vec![
Field::new(
"timestamp",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("row_index", DataType::Int64, false),
Field::new("has_metrics", DataType::Utf8, true),
];
for (tag_key, _) in &tag_arrays {
fields.push(Field::new(tag_key, DataType::Utf8, true));
}
let schema = Arc::new(Schema::new(fields));
let mut columns: Vec<ArrayRef> = vec![
Arc::new(timestamp_array) as ArrayRef,
Arc::new(index_array) as ArrayRef,
metrics_array,
];
for (_, array) in tag_arrays {
columns.push(array);
}
RecordBatch::try_new(schema, columns).map_err(|e| {
TermError::repository_with_source(
"datafusion",
"create_record_batch",
format!("Failed to create Arrow RecordBatch for {len} rows"),
Box::new(e),
)
})
}
#[instrument(skip(tags, analyzers), fields(
has_time_filter = before.is_some() || after.is_some(),
tag_filter_count = tags.len(),
has_analyzer_filter = analyzers.is_some()
))]
#[allow(clippy::too_many_arguments)]
fn build_optimized_sql(
table_name: &str,
before: Option<i64>,
after: Option<i64>,
tags: &HashMap<String, String>,
analyzers: &Option<Vec<String>>,
limit: Option<usize>,
offset: Option<usize>,
ascending: bool,
) -> Result<String> {
let mut sql = format!("SELECT * FROM {table_name} WHERE 1=1");
if let Some(before_ts) = before {
sql.push_str(&format!(
" AND timestamp < TIMESTAMP '{}'",
chrono::DateTime::from_timestamp_millis(before_ts)
.unwrap_or_else(chrono::Utc::now)
.format("%Y-%m-%d %H:%M:%S%.3f")
));
}
if let Some(after_ts) = after {
sql.push_str(&format!(
" AND timestamp >= TIMESTAMP '{}'",
chrono::DateTime::from_timestamp_millis(after_ts)
.unwrap_or_else(chrono::Utc::now)
.format("%Y-%m-%d %H:%M:%S%.3f")
));
}
for (tag_key, tag_value) in tags {
let safe_key = tag_key.replace(['\'', '"'], "_"); let safe_value = tag_value.replace(['\'', '"'], "_");
sql.push_str(&format!(" AND tag_{safe_key} = '{safe_value}'"));
}
if let Some(_analyzer_list) = analyzers {
sql.push_str(" AND has_metrics = 'true'");
}
let sort_direction = if ascending { "ASC" } else { "DESC" };
sql.push_str(&format!(" ORDER BY timestamp {sort_direction}"));
if let Some(limit_val) = limit {
sql.push_str(&format!(" LIMIT {limit_val}"));
if let Some(offset_val) = offset {
sql.push_str(&format!(" OFFSET {offset_val}"));
}
}
Ok(sql)
}
#[instrument(skip(original_data, results))]
async fn convert_results_back(
original_data: &[(ResultKey, AnalyzerContext)],
results: Vec<RecordBatch>,
) -> Result<Vec<(ResultKey, AnalyzerContext)>> {
let mut output = Vec::new();
for batch in results {
let row_indices = batch
.column_by_name("row_index")
.ok_or_else(|| TermError::Internal("Missing row_index column".to_string()))?
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| TermError::Internal("Invalid row_index column type".to_string()))?;
for row_idx in 0..batch.num_rows() {
let original_idx = row_indices.value(row_idx) as usize;
if let Some((key, context)) = original_data.get(original_idx) {
output.push((key.clone(), context.clone()));
}
}
}
debug!(
"Converted {} DataFusion results back to original format",
output.len()
);
Ok(output)
}
}
impl Default for DataFusionQueryExecutor {
fn default() -> Self {
Self
}
}
#[async_trait]
pub trait DataFusionQueryExecutorExt: MetricsRepository {
#[instrument(skip(self, query))]
async fn execute_datafusion_query(
&self,
query: MetricsQuery,
) -> Result<Vec<(ResultKey, AnalyzerContext)>> {
let all_keys = self.list_keys().await?;
let mut data = Vec::with_capacity(all_keys.len());
for key in all_keys {
if let Ok(Some(context)) = self.get(&key).await {
data.push((key, context));
}
}
DataFusionQueryExecutor::execute_optimized_query(
data,
query.get_before(),
query.get_after(),
query.get_tags(),
query.get_analyzers(),
query.get_limit(),
query.get_offset(),
query.is_ascending(),
)
.await
}
}
impl<T: MetricsRepository + ?Sized> DataFusionQueryExecutorExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzers::types::MetricValue;
fn create_test_data() -> Vec<(ResultKey, AnalyzerContext)> {
let mut data = Vec::new();
for i in 0..100 {
let key = ResultKey::new(i * 1000)
.with_tag("env", if i % 2 == 0 { "prod" } else { "staging" })
.with_tag("region", if i % 3 == 0 { "us-east-1" } else { "us-west-2" })
.with_tag("version", format!("v{}.0.0", i % 5));
let mut context = AnalyzerContext::new();
context.store_metric("row_count", MetricValue::Long(i * 100));
context.store_metric(
"completeness",
MetricValue::Double(0.95 + (i as f64 * 0.001)),
);
data.push((key, context));
}
data
}
#[tokio::test]
async fn test_datafusion_query_executor_basic() {
let data = create_test_data();
let results = DataFusionQueryExecutor::execute_optimized_query(
data.clone(),
None,
None,
&HashMap::new(),
&None,
Some(10),
None,
false,
)
.await
.unwrap();
assert_eq!(results.len(), 10);
assert!(results[0].0.timestamp >= results[1].0.timestamp);
}
#[tokio::test]
async fn test_datafusion_query_executor_time_filter() {
let data = create_test_data();
let results = DataFusionQueryExecutor::execute_optimized_query(
data,
Some(50000), Some(10000), &HashMap::new(),
&None,
None,
None,
true, )
.await
.unwrap();
assert_eq!(results.len(), 40);
assert!(results[0].0.timestamp >= 10000);
assert!(results[0].0.timestamp < 50000);
assert!(results[0].0.timestamp <= results[1].0.timestamp);
}
#[tokio::test]
async fn test_datafusion_query_executor_tag_filter() {
let data = create_test_data();
let mut tags = HashMap::new();
tags.insert("env".to_string(), "prod".to_string());
let results = DataFusionQueryExecutor::execute_optimized_query(
data, None, None, &tags, &None, None, None, false,
)
.await
.unwrap();
assert_eq!(results.len(), 50);
for (key, _) in results {
assert_eq!(key.get_tag("env"), Some("prod"));
}
}
#[tokio::test]
async fn test_datafusion_query_executor_pagination() {
let data = create_test_data();
let page1 = DataFusionQueryExecutor::execute_optimized_query(
data.clone(),
None,
None,
&HashMap::new(),
&None,
Some(20), Some(0), true, )
.await
.unwrap();
let page2 = DataFusionQueryExecutor::execute_optimized_query(
data,
None,
None,
&HashMap::new(),
&None,
Some(20), Some(20), true, )
.await
.unwrap();
assert_eq!(page1.len(), 20);
assert_eq!(page2.len(), 20);
assert_ne!(page1[0].0.timestamp, page2[0].0.timestamp);
assert!(page1[19].0.timestamp < page2[0].0.timestamp);
}
#[tokio::test]
async fn test_record_batch_creation() {
let data = create_test_data();
let batch = DataFusionQueryExecutor::create_record_batch(&data[0..10]).unwrap();
assert_eq!(batch.num_rows(), 10);
assert!(batch.num_columns() >= 5);
let timestamps = batch.column_by_name("timestamp").unwrap();
assert_eq!(
timestamps.data_type(),
&DataType::Timestamp(TimeUnit::Millisecond, None)
);
let indices = batch.column_by_name("row_index").unwrap();
assert_eq!(indices.data_type(), &DataType::Int64);
}
}