Skip to main content

tl_data/
engine.rs

1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::util::pretty::pretty_format_batches;
3use datafusion::execution::context::SessionContext;
4use datafusion::execution::disk_manager::DiskManagerConfig;
5use datafusion::execution::memory_pool::FairSpillPool;
6use datafusion::execution::runtime_env::RuntimeEnvBuilder;
7use datafusion::prelude::*;
8use std::sync::Arc;
9use tokio::runtime::Runtime;
10
11/// Configuration for the DataEngine.
12pub struct DataEngineConfig {
13    /// Maximum memory in bytes for the DataFusion pool (default: 512MB).
14    pub max_memory_bytes: usize,
15    /// Enable spill-to-disk when memory limit is reached.
16    pub spill_to_disk: bool,
17    /// Directory for spill files (default: system temp dir).
18    pub spill_path: Option<String>,
19}
20
21impl Default for DataEngineConfig {
22    fn default() -> Self {
23        DataEngineConfig {
24            max_memory_bytes: 512 * 1024 * 1024, // 512 MB
25            spill_to_disk: true,
26            spill_path: None,
27        }
28    }
29}
30
31/// Synchronous wrapper around DataFusion's async SessionContext.
32pub struct DataEngine {
33    pub ctx: SessionContext,
34    pub rt: Arc<Runtime>,
35}
36
37impl Default for DataEngine {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl DataEngine {
44    /// Create a new DataEngine with default configuration.
45    /// Backward-compatible with existing code.
46    pub fn new() -> Self {
47        Self::with_config(DataEngineConfig::default())
48    }
49
50    /// Create a new DataEngine with custom configuration.
51    pub fn with_config(config: DataEngineConfig) -> Self {
52        let rt = Arc::new(Runtime::new().expect("Failed to create tokio runtime for DataEngine"));
53
54        // Build runtime environment with memory pool and disk manager
55        let pool = FairSpillPool::new(config.max_memory_bytes);
56
57        let mut rt_builder = RuntimeEnvBuilder::new().with_memory_pool(Arc::new(pool));
58
59        if config.spill_to_disk {
60            let disk_config = if let Some(ref path) = config.spill_path {
61                DiskManagerConfig::new_specified(vec![path.clone().into()])
62            } else {
63                DiskManagerConfig::NewOs
64            };
65            rt_builder = rt_builder.with_disk_manager(disk_config);
66        }
67
68        let runtime_env = rt_builder.build().expect("Failed to build RuntimeEnv");
69
70        // Configure session with parallelism
71        let target_partitions = num_cpus::get();
72        let session_config = SessionConfig::new().with_target_partitions(target_partitions);
73
74        let ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime_env));
75
76        DataEngine { ctx, rt }
77    }
78
79    /// Execute a DataFusion DataFrame and collect results synchronously.
80    pub fn collect(&self, df: DataFrame) -> Result<Vec<RecordBatch>, String> {
81        self.rt
82            .block_on(df.collect())
83            .map_err(|e| format!("DataFusion collect error: {e}"))
84    }
85
86    /// Format collected batches as a pretty table string.
87    pub fn format_batches(batches: &[RecordBatch]) -> Result<String, String> {
88        pretty_format_batches(batches)
89            .map(|t| t.to_string())
90            .map_err(|e| format!("Format error: {e}"))
91    }
92
93    /// Register a RecordBatch as a named table in the session.
94    pub fn register_batch(&self, name: &str, batch: RecordBatch) -> Result<(), String> {
95        let schema = batch.schema();
96        let provider = datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]])
97            .map_err(|e| format!("MemTable error: {e}"))?;
98        self.ctx
99            .register_table(name, Arc::new(provider))
100            .map_err(|e| format!("Register table error: {e}"))?;
101        Ok(())
102    }
103
104    /// Run a SQL query and return results.
105    pub fn sql(&self, query: &str) -> Result<DataFrame, String> {
106        self.rt
107            .block_on(self.ctx.sql(query))
108            .map_err(|e| format!("SQL error: {e}"))
109    }
110
111    /// Get the underlying session context for DataFusion operations.
112    pub fn session_ctx(&self) -> &SessionContext {
113        &self.ctx
114    }
115
116    /// Get a reference to the tokio Runtime.
117    pub fn runtime(&self) -> &Arc<Runtime> {
118        &self.rt
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use datafusion::arrow::array::{Int64Array, StringArray};
126    use datafusion::arrow::datatypes::{DataType, Field, Schema};
127
128    #[test]
129    fn test_engine_basic() {
130        let engine = DataEngine::new();
131        let schema = Arc::new(Schema::new(vec![
132            Field::new("id", DataType::Int64, false),
133            Field::new("name", DataType::Utf8, false),
134        ]));
135        let batch = RecordBatch::try_new(
136            schema,
137            vec![
138                Arc::new(Int64Array::from(vec![1, 2, 3])),
139                Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
140            ],
141        )
142        .unwrap();
143
144        engine.register_batch("test_table", batch).unwrap();
145        let df = engine.sql("SELECT * FROM test_table WHERE id > 1").unwrap();
146        let results = engine.collect(df).unwrap();
147        assert_eq!(results[0].num_rows(), 2);
148    }
149
150    #[test]
151    fn test_engine_with_config() {
152        let config = DataEngineConfig {
153            max_memory_bytes: 256 * 1024 * 1024,
154            spill_to_disk: true,
155            spill_path: None,
156        };
157        let engine = DataEngine::with_config(config);
158        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
159        let batch =
160            RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
161        engine.register_batch("t", batch).unwrap();
162        let df = engine.sql("SELECT * FROM t").unwrap();
163        let results = engine.collect(df).unwrap();
164        assert_eq!(results[0].num_rows(), 3);
165    }
166}