micromegas_datafusion_wasm/
lib.rs1use std::sync::{Arc, Once};
2
3use arrow::array::RecordBatch;
4use arrow::datatypes::Schema;
5use arrow::ipc::reader::StreamReader;
6use arrow::ipc::writer::StreamWriter;
7use datafusion::execution::SessionStateBuilder;
8use datafusion::physical_optimizer::PhysicalOptimizerRule;
9use datafusion::physical_optimizer::optimizer::PhysicalOptimizer;
10use datafusion::prelude::*;
11use micromegas_tracing::prelude::*;
12use wasm_bindgen::prelude::*;
13
14static INIT: Once = Once::new();
15
16fn ensure_tracing() {
17 INIT.call_once(|| {
18 #[cfg(target_arch = "wasm32")]
19 {
20 let guard =
21 micromegas_telemetry_sink::init_telemetry().expect("failed to init telemetry");
22 std::mem::forget(guard); }
24 });
25}
26
27#[wasm_bindgen]
28pub struct WasmQueryEngine {
29 ctx: SessionContext,
30}
31
32impl Default for WasmQueryEngine {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38#[wasm_bindgen]
39impl WasmQueryEngine {
40 #[wasm_bindgen(constructor)]
41 pub fn new() -> Self {
42 ensure_tracing();
43 info!("WasmQueryEngine created");
44
45 let filtered_rules = PhysicalOptimizer::default()
53 .rules
54 .into_iter()
55 .filter(|rule: &Arc<dyn PhysicalOptimizerRule + Send + Sync>| {
56 rule.name() != "LimitPushdown"
57 })
58 .collect::<Vec<_>>();
59
60 let config = SessionConfig::default().with_information_schema(true);
61 let state = SessionStateBuilder::new()
62 .with_config(config)
63 .with_default_features()
64 .with_physical_optimizer_rules(filtered_rules)
65 .build();
66 let ctx = SessionContext::new_with_state(state);
67 micromegas_datafusion_extensions::register_extension_udfs(&ctx);
68 Self { ctx }
69 }
70
71 pub fn register_table(&self, name: &str, ipc_bytes: &[u8]) -> Result<usize, JsValue> {
75 info!("registering table '{name}' ({} bytes)", ipc_bytes.len());
76 let cursor = std::io::Cursor::new(ipc_bytes);
77 let reader = StreamReader::try_new(cursor, None)
78 .map_err(|e| JsValue::from_str(&format!("Failed to read IPC stream: {e}")))?;
79
80 let schema = reader.schema();
81 let mut batches = Vec::new();
82 let mut row_count: usize = 0;
83
84 for batch_result in reader {
85 let batch = batch_result
86 .map_err(|e| JsValue::from_str(&format!("Failed to read batch: {e}")))?;
87 row_count += batch.num_rows();
88 batches.push(batch);
89 }
90
91 let table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
92 .map_err(|e| JsValue::from_str(&format!("Failed to create MemTable: {e}")))?;
93
94 let _ = self.ctx.deregister_table(name);
95 self.ctx
96 .register_table(name, Arc::new(table))
97 .map_err(|e| JsValue::from_str(&format!("Failed to register table: {e}")))?;
98
99 info!("registered table '{name}': {row_count} rows");
100 Ok(row_count)
101 }
102
103 pub async fn execute_sql(&self, sql: &str) -> Result<Vec<u8>, JsValue> {
105 info!("execute_sql: {sql}");
106 let df = self
107 .ctx
108 .sql(sql)
109 .await
110 .map_err(|e| JsValue::from_str(&format!("SQL error: {e}")))?;
111
112 let schema = Arc::new(df.schema().as_arrow().clone());
113
114 let batches = df
115 .collect()
116 .await
117 .map_err(|e| JsValue::from_str(&format!("Execution error: {e}")))?;
118
119 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
120 info!("execute_sql completed: {total_rows} rows");
121 serialize_to_ipc(&schema, &batches)
122 }
123
124 pub async fn execute_and_register(
126 &self,
127 sql: &str,
128 register_as: &str,
129 ) -> Result<Vec<u8>, JsValue> {
130 info!("execute_and_register: {sql} -> '{register_as}'");
131 let df = self
132 .ctx
133 .sql(sql)
134 .await
135 .map_err(|e| JsValue::from_str(&format!("SQL error: {e}")))?;
136
137 let schema = Arc::new(df.schema().as_arrow().clone());
138
139 let batches = df
140 .collect()
141 .await
142 .map_err(|e| JsValue::from_str(&format!("Execution error: {e}")))?;
143
144 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
145
146 let _ = self.ctx.deregister_table(register_as);
148 let mem_table =
149 datafusion::datasource::MemTable::try_new(schema.clone(), vec![batches.clone()])
150 .map_err(|e| JsValue::from_str(&format!("Failed to create MemTable: {e}")))?;
151 self.ctx
152 .register_table(register_as, Arc::new(mem_table))
153 .map_err(|e| JsValue::from_str(&format!("Failed to register table: {e}")))?;
154
155 info!("execute_and_register completed: {total_rows} rows registered as '{register_as}'");
156 serialize_to_ipc(&schema, &batches)
157 }
158
159 pub fn deregister_table(&self, name: &str) -> Result<bool, JsValue> {
161 let existed = self
162 .ctx
163 .deregister_table(name)
164 .map_err(|e| JsValue::from_str(&format!("Failed to deregister table: {e}")))?;
165 Ok(existed.is_some())
166 }
167
168 pub fn reset(&self) {
170 let names: Vec<String> = self
171 .ctx
172 .catalog_names()
173 .into_iter()
174 .flat_map(|catalog_name| {
175 self.ctx
176 .catalog(&catalog_name)
177 .into_iter()
178 .flat_map(move |catalog| {
179 catalog
180 .schema_names()
181 .into_iter()
182 .flat_map(move |schema_name| {
183 catalog
184 .schema(&schema_name)
185 .map(|schema| {
186 schema
187 .table_names()
188 .into_iter()
189 .map(move |t| t.to_string())
190 .collect::<Vec<_>>()
191 })
192 .unwrap_or_default()
193 })
194 })
195 })
196 .collect();
197
198 for table_name in names {
199 let _ = self.ctx.deregister_table(&table_name);
200 }
201 }
202}
203
204fn serialize_to_ipc(schema: &Arc<Schema>, batches: &[RecordBatch]) -> Result<Vec<u8>, JsValue> {
205 let mut buf = Vec::new();
206 let mut writer = StreamWriter::try_new(&mut buf, schema)
207 .map_err(|e| JsValue::from_str(&format!("IPC writer error: {e}")))?;
208
209 for batch in batches {
210 writer
211 .write(batch)
212 .map_err(|e| JsValue::from_str(&format!("IPC write error: {e}")))?;
213 }
214
215 writer
216 .finish()
217 .map_err(|e| JsValue::from_str(&format!("IPC finish error: {e}")))?;
218
219 Ok(buf)
220}