Skip to main content

micromegas_datafusion_wasm/
lib.rs

1use std::sync::{Arc, Once};
2
3use arrow::array::RecordBatch;
4use arrow::datatypes::Schema;
5use arrow::ipc::reader::StreamReader;
6use arrow::ipc::writer::StreamWriter;
7use datafusion::execution::SessionStateBuilder;
8use datafusion::physical_optimizer::PhysicalOptimizerRule;
9use datafusion::physical_optimizer::optimizer::PhysicalOptimizer;
10use datafusion::prelude::*;
11use micromegas_tracing::prelude::*;
12use wasm_bindgen::prelude::*;
13
14static INIT: Once = Once::new();
15
16fn ensure_tracing() {
17    INIT.call_once(|| {
18        #[cfg(target_arch = "wasm32")]
19        {
20            let guard =
21                micromegas_telemetry_sink::init_telemetry().expect("failed to init telemetry");
22            std::mem::forget(guard); // leak — WASM module lives for page lifetime
23        }
24    });
25}
26
27#[wasm_bindgen]
28pub struct WasmQueryEngine {
29    ctx: SessionContext,
30}
31
32impl Default for WasmQueryEngine {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38#[wasm_bindgen]
39impl WasmQueryEngine {
40    #[wasm_bindgen(constructor)]
41    pub fn new() -> Self {
42        ensure_tracing();
43        info!("WasmQueryEngine created");
44
45        // Work around a DataFusion 52.1 bug where the LimitPushdown physical
46        // optimizer rule removes GlobalLimitExec without actually pushing the
47        // fetch into DataSourceExec, causing LIMIT to be silently ignored.
48        // Fixed upstream in https://github.com/apache/datafusion/pull/20048
49        // but not yet released.
50        // TODO: remove after upgrading DataFusion past 52.1
51        // https://github.com/madesroches/micromegas/issues/809
52        let filtered_rules = PhysicalOptimizer::default()
53            .rules
54            .into_iter()
55            .filter(|rule: &Arc<dyn PhysicalOptimizerRule + Send + Sync>| {
56                rule.name() != "LimitPushdown"
57            })
58            .collect::<Vec<_>>();
59
60        let config = SessionConfig::default().with_information_schema(true);
61        let state = SessionStateBuilder::new()
62            .with_config(config)
63            .with_default_features()
64            .with_physical_optimizer_rules(filtered_rules)
65            .build();
66        let ctx = SessionContext::new_with_state(state);
67        micromegas_datafusion_extensions::register_extension_udfs(&ctx);
68        Self { ctx }
69    }
70
71    /// Register Arrow IPC stream bytes as a named table.
72    /// Replaces any existing table with the same name.
73    /// Returns the number of rows registered.
74    pub fn register_table(&self, name: &str, ipc_bytes: &[u8]) -> Result<usize, JsValue> {
75        info!("registering table '{name}' ({} bytes)", ipc_bytes.len());
76        let cursor = std::io::Cursor::new(ipc_bytes);
77        let reader = StreamReader::try_new(cursor, None)
78            .map_err(|e| JsValue::from_str(&format!("Failed to read IPC stream: {e}")))?;
79
80        let schema = reader.schema();
81        let mut batches = Vec::new();
82        let mut row_count: usize = 0;
83
84        for batch_result in reader {
85            let batch = batch_result
86                .map_err(|e| JsValue::from_str(&format!("Failed to read batch: {e}")))?;
87            row_count += batch.num_rows();
88            batches.push(batch);
89        }
90
91        let table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
92            .map_err(|e| JsValue::from_str(&format!("Failed to create MemTable: {e}")))?;
93
94        let _ = self.ctx.deregister_table(name);
95        self.ctx
96            .register_table(name, Arc::new(table))
97            .map_err(|e| JsValue::from_str(&format!("Failed to register table: {e}")))?;
98
99        info!("registered table '{name}': {row_count} rows");
100        Ok(row_count)
101    }
102
103    /// Execute SQL, return Arrow IPC stream bytes.
104    pub async fn execute_sql(&self, sql: &str) -> Result<Vec<u8>, JsValue> {
105        info!("execute_sql: {sql}");
106        let df = self
107            .ctx
108            .sql(sql)
109            .await
110            .map_err(|e| JsValue::from_str(&format!("SQL error: {e}")))?;
111
112        let schema = Arc::new(df.schema().as_arrow().clone());
113
114        let batches = df
115            .collect()
116            .await
117            .map_err(|e| JsValue::from_str(&format!("Execution error: {e}")))?;
118
119        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
120        info!("execute_sql completed: {total_rows} rows");
121        serialize_to_ipc(&schema, &batches)
122    }
123
124    /// Execute SQL, register result as a named table, return Arrow IPC stream bytes.
125    pub async fn execute_and_register(
126        &self,
127        sql: &str,
128        register_as: &str,
129    ) -> Result<Vec<u8>, JsValue> {
130        info!("execute_and_register: {sql} -> '{register_as}'");
131        let df = self
132            .ctx
133            .sql(sql)
134            .await
135            .map_err(|e| JsValue::from_str(&format!("SQL error: {e}")))?;
136
137        let schema = Arc::new(df.schema().as_arrow().clone());
138
139        let batches = df
140            .collect()
141            .await
142            .map_err(|e| JsValue::from_str(&format!("Execution error: {e}")))?;
143
144        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
145
146        // Register result as named table
147        let _ = self.ctx.deregister_table(register_as);
148        let mem_table =
149            datafusion::datasource::MemTable::try_new(schema.clone(), vec![batches.clone()])
150                .map_err(|e| JsValue::from_str(&format!("Failed to create MemTable: {e}")))?;
151        self.ctx
152            .register_table(register_as, Arc::new(mem_table))
153            .map_err(|e| JsValue::from_str(&format!("Failed to register table: {e}")))?;
154
155        info!("execute_and_register completed: {total_rows} rows registered as '{register_as}'");
156        serialize_to_ipc(&schema, &batches)
157    }
158
159    /// Deregister a single named table. Returns true if the table existed.
160    pub fn deregister_table(&self, name: &str) -> Result<bool, JsValue> {
161        let existed = self
162            .ctx
163            .deregister_table(name)
164            .map_err(|e| JsValue::from_str(&format!("Failed to deregister table: {e}")))?;
165        Ok(existed.is_some())
166    }
167
168    /// Deregister all tables.
169    pub fn reset(&self) {
170        let names: Vec<String> = self
171            .ctx
172            .catalog_names()
173            .into_iter()
174            .flat_map(|catalog_name| {
175                self.ctx
176                    .catalog(&catalog_name)
177                    .into_iter()
178                    .flat_map(move |catalog| {
179                        catalog
180                            .schema_names()
181                            .into_iter()
182                            .flat_map(move |schema_name| {
183                                catalog
184                                    .schema(&schema_name)
185                                    .map(|schema| {
186                                        schema
187                                            .table_names()
188                                            .into_iter()
189                                            .map(move |t| t.to_string())
190                                            .collect::<Vec<_>>()
191                                    })
192                                    .unwrap_or_default()
193                            })
194                    })
195            })
196            .collect();
197
198        for table_name in names {
199            let _ = self.ctx.deregister_table(&table_name);
200        }
201    }
202}
203
204fn serialize_to_ipc(schema: &Arc<Schema>, batches: &[RecordBatch]) -> Result<Vec<u8>, JsValue> {
205    let mut buf = Vec::new();
206    let mut writer = StreamWriter::try_new(&mut buf, schema)
207        .map_err(|e| JsValue::from_str(&format!("IPC writer error: {e}")))?;
208
209    for batch in batches {
210        writer
211            .write(batch)
212            .map_err(|e| JsValue::from_str(&format!("IPC write error: {e}")))?;
213    }
214
215    writer
216        .finish()
217        .map_err(|e| JsValue::from_str(&format!("IPC finish error: {e}")))?;
218
219    Ok(buf)
220}