use datafusion::arrow::datatypes::SchemaRef;
use datafusion::execution::context::DataFilePaths;
use datafusion::sql::TableReference;
use kapot_core::object_store_registry::KapotObjectStoreRegistry;
use log::info;
use parking_lot::Mutex;
use sqlparser::ast::Statement;
use std::collections::HashMap;
use std::sync::Arc;
use kapot_core::config::KapotConfig;
use kapot_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
use kapot_core::serde::protobuf::{CreateSessionParams, KeyValuePair};
use kapot_core::utils::{
create_df_ctx_with_kapot_query_planner, create_grpc_client_connection,
};
use datafusion_proto::protobuf::LogicalPlanNode;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::{source_as_provider, TableProvider};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{
CreateExternalTable, DdlStatement, LogicalPlan, TableScan,
};
use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
SessionConfig, SessionContext,
};
use datafusion::sql::parser::{DFParser, Statement as DFStatement};
struct KapotContextState {
config: KapotConfig,
scheduler_host: String,
scheduler_port: u16,
tables: HashMap<String, Arc<dyn TableProvider>>,
}
impl KapotContextState {
pub fn new(
scheduler_host: String,
scheduler_port: u16,
config: &KapotConfig,
) -> Self {
Self {
config: config.clone(),
scheduler_host,
scheduler_port,
tables: HashMap::new(),
}
}
pub fn config(&self) -> &KapotConfig {
&self.config
}
}
pub struct KapotContext {
state: Arc<Mutex<KapotContextState>>,
context: Arc<SessionContext>,
}
impl KapotContext {
pub async fn remote(
host: &str,
port: u16,
config: &KapotConfig,
) -> Result<Self> {
let state = KapotContextState::new(host.to_owned(), port, config);
let scheduler_url =
format!("http://{}:{}", &state.scheduler_host, state.scheduler_port);
info!(
"Connecting to kapot scheduler at {}",
scheduler_url.clone()
);
let connection = create_grpc_client_connection(scheduler_url.clone())
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let limit = config.default_grpc_client_max_message_size();
let mut scheduler = SchedulerGrpcClient::new(connection)
.max_encoding_message_size(limit)
.max_decoding_message_size(limit);
let remote_session_id = scheduler
.create_session(CreateSessionParams {
settings: config
.settings()
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>(),
})
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
.into_inner()
.session_id;
info!(
"Server side SessionContext created with session id: {}",
remote_session_id
);
let context = {
create_df_ctx_with_kapot_query_planner::<LogicalPlanNode>(
scheduler_url,
remote_session_id,
state.config(),
)
};
Ok(Self {
state: Arc::new(Mutex::new(state)),
context: Arc::new(context),
})
}
#[cfg(feature = "standalone")]
pub async fn standalone(
config: &KapotConfig,
concurrent_tasks: usize,
) -> kapot_core::error::Result<Self> {
use std::sync::Arc;
use kapot_core::serde::KapotCodec;
use datafusion_proto::protobuf::PhysicalPlanNode;
log::info!("Running in local mode. Scheduler will be run in-proc");
let addr = kapot_scheduler::standalone::new_standalone_scheduler().await?;
let scheduler_url = format!("http://localhost:{}", addr.port());
let mut scheduler = loop {
match SchedulerGrpcClient::connect(scheduler_url.clone()).await {
Err(_) => {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
log::info!("Attempting to connect to in-proc scheduler...");
}
Ok(scheduler) => break scheduler,
}
};
let remote_session_id = scheduler
.create_session(CreateSessionParams {
settings: config
.settings()
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>(),
})
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
.into_inner()
.session_id;
info!(
"Server side SessionContext created with session id: {}",
remote_session_id
);
let context = {
create_df_ctx_with_kapot_query_planner::<LogicalPlanNode>(
scheduler_url,
remote_session_id,
config,
)
};
let default_codec: KapotCodec<LogicalPlanNode, PhysicalPlanNode> =
KapotCodec::default();
kapot_executor::new_standalone_executor(
scheduler,
concurrent_tasks,
default_codec,
)
.await?;
let state =
KapotContextState::new("localhost".to_string(), addr.port(), config);
Ok(Self {
state: Arc::new(Mutex::new(state)),
context: Arc::new(context),
})
}
pub async fn read_json<P: DataFilePaths>(
&self,
paths: P,
options: NdJsonReadOptions<'_>,
) -> Result<DataFrame> {
let df = self.context.read_json(paths, options).await?;
Ok(df)
}
pub async fn read_avro<P: DataFilePaths>(
&self,
paths: P,
options: AvroReadOptions<'_>,
) -> Result<DataFrame> {
let df = self.context.read_avro(paths, options).await?;
Ok(df)
}
pub async fn read_parquet<P: DataFilePaths>(
&self,
paths: P,
options: ParquetReadOptions<'_>,
) -> Result<DataFrame> {
let df = self.context.read_parquet(paths, options).await?;
Ok(df)
}
pub async fn read_csv<P: DataFilePaths>(
&self,
paths: P,
options: CsvReadOptions<'_>,
) -> Result<DataFrame> {
let df = self.context.read_csv(paths, options).await?;
Ok(df)
}
pub fn register_table(
&self,
name: &str,
table: Arc<dyn TableProvider>,
) -> Result<()> {
let mut state = self.state.lock();
state.tables.insert(name.to_owned(), table);
Ok(())
}
pub async fn register_csv(
&self,
name: &str,
path: &str,
options: CsvReadOptions<'_>,
) -> Result<()> {
let plan = self
.read_csv(path, options)
.await
.map_err(|e| {
DataFusionError::Context(format!("Can't read CSV: {path}"), Box::new(e))
})?
.into_optimized_plan()?;
match plan {
LogicalPlan::TableScan(TableScan { source, .. }) => {
self.register_table(name, source_as_provider(&source)?)
}
_ => Err(DataFusionError::Internal("Expected tables scan".to_owned())),
}
}
pub async fn register_parquet(
&self,
name: &str,
path: &str,
options: ParquetReadOptions<'_>,
) -> Result<()> {
match self
.read_parquet(path, options)
.await?
.into_optimized_plan()?
{
LogicalPlan::TableScan(TableScan { source, .. }) => {
self.register_table(name, source_as_provider(&source)?)
}
_ => Err(DataFusionError::Internal("Expected tables scan".to_owned())),
}
}
pub async fn register_avro(
&self,
name: &str,
path: &str,
options: AvroReadOptions<'_>,
) -> Result<()> {
match self.read_avro(path, options).await?.into_optimized_plan()? {
LogicalPlan::TableScan(TableScan { source, .. }) => {
self.register_table(name, source_as_provider(&source)?)
}
_ => Err(DataFusionError::Internal("Expected tables scan".to_owned())),
}
}
pub async fn is_show_statement(&self, sql: &str) -> Result<bool> {
let mut is_show_variable: bool = false;
let statements = DFParser::parse_sql(sql)?;
if statements.len() != 1 {
return Err(DataFusionError::NotImplemented(
"The context currently only supports a single SQL statement".to_string(),
));
}
if let DFStatement::Statement(statement) = &statements[0] {
match **statement {
Statement::ShowVariable { .. } => {
is_show_variable = true;
}
Statement::ShowColumns { .. } => {
is_show_variable = true;
}
Statement::ShowTables { .. } => {
is_show_variable = true;
}
_ => {
is_show_variable = false;
}
}
};
Ok(is_show_variable)
}
pub fn context(&self) -> &SessionContext {
&self.context
}
pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
let mut ctx = self.context.clone();
let is_show = self.is_show_statement(sql).await?;
if is_show {
let state = self.state.lock();
ctx = Arc::new(SessionContext::new_with_config(
SessionConfig::new().with_information_schema(
state.config.default_with_information_schema(),
),
));
}
{
let state = self.state.lock();
for (name, prov) in &state.tables {
let table_ref = TableReference::Bare {
table: name.as_str().into(),
};
if !ctx.table_exist(table_ref)? {
ctx.register_table(
TableReference::Bare {
table: name.as_str().into(),
},
Arc::clone(prov),
)?;
}
}
}
let plan = ctx.state().create_logical_plan(sql).await?;
match &plan {
LogicalPlan::Ddl(DdlStatement::CreateExternalTable(
CreateExternalTable {
ref schema,
ref name,
ref location,
ref file_type,
ref table_partition_cols,
ref if_not_exists,
options,
..
},
)) => {
let table_exists = ctx.table_exist(name.to_owned())?;
let schema: SchemaRef = Arc::new(schema.as_ref().to_owned().into());
let table_partition_cols = table_partition_cols
.iter()
.map(|col| {
schema
.field_with_name(col)
.map(|f| (f.name().to_owned(), f.data_type().to_owned()))
.map_err(|e| DataFusionError::ArrowError(e, None))
})
.collect::<Result<Vec<_>>>()?;
match (if_not_exists, table_exists) {
(_, false) => match file_type.to_lowercase().as_str() {
"csv" => {
let has_header = match options.get("format.has_header") {
Some(str) => str.parse::<bool>().unwrap(),
None => false,
};
let delimiter = match options.get("format.delimiter") {
Some(str) => str.chars().next().unwrap(),
None => ',',
};
let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter as u8)
.table_partition_cols(table_partition_cols.to_vec());
if !schema.fields().is_empty() {
options = options.schema(&schema);
}
self.register_csv(name.table(), location, options).await?;
Ok(DataFrame::new(ctx.state(), plan))
}
"parquet" => {
self.register_parquet(
name.table(),
location,
ParquetReadOptions::default()
.table_partition_cols(table_partition_cols),
)
.await?;
Ok(DataFrame::new(ctx.state(), plan))
}
"avro" => {
self.register_avro(
name.table(),
location,
AvroReadOptions::default()
.table_partition_cols(table_partition_cols),
)
.await?;
Ok(DataFrame::new(ctx.state(), plan))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported file type {file_type:?}."
))),
},
(true, true) => Ok(DataFrame::new(ctx.state(), plan)),
(false, true) => Err(DataFusionError::Execution(format!(
"Table '{name:?}' already exists"
))),
}
}
_ => ctx.execute_logical_plan(plan).await,
}
}
pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
let ctx = self.context.clone();
ctx.execute_logical_plan(plan).await
}
}
#[cfg(test)]
#[cfg(feature = "standalone")]
mod standalone_tests {
use datafusion::arrow;
use datafusion::arrow::util::pretty::pretty_format_batches;
use crate::context::KapotContext;
use kapot_core::config::{
KapotConfigBuilder, KAPOT_WITH_INFORMATION_SCHEMA,
};
use kapot_core::error::Result;
use datafusion::config::TableParquetOptions;
use datafusion::dataframe::DataFrameWriteOptions;
use datafusion::datasource::listing::ListingTableUrl;
use datafusion::prelude::ParquetReadOptions;
use tempfile::TempDir;
#[tokio::test]
async fn test_standalone_mode() {
use super::*;
let context = KapotContext::standalone(&KapotConfig::new().unwrap(), 1)
.await
.unwrap();
let df = context.sql("SELECT 1;").await.unwrap();
df.collect().await.unwrap();
}
#[tokio::test]
async fn test_write_parquet() -> Result<()> {
use super::*;
let context =
KapotContext::standalone(&KapotConfig::new().unwrap(), 1).await?;
let df = context.sql("SELECT 1;").await?;
let tmp_dir = TempDir::new().unwrap();
let file_path = format!(
"{}",
tmp_dir.path().join("test_write_parquet.parquet").display()
);
df.write_parquet(
&file_path,
DataFrameWriteOptions::default(),
Some(TableParquetOptions::default()),
)
.await?;
Ok(())
}
#[tokio::test]
async fn test_write_csv() -> Result<()> {
use super::*;
let context =
KapotContext::standalone(&KapotConfig::new().unwrap(), 1).await?;
let df = context.sql("SELECT 1;").await?;
let tmp_dir = TempDir::new().unwrap();
let file_path =
format!("{}", tmp_dir.path().join("test_write_csv.csv").display());
df.write_csv(&file_path, DataFrameWriteOptions::default(), None)
.await?;
Ok(())
}
#[tokio::test]
async fn test_kapot_show_tables() {
use super::*;
use std::fs::File;
use std::io::Write;
use tempfile::TempDir;
let context = KapotContext::standalone(&KapotConfig::new().unwrap(), 1)
.await
.unwrap();
let data = "Jorge,2018-12-13T12:12:10.011Z\n\
Andrew,2018-11-13T17:11:10.011Z";
let tmp_dir = TempDir::new().unwrap();
let file_path = tmp_dir.path().join("timestamps.csv");
{
File::create(&file_path)
.expect("creating temp file")
.write_all(data.as_bytes())
.expect("writing data");
}
let sql = format!(
"CREATE EXTERNAL TABLE csv_with_timestamps (
name VARCHAR,
ts TIMESTAMP
)
STORED AS CSV
LOCATION '{}'
OPTIONS ('has_header' 'false', 'delimiter' ',')
",
file_path.to_str().expect("path is utf8")
);
context.sql(sql.as_str()).await.unwrap();
let df = context.sql("show columns from csv_with_timestamps;").await;
assert!(df.is_err());
}
#[tokio::test]
async fn test_show_tables_not_with_information_schema() {
use super::*;
use kapot_core::config::{
KapotConfigBuilder, KAPOT_WITH_INFORMATION_SCHEMA,
};
use std::fs::File;
use std::io::Write;
use tempfile::TempDir;
let config = KapotConfigBuilder::default()
.set(KAPOT_WITH_INFORMATION_SCHEMA, "true")
.build()
.unwrap();
let context = KapotContext::standalone(&config, 1).await.unwrap();
let data = "Jorge,2018-12-13T12:12:10.011Z\n\
Andrew,2018-11-13T17:11:10.011Z";
let tmp_dir = TempDir::new().unwrap();
let file_path = tmp_dir.path().join("timestamps.csv");
{
File::create(&file_path)
.expect("creating temp file")
.write_all(data.as_bytes())
.expect("writing data");
}
let sql = format!(
"CREATE EXTERNAL TABLE csv_with_timestamps (
name VARCHAR,
ts TIMESTAMP
)
STORED AS CSV
LOCATION '{}'
",
file_path.to_str().expect("path is utf8")
);
context.sql(sql.as_str()).await.unwrap();
let df = context.sql("show tables;").await;
assert!(df.is_ok());
}
#[tokio::test]
#[ignore]
async fn test_task_stuck_when_referenced_task_failed() {
use super::*;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::util::pretty;
use datafusion::datasource::file_format::csv::CsvFormat;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig,
};
use kapot_core::config::{
KapotConfigBuilder, KAPOT_WITH_INFORMATION_SCHEMA,
};
let config = KapotConfigBuilder::default()
.set(KAPOT_WITH_INFORMATION_SCHEMA, "true")
.build()
.unwrap();
let context = KapotContext::standalone(&config, 1).await.unwrap();
context
.register_parquet(
"single_nan",
"testdata/single_nan.parquet",
ParquetReadOptions::default(),
)
.await
.unwrap();
{
let mut guard = context.state.lock();
let csv_table = guard.tables.get("single_nan");
if let Some(table_provide) = csv_table {
if let Some(listing_table) = table_provide
.clone()
.as_any()
.downcast_ref::<ListingTable>()
{
let x = listing_table.options();
let error_options = ListingOptions {
file_extension: x.file_extension.clone(),
format: Arc::new(CsvFormat::default()),
table_partition_cols: x.table_partition_cols.clone(),
collect_stat: x.collect_stat,
target_partitions: x.target_partitions,
file_sort_order: vec![],
};
let table_paths = listing_table
.table_paths()
.iter()
.map(|t| ListingTableUrl::parse(t).unwrap())
.collect();
let config = ListingTableConfig::new_with_multi_paths(table_paths)
.with_schema(Arc::new(Schema::empty()))
.with_listing_options(error_options);
let error_table = ListingTable::try_new(config).unwrap();
guard
.tables
.insert("single_nan".to_string(), Arc::new(error_table));
}
}
}
let df = context
.sql("select count(1) from single_nan;")
.await
.unwrap();
let results = df.collect().await.unwrap();
pretty::print_batches(&results).unwrap();
}
#[tokio::test]
async fn test_empty_exec_with_one_row() {
use crate::context::KapotContext;
use kapot_core::config::{
KapotConfigBuilder, KAPOT_WITH_INFORMATION_SCHEMA,
};
let config = KapotConfigBuilder::default()
.set(KAPOT_WITH_INFORMATION_SCHEMA, "true")
.build()
.unwrap();
let context = KapotContext::standalone(&config, 1).await.unwrap();
let sql = "select EXTRACT(year FROM to_timestamp('2020-09-08T12:13:14+00:00'));";
let df = context.sql(sql).await.unwrap();
assert!(!df.collect().await.unwrap().is_empty());
}
#[tokio::test]
async fn test_union_and_union_all() {
use super::*;
use kapot_core::config::{
KapotConfigBuilder, KAPOT_WITH_INFORMATION_SCHEMA,
};
use datafusion::arrow::util::pretty::pretty_format_batches;
let config = KapotConfigBuilder::default()
.set(KAPOT_WITH_INFORMATION_SCHEMA, "true")
.build()
.unwrap();
let context = KapotContext::standalone(&config, 1).await.unwrap();
let df = context
.sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;")
.await
.unwrap();
let res1 = df.collect().await.unwrap();
let expected1 = vec![
"+--------+",
"| number |",
"+--------+",
"| 1 |",
"+--------+",
];
assert_eq!(
expected1,
pretty_format_batches(&res1)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
let expected2 = vec![
"+--------+",
"| number |",
"+--------+",
"| 1 |",
"| 1 |",
"+--------+",
];
let df = context
.sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;")
.await
.unwrap();
let res2 = df.collect().await.unwrap();
assert_eq!(
expected2,
pretty_format_batches(&res2)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
}
#[tokio::test]
async fn test_aggregate_min_max() {
let context = create_test_context().await;
let df = context.sql("select min(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| min(test.id) |",
"+--------------+",
"| 0 |",
"+--------------+",
];
assert_result_eq(expected, &res);
let df = context.sql("select max(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| max(test.id) |",
"+--------------+",
"| 7 |",
"+--------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_sum() {
let context = create_test_context().await;
let df = context.sql("select SUM(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| sum(test.id) |",
"+--------------+",
"| 28 |",
"+--------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_avg() {
let context = create_test_context().await;
let df = context.sql("select AVG(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| avg(test.id) |",
"+--------------+",
"| 3.5 |",
"+--------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_count() {
let context = create_test_context().await;
let df = context.sql("select COUNT(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+----------------+",
"| count(test.id) |",
"+----------------+",
"| 8 |",
"+----------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_approx_distinct() {
let context = create_test_context().await;
let df = context
.sql("select approx_distinct(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------+",
"| approx_distinct(test.id) |",
"+--------------------------+",
"| 8 |",
"+--------------------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_array_agg() {
let context = create_test_context().await;
let df = context
.sql("select ARRAY_AGG(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------+",
"| array_agg(test.id) |",
"+--------------------------+",
"| [4, 5, 6, 7, 2, 3, 0, 1] |",
"+--------------------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_var() {
let context = create_test_context().await;
let df = context.sql("select VAR(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------+",
"| var(test.id) |",
"+-------------------+",
"| 6.000000000000001 |",
"+-------------------+",
];
assert_result_eq(expected, &res);
let df = context
.sql("select VAR_POP(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------+",
"| var_pop(test.id) |",
"+-------------------+",
"| 5.250000000000001 |",
"+-------------------+",
];
assert_result_eq(expected, &res);
let df = context
.sql("select VAR_SAMP(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------+",
"| var(test.id) |",
"+-------------------+",
"| 6.000000000000001 |",
"+-------------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_stddev() {
let context = create_test_context().await;
let df = context
.sql("select STDDEV(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------+",
"| stddev(test.id) |",
"+--------------------+",
"| 2.4494897427831783 |",
"+--------------------+",
];
assert_result_eq(expected, &res);
let df = context
.sql("select STDDEV_SAMP(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------+",
"| stddev(test.id) |",
"+--------------------+",
"| 2.4494897427831783 |",
"+--------------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_covar() {
let context = create_test_context().await;
let df = context
.sql("select COVAR(id, tinyint_col) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------------------+",
"| covar_samp(test.id,test.tinyint_col) |",
"+--------------------------------------+",
"| 0.28571428571428586 |",
"+--------------------------------------+",
];
assert_result_eq(expected, &res);
}
#[tokio::test]
async fn test_aggregate_correlation() {
let context = create_test_context().await;
let df = context
.sql("select CORR(id, tinyint_col) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------------+",
"| corr(test.id,test.tinyint_col) |",
"+--------------------------------+",
"| 0.21821789023599245 |",
"+--------------------------------+",
];
assert_result_eq(expected, &res);
}
#[ignore]
#[tokio::test]
async fn test_aggregate_approx_percentile() {
let context = create_test_context().await;
let df = context
.sql("select approx_percentile_cont_with_weight(id, 2, 0.5) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------------------------------------------------------+",
"| approx_percentile_cont_with_weight(test.id,Int64(2),Float64(0.5)) |",
"+-------------------------------------------------------------------+",
"| 1 |",
"+-------------------------------------------------------------------+",
];
assert_result_eq(expected, &res);
let df = context
.sql("select approx_percentile_cont(\"double_col\", 0.5) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+------------------------------------------------------+",
"| approx_percentile_cont(test.double_col,Float64(0.5)) |",
"+------------------------------------------------------+",
"| 7.574999999999999 |",
"+------------------------------------------------------+",
];
assert_result_eq(expected, &res);
}
fn assert_result_eq(
expected: Vec<&str>,
results: &[arrow::record_batch::RecordBatch],
) {
assert_eq!(
expected,
pretty_format_batches(results)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
}
async fn create_test_context() -> KapotContext {
let config = KapotConfigBuilder::default()
.set(KAPOT_WITH_INFORMATION_SCHEMA, "true")
.build()
.unwrap();
let context = KapotContext::standalone(&config, 4).await.unwrap();
context
.register_parquet(
"test",
"testdata/alltypes_plain.parquet",
ParquetReadOptions::default(),
)
.await
.unwrap();
context
}
}