use crate::prelude::*;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::{SessionConfig, SessionContext};
use datafusion::execution::memory_pool::{FairSpillPool, MemoryPool};
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::instrument;
#[derive(Debug, Clone)]
pub struct TermContextConfig {
pub batch_size: usize,
pub target_partitions: usize,
pub max_memory: usize,
pub memory_fraction: f64,
}
impl Default for TermContextConfig {
fn default() -> Self {
Self {
batch_size: 8192,
target_partitions: std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4),
max_memory: 2 * 1024 * 1024 * 1024, memory_fraction: 0.9,
}
}
}
pub struct TermContext {
inner: SessionContext,
pub(crate) tables: HashMap<String, Arc<dyn TableProvider>>,
config: TermContextConfig,
}
impl TermContext {
#[instrument]
pub fn new() -> Result<Self> {
Self::with_config(TermContextConfig::default())
}
#[instrument(skip(config))]
pub fn with_config(config: TermContextConfig) -> Result<Self> {
let session_config = SessionConfig::new()
.with_batch_size(config.batch_size)
.with_target_partitions(config.target_partitions)
.with_information_schema(true);
let memory_pool = Arc::new(FairSpillPool::new(config.max_memory)) as Arc<dyn MemoryPool>;
let runtime_env = RuntimeEnvBuilder::new()
.with_memory_pool(memory_pool)
.with_temp_file_path(std::env::temp_dir())
.build()
.map(Arc::new)?;
let inner = SessionContext::new_with_config_rt(session_config, runtime_env);
Ok(Self {
inner,
tables: HashMap::new(),
config,
})
}
pub fn inner(&self) -> &SessionContext {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut SessionContext {
&mut self.inner
}
pub fn config(&self) -> &TermContextConfig {
&self.config
}
pub fn registered_tables(&self) -> Vec<&str> {
self.tables.keys().map(|s| s.as_str()).collect()
}
pub fn has_table(&self, name: &str) -> bool {
self.tables.contains_key(name)
}
#[instrument(skip(self))]
pub async fn register_csv(&mut self, name: &str, path: &str) -> Result<()> {
self.inner
.register_csv(name, path, Default::default())
.await?;
let source = self.inner.table_provider(name).await?;
self.tables.insert(name.to_string(), source);
Ok(())
}
#[instrument(skip(self))]
pub async fn register_parquet(&mut self, name: &str, path: &str) -> Result<()> {
self.inner
.register_parquet(name, path, Default::default())
.await?;
let source = self.inner.table_provider(name).await?;
self.tables.insert(name.to_string(), source);
Ok(())
}
pub fn deregister_table(&mut self, name: &str) -> Result<()> {
self.inner.deregister_table(name)?;
self.tables.remove(name);
Ok(())
}
#[instrument(skip(self, provider))]
pub async fn register_table_provider(
&mut self,
name: &str,
provider: Arc<dyn TableProvider>,
) -> Result<()> {
self.inner.register_table(name, provider.clone())?;
self.tables.insert(name.to_string(), provider);
Ok(())
}
pub fn clear_tables(&mut self) -> Result<()> {
let table_names: Vec<_> = self.tables.keys().cloned().collect();
for name in table_names {
self.deregister_table(&name)?;
}
Ok(())
}
}
impl Drop for TermContext {
fn drop(&mut self) {
if let Err(e) = self.clear_tables() {
tracing::warn!("Failed to clear tables during TermContext drop: {}", e);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
#[test]
fn test_default_config() {
let config = TermContextConfig::default();
assert_eq!(config.batch_size, 8192);
assert_eq!(
config.target_partitions,
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
);
assert_eq!(config.max_memory, 2 * 1024 * 1024 * 1024);
assert_eq!(config.memory_fraction, 0.9);
}
#[tokio::test]
async fn test_context_creation() {
let ctx = TermContext::new().unwrap();
assert!(ctx.registered_tables().is_empty());
}
#[tokio::test]
async fn test_context_with_custom_config() {
let config = TermContextConfig {
batch_size: 16384,
max_memory: 4 * 1024 * 1024 * 1024,
..Default::default()
};
let ctx = TermContext::with_config(config.clone()).unwrap();
assert_eq!(ctx.config().batch_size, 16384);
assert_eq!(ctx.config().max_memory, 4 * 1024 * 1024 * 1024);
}
#[tokio::test]
async fn test_table_registration() {
let mut ctx = TermContext::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
],
)
.unwrap();
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table_provider("users", Arc::new(table))
.await
.unwrap();
assert!(ctx.has_table("users"));
assert_eq!(ctx.registered_tables(), vec!["users"]);
}
#[tokio::test]
async fn test_table_deregistration() {
let mut ctx = TermContext::new().unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
)
.unwrap();
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table_provider("test", Arc::new(table))
.await
.unwrap();
assert!(ctx.has_table("test"));
ctx.deregister_table("test").unwrap();
assert!(!ctx.has_table("test"));
assert!(ctx.registered_tables().is_empty());
}
#[tokio::test]
async fn test_clear_tables() {
let mut ctx = TermContext::new().unwrap();
for i in 0..3 {
let name = format!("table{i}");
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![i]))])
.unwrap();
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table_provider(&name, Arc::new(table))
.await
.unwrap();
}
assert_eq!(ctx.registered_tables().len(), 3);
ctx.clear_tables().unwrap();
assert!(ctx.registered_tables().is_empty());
}
}