Skip to main content

tl_data/
engine.rs

1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::util::pretty::pretty_format_batches;
3use datafusion::execution::context::SessionContext;
4use datafusion::execution::disk_manager::DiskManagerConfig;
5use datafusion::execution::memory_pool::FairSpillPool;
6use datafusion::execution::runtime_env::RuntimeEnvBuilder;
7use datafusion::prelude::*;
8use std::sync::Arc;
9use tokio::runtime::Runtime;
10
11/// Configuration for the DataEngine.
12pub struct DataEngineConfig {
13    /// Maximum memory in bytes for the DataFusion pool (default: 512MB).
14    pub max_memory_bytes: usize,
15    /// Enable spill-to-disk when memory limit is reached.
16    pub spill_to_disk: bool,
17    /// Directory for spill files (default: system temp dir).
18    pub spill_path: Option<String>,
19}
20
21impl Default for DataEngineConfig {
22    fn default() -> Self {
23        DataEngineConfig {
24            max_memory_bytes: 512 * 1024 * 1024, // 512 MB
25            spill_to_disk: true,
26            spill_path: None,
27        }
28    }
29}
30
31/// Synchronous wrapper around DataFusion's async SessionContext.
32pub struct DataEngine {
33    pub ctx: SessionContext,
34    pub rt: Arc<Runtime>,
35}
36
37impl Default for DataEngine {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl DataEngine {
44    /// Create a new DataEngine with default configuration.
45    /// Backward-compatible with existing code.
46    pub fn new() -> Self {
47        Self::with_config(DataEngineConfig::default())
48    }
49
50    /// Create a new DataEngine with custom configuration.
51    pub fn with_config(config: DataEngineConfig) -> Self {
52        let rt = Arc::new(Runtime::new().expect("Failed to create tokio runtime for DataEngine"));
53
54        // Build runtime environment with memory pool and disk manager
55        let pool = FairSpillPool::new(config.max_memory_bytes);
56
57        let mut rt_builder = RuntimeEnvBuilder::new().with_memory_pool(Arc::new(pool));
58
59        if config.spill_to_disk {
60            let disk_config = if let Some(ref path) = config.spill_path {
61                DiskManagerConfig::new_specified(vec![path.clone().into()])
62            } else {
63                DiskManagerConfig::NewOs
64            };
65            rt_builder = rt_builder.with_disk_manager(disk_config);
66        }
67
68        let runtime_env = rt_builder.build().expect("Failed to build RuntimeEnv");
69
70        // Configure session with parallelism
71        let target_partitions = num_cpus::get();
72        let session_config = SessionConfig::new().with_target_partitions(target_partitions);
73
74        let ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime_env));
75
76        DataEngine { ctx, rt }
77    }
78
79    /// Execute a DataFusion DataFrame and collect results synchronously.
80    pub fn collect(&self, df: DataFrame) -> Result<Vec<RecordBatch>, String> {
81        self.rt
82            .block_on(df.collect())
83            .map_err(|e| format!("DataFusion collect error: {e}"))
84    }
85
86    /// Format collected batches as a pretty table string.
87    pub fn format_batches(batches: &[RecordBatch]) -> Result<String, String> {
88        pretty_format_batches(batches)
89            .map(|t| t.to_string())
90            .map_err(|e| format!("Format error: {e}"))
91    }
92
93    /// Register a RecordBatch as a named table in the session.
94    pub fn register_batch(&self, name: &str, batch: RecordBatch) -> Result<(), String> {
95        let schema = batch.schema();
96        let provider = datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]])
97            .map_err(|e| format!("MemTable error: {e}"))?;
98        self.ctx
99            .register_table(name, Arc::new(provider))
100            .map_err(|e| format!("Register table error: {e}"))?;
101        Ok(())
102    }
103
104    /// Register multiple RecordBatches as a single named table.
105    /// Each batch becomes a separate partition, enabling DataFusion parallelism.
106    pub fn register_batches(
107        &self,
108        name: &str,
109        schema: Arc<datafusion::arrow::datatypes::Schema>,
110        batches: Vec<RecordBatch>,
111    ) -> Result<(), String> {
112        if batches.is_empty() {
113            // Register an empty table with the given schema
114            let provider = datafusion::datasource::MemTable::try_new(schema, vec![])
115                .map_err(|e| format!("MemTable error: {e}"))?;
116            self.ctx
117                .register_table(name, Arc::new(provider))
118                .map_err(|e| format!("Register table error: {e}"))?;
119            return Ok(());
120        }
121        // Each batch is its own partition for DataFusion parallelism
122        let partitions: Vec<Vec<RecordBatch>> = batches.into_iter().map(|b| vec![b]).collect();
123        let provider = datafusion::datasource::MemTable::try_new(schema, partitions)
124            .map_err(|e| format!("MemTable error: {e}"))?;
125        // Deregister previous table if it exists
126        let _ = self.ctx.deregister_table(name);
127        self.ctx
128            .register_table(name, Arc::new(provider))
129            .map_err(|e| format!("Register table error: {e}"))?;
130        Ok(())
131    }
132
133    /// Run a SQL query and return results.
134    pub fn sql(&self, query: &str) -> Result<DataFrame, String> {
135        self.rt
136            .block_on(self.ctx.sql(query))
137            .map_err(|e| format!("SQL error: {e}"))
138    }
139
140    /// Get the underlying session context for DataFusion operations.
141    pub fn session_ctx(&self) -> &SessionContext {
142        &self.ctx
143    }
144
145    /// Get a reference to the tokio Runtime.
146    pub fn runtime(&self) -> &Arc<Runtime> {
147        &self.rt
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use datafusion::arrow::array::{Int64Array, StringArray};
155    use datafusion::arrow::datatypes::{DataType, Field, Schema};
156
157    #[test]
158    fn test_engine_basic() {
159        let engine = DataEngine::new();
160        let schema = Arc::new(Schema::new(vec![
161            Field::new("id", DataType::Int64, false),
162            Field::new("name", DataType::Utf8, false),
163        ]));
164        let batch = RecordBatch::try_new(
165            schema,
166            vec![
167                Arc::new(Int64Array::from(vec![1, 2, 3])),
168                Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
169            ],
170        )
171        .unwrap();
172
173        engine.register_batch("test_table", batch).unwrap();
174        let df = engine.sql("SELECT * FROM test_table WHERE id > 1").unwrap();
175        let results = engine.collect(df).unwrap();
176        assert_eq!(results[0].num_rows(), 2);
177    }
178
179    #[test]
180    fn test_engine_with_config() {
181        let config = DataEngineConfig {
182            max_memory_bytes: 256 * 1024 * 1024,
183            spill_to_disk: true,
184            spill_path: None,
185        };
186        let engine = DataEngine::with_config(config);
187        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
188        let batch =
189            RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
190        engine.register_batch("t", batch).unwrap();
191        let df = engine.sql("SELECT * FROM t").unwrap();
192        let results = engine.collect(df).unwrap();
193        assert_eq!(results[0].num_rows(), 3);
194    }
195}