1mod builder;
8mod execution;
9mod query_extraction;
10mod stdlib;
11mod types;
12
13pub use crate::query_result::QueryType;
15pub use builder::ShapeEngineBuilder;
16use shape_value::ValueWord;
17pub use types::{
18 EngineBootstrapState, ExecutionMetrics, ExecutionResult, ExecutionType, Message, MessageLevel,
19};
20
21use crate::Runtime;
22use crate::data::DataFrame;
23use shape_ast::error::{Result, ShapeError};
24
25#[cfg(feature = "jit")]
26use std::collections::HashMap;
27
28use crate::hashing::HashDigest;
29use crate::snapshot::{ContextSnapshot, ExecutionSnapshot, SemanticSnapshot, SnapshotStore};
30use serde::Serialize;
31use shape_ast::Program;
32use shape_wire::WireValue;
33
34pub trait ExpressionEvaluator: Send + Sync {
40 fn eval_statements(
42 &self,
43 stmts: &[shape_ast::Statement],
44 ctx: &mut crate::context::ExecutionContext,
45 ) -> Result<ValueWord>;
46
47 fn eval_expr(
49 &self,
50 expr: &shape_ast::Expr,
51 ctx: &mut crate::context::ExecutionContext,
52 ) -> Result<ValueWord>;
53}
54
55pub struct ProgramExecutorResult {
57 pub wire_value: WireValue,
58 pub type_info: Option<shape_wire::metadata::TypeInfo>,
59 pub execution_type: ExecutionType,
60 pub content_json: Option<serde_json::Value>,
61 pub content_html: Option<String>,
62 pub content_terminal: Option<String>,
63}
64
65pub trait ProgramExecutor {
67 fn execute_program(
68 &mut self,
69 engine: &mut ShapeEngine,
70 program: &Program,
71 ) -> Result<ProgramExecutorResult>;
72}
73
74pub struct ShapeEngine {
76 pub runtime: Runtime,
78 pub default_data: DataFrame,
80 #[cfg(feature = "jit")]
82 pub(crate) jit_cache: HashMap<u64, ()>,
83 pub(crate) current_source: Option<String>,
85 pub(crate) snapshot_store: Option<SnapshotStore>,
87 pub(crate) last_snapshot: Option<HashDigest>,
89 pub(crate) script_path: Option<String>,
91 pub(crate) exported_symbols: std::collections::HashSet<String>,
93}
94
95impl ShapeEngine {
96 pub fn new() -> Result<Self> {
98 let mut runtime = Runtime::new_without_stdlib();
99 runtime.enable_persistent_context_without_data();
100
101 Ok(Self {
102 runtime,
103 default_data: DataFrame::default(),
104 #[cfg(feature = "jit")]
105 jit_cache: HashMap::new(),
106 current_source: None,
107 snapshot_store: None,
108 last_snapshot: None,
109 script_path: None,
110 exported_symbols: std::collections::HashSet::new(),
111 })
112 }
113
114 pub fn with_data(data: DataFrame) -> Result<Self> {
116 let mut runtime = Runtime::new_without_stdlib();
117 runtime.enable_persistent_context(&data);
118 Ok(Self {
119 runtime,
120 default_data: data,
121 #[cfg(feature = "jit")]
122 jit_cache: HashMap::new(),
123 current_source: None,
124 snapshot_store: None,
125 last_snapshot: None,
126 script_path: None,
127 exported_symbols: std::collections::HashSet::new(),
128 })
129 }
130
131 pub fn with_async_provider(provider: crate::data::SharedAsyncProvider) -> Result<Self> {
136 let runtime_handle = tokio::runtime::Handle::try_current()
137 .map_err(|_| ShapeError::RuntimeError {
138 message: "No tokio runtime available. Ensure with_async_provider is called within a tokio context.".to_string(),
139 location: None,
140 })?;
141 let mut runtime = Runtime::new_without_stdlib();
142
143 let ctx = crate::context::ExecutionContext::with_async_provider(provider, runtime_handle);
145 runtime.set_persistent_context(ctx);
146
147 Ok(Self {
148 runtime,
149 default_data: DataFrame::default(),
150 #[cfg(feature = "jit")]
151 jit_cache: HashMap::new(),
152 current_source: None,
153 snapshot_store: None,
154 last_snapshot: None,
155 script_path: None,
156 exported_symbols: std::collections::HashSet::new(),
157 })
158 }
159
160 pub fn init_repl(&mut self) {
166 if let Some(ctx) = self.runtime.persistent_context_mut() {
168 ctx.set_output_adapter(Box::new(crate::output_adapter::ReplAdapter));
169 }
170 }
171
172 pub fn capture_bootstrap_state(&self) -> Result<EngineBootstrapState> {
176 let context =
177 self.runtime
178 .persistent_context()
179 .cloned()
180 .ok_or_else(|| ShapeError::RuntimeError {
181 message: "No persistent context available for bootstrap capture".to_string(),
182 location: None,
183 })?;
184 Ok(EngineBootstrapState {
185 semantic: SemanticSnapshot {
186 exported_symbols: self.exported_symbols.clone(),
187 },
188 context,
189 })
190 }
191
192 pub fn apply_bootstrap_state(&mut self, state: &EngineBootstrapState) {
194 self.exported_symbols = state.semantic.exported_symbols.clone();
195 self.runtime.set_persistent_context(state.context.clone());
196 }
197
198 pub fn set_script_path(&mut self, path: impl Into<String>) {
200 self.script_path = Some(path.into());
201 }
202
203 pub fn script_path(&self) -> Option<&str> {
205 self.script_path.as_deref()
206 }
207
208 pub fn enable_snapshot_store(&mut self, store: SnapshotStore) {
210 self.snapshot_store = Some(store);
211 }
212
213 pub fn last_snapshot(&self) -> Option<&HashDigest> {
215 self.last_snapshot.as_ref()
216 }
217
218 pub fn snapshot_store(&self) -> Option<&SnapshotStore> {
220 self.snapshot_store.as_ref()
221 }
222
223 pub fn store_snapshot_blob<T: Serialize>(&self, value: &T) -> Result<HashDigest> {
225 let store = self
226 .snapshot_store
227 .as_ref()
228 .ok_or_else(|| ShapeError::RuntimeError {
229 message: "Snapshot store not configured".to_string(),
230 location: None,
231 })?;
232 Ok(store.put_struct(value)?)
233 }
234
235 pub fn snapshot_with_hashes(
237 &mut self,
238 vm_hash: Option<HashDigest>,
239 bytecode_hash: Option<HashDigest>,
240 ) -> Result<HashDigest> {
241 let store = self
242 .snapshot_store
243 .as_ref()
244 .ok_or_else(|| ShapeError::RuntimeError {
245 message: "Snapshot store not configured".to_string(),
246 location: None,
247 })?;
248
249 let semantic = SemanticSnapshot {
250 exported_symbols: self.exported_symbols.clone(),
251 };
252 let semantic_hash = store.put_struct(&semantic)?;
253
254 let context = if let Some(ctx) = self.runtime.persistent_context() {
255 ctx.snapshot(store)?
256 } else {
257 return Err(ShapeError::RuntimeError {
258 message: "No persistent context for snapshot".to_string(),
259 location: None,
260 });
261 };
262 let context_hash = store.put_struct(&context)?;
263
264 let snapshot = ExecutionSnapshot {
265 version: crate::snapshot::SNAPSHOT_VERSION,
266 created_at_ms: chrono::Utc::now().timestamp_millis(),
267 semantic_hash,
268 context_hash,
269 vm_hash,
270 bytecode_hash,
271 script_path: self.script_path.clone(),
272 };
273
274 let snapshot_hash = store.put_snapshot(&snapshot)?;
275 self.last_snapshot = Some(snapshot_hash.clone());
276 Ok(snapshot_hash)
277 }
278
279 pub fn load_snapshot(
281 &self,
282 snapshot_id: &HashDigest,
283 ) -> Result<(
284 SemanticSnapshot,
285 ContextSnapshot,
286 Option<HashDigest>,
287 Option<HashDigest>,
288 )> {
289 let store = self
290 .snapshot_store
291 .as_ref()
292 .ok_or_else(|| ShapeError::RuntimeError {
293 message: "Snapshot store not configured".to_string(),
294 location: None,
295 })?;
296 let snapshot = store.get_snapshot(snapshot_id)?;
297 let semantic: SemanticSnapshot =
298 store
299 .get_struct(&snapshot.semantic_hash)
300 .map_err(|e| ShapeError::RuntimeError {
301 message: format!("failed to deserialize SemanticSnapshot: {e}"),
302 location: None,
303 })?;
304 let context: ContextSnapshot =
305 store
306 .get_struct(&snapshot.context_hash)
307 .map_err(|e| ShapeError::RuntimeError {
308 message: format!("failed to deserialize ContextSnapshot: {e}"),
309 location: None,
310 })?;
311 Ok((semantic, context, snapshot.vm_hash, snapshot.bytecode_hash))
312 }
313
314 pub fn apply_snapshot(
316 &mut self,
317 semantic: SemanticSnapshot,
318 context: ContextSnapshot,
319 ) -> Result<()> {
320 self.exported_symbols = semantic.exported_symbols;
321 if let Some(ctx) = self.runtime.persistent_context_mut() {
322 let store = self
323 .snapshot_store
324 .as_ref()
325 .ok_or_else(|| ShapeError::RuntimeError {
326 message: "Snapshot store not configured".to_string(),
327 location: None,
328 })?;
329 ctx.restore_from_snapshot(context, store)?;
330 Ok(())
331 } else {
332 Err(ShapeError::RuntimeError {
333 message: "No persistent context for snapshot".to_string(),
334 location: None,
335 })
336 }
337 }
338
339 pub fn register_extension_modules(
342 &mut self,
343 modules: &[crate::extensions::ParsedModuleSchema],
344 ) {
345 self.runtime.register_extension_module_artifacts(modules);
346 }
347
348 pub fn register_language_runtime_artifacts(&mut self) {
357 let runtimes = self.language_runtimes();
358 let mut schemas = Vec::new();
359 for (_lang_id, runtime) in &runtimes {
360 match runtime.shape_source() {
361 Ok(Some((namespace, source))) => {
362 schemas.push(crate::extensions::ParsedModuleSchema {
363 module_name: namespace.clone(),
364 functions: Vec::new(),
365 artifacts: vec![crate::extensions::ParsedModuleArtifact {
366 module_path: namespace,
367 source: Some(source),
368 compiled: None,
369 }],
370 });
371 }
372 Ok(None) => {}
373 Err(e) => {
374 tracing::warn!(
375 "Failed to get shape source from language runtime: {}",
376 e
377 );
378 }
379 }
380 }
381 if !schemas.is_empty() {
382 self.runtime.register_extension_module_artifacts(&schemas);
383 }
384 }
385
386 pub fn set_source(&mut self, source: &str) {
391 self.current_source = Some(source.to_string());
392 }
393
394 pub fn current_source(&self) -> Option<&str> {
396 self.current_source.as_deref()
397 }
398
399 pub fn register_provider(&mut self, name: &str, provider: crate::data::SharedAsyncProvider) {
410 if let Some(ctx) = self.runtime.persistent_context_mut() {
411 ctx.register_provider(name, provider);
412 }
413 }
414
415 pub fn set_default_provider(&mut self, name: &str) -> Result<()> {
419 if let Some(ctx) = self.runtime.persistent_context_mut() {
420 ctx.set_default_provider(name)
421 } else {
422 Err(ShapeError::RuntimeError {
423 message: "No execution context available".to_string(),
424 location: None,
425 })
426 }
427 }
428
429 pub fn register_type_mapping(
456 &mut self,
457 type_name: &str,
458 mapping: crate::type_mapping::TypeMapping,
459 ) {
460 if let Some(ctx) = self.runtime.persistent_context_mut() {
461 ctx.register_type_mapping(type_name, mapping);
462 }
463 }
464
465 pub fn get_runtime(&self) -> &Runtime {
467 &self.runtime
468 }
469
470 pub fn get_runtime_mut(&mut self) -> &mut Runtime {
472 &mut self.runtime
473 }
474
475 pub fn get_variable_format_hint(&self, name: &str) -> Option<String> {
480 self.runtime
481 .persistent_context()
482 .and_then(|ctx| ctx.get_variable_format_hint(name))
483 }
484
485 pub fn format_value_string(
517 &mut self,
518 value: f64,
519 type_name: &str,
520 format_name: Option<&str>,
521 params: &std::collections::HashMap<String, serde_json::Value>,
522 ) -> Result<String> {
523 use std::sync::Arc;
524
525 let (resolved_type_name, merged_params) =
527 self.resolve_type_alias_for_formatting(type_name, params)?;
528
529 let param_values: std::collections::HashMap<String, ValueWord> = merged_params
531 .iter()
532 .map(|(k, v)| {
533 let runtime_val = match v {
534 serde_json::Value::Number(n) => ValueWord::from_f64(n.as_f64().unwrap_or(0.0)),
535 serde_json::Value::String(s) => ValueWord::from_string(Arc::new(s.clone())),
536 serde_json::Value::Bool(b) => ValueWord::from_bool(*b),
537 _ => ValueWord::none(),
538 };
539 (k.clone(), runtime_val)
540 })
541 .collect();
542
543 let runtime_value = ValueWord::from_f64(value);
545
546 self.runtime.format_value(
548 runtime_value,
549 resolved_type_name.as_str(),
550 format_name,
551 param_values,
552 )
553 }
554
555 fn resolve_type_alias_for_formatting(
560 &self,
561 type_name: &str,
562 params: &std::collections::HashMap<String, serde_json::Value>,
563 ) -> Result<(String, std::collections::HashMap<String, serde_json::Value>)> {
564 let resolved = self
566 .runtime
567 .persistent_context()
568 .map(|ctx| ctx.resolve_type_for_format(type_name));
569
570 if let Some((base_type, Some(overrides))) = resolved {
571 if base_type != type_name {
572 let mut merged = std::collections::HashMap::new();
573
574 for (key, val) in overrides {
576 let json_val = if let Some(n) = val.as_f64() {
577 serde_json::json!(n)
578 } else if val.is_bool() {
579 serde_json::json!(val.as_bool())
580 } else {
581 continue;
583 };
584 merged.insert(key, json_val);
585 }
586
587 for (key, val) in params {
589 merged.insert(key.clone(), val.clone());
590 }
591
592 return Ok((base_type, merged));
593 }
594 }
595
596 Ok((type_name.to_string(), params.clone()))
598 }
599
600 pub fn load_extension(
626 &mut self,
627 path: &std::path::Path,
628 config: &serde_json::Value,
629 ) -> Result<crate::extensions::LoadedExtension> {
630 if let Some(ctx) = self.runtime.persistent_context_mut() {
631 ctx.load_extension(path, config)
632 } else {
633 Err(ShapeError::RuntimeError {
634 message: "No execution context available for extension loading".to_string(),
635 location: None,
636 })
637 }
638 }
639
640 pub fn unload_extension(&mut self, name: &str) -> bool {
650 if let Some(ctx) = self.runtime.persistent_context_mut() {
651 ctx.unload_extension(name)
652 } else {
653 false
654 }
655 }
656
657 pub fn list_extensions(&self) -> Vec<String> {
659 if let Some(ctx) = self.runtime.persistent_context() {
660 ctx.list_extensions()
661 } else {
662 Vec::new()
663 }
664 }
665
666 pub fn get_extension_query_schema(
676 &self,
677 name: &str,
678 ) -> Option<crate::extensions::ParsedQuerySchema> {
679 if let Some(ctx) = self.runtime.persistent_context() {
680 ctx.get_extension_query_schema(name)
681 } else {
682 None
683 }
684 }
685
686 pub fn get_extension_output_schema(
696 &self,
697 name: &str,
698 ) -> Option<crate::extensions::ParsedOutputSchema> {
699 if let Some(ctx) = self.runtime.persistent_context() {
700 ctx.get_extension_output_schema(name)
701 } else {
702 None
703 }
704 }
705
706 pub fn get_extension(
708 &self,
709 name: &str,
710 ) -> Option<std::sync::Arc<crate::extensions::ExtensionDataSource>> {
711 if let Some(ctx) = self.runtime.persistent_context() {
712 ctx.get_extension(name)
713 } else {
714 None
715 }
716 }
717
718 pub fn get_extension_module_schema(
720 &self,
721 module_name: &str,
722 ) -> Option<crate::extensions::ParsedModuleSchema> {
723 if let Some(ctx) = self.runtime.persistent_context() {
724 ctx.get_extension_module_schema(module_name)
725 } else {
726 None
727 }
728 }
729
730 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
732 if let Some(ctx) = self.runtime.persistent_context() {
733 ctx.module_exports_from_extensions()
734 } else {
735 Vec::new()
736 }
737 }
738
739 pub fn language_runtimes(
741 &self,
742 ) -> std::collections::HashMap<
743 String,
744 std::sync::Arc<crate::plugins::language_runtime::PluginLanguageRuntime>,
745 > {
746 if let Some(ctx) = self.runtime.persistent_context() {
747 ctx.language_runtimes()
748 } else {
749 std::collections::HashMap::new()
750 }
751 }
752
753 pub fn invoke_extension_module_nb(
755 &self,
756 module_name: &str,
757 function: &str,
758 args: &[shape_value::ValueWord],
759 ) -> Result<shape_value::ValueWord> {
760 if let Some(ctx) = self.runtime.persistent_context() {
761 ctx.invoke_extension_module_nb(module_name, function, args)
762 } else {
763 Err(shape_ast::error::ShapeError::RuntimeError {
764 message: "No runtime context available".to_string(),
765 location: None,
766 })
767 }
768 }
769
770 pub fn invoke_extension_module_wire(
772 &self,
773 module_name: &str,
774 function: &str,
775 args: &[shape_wire::WireValue],
776 ) -> Result<shape_wire::WireValue> {
777 if let Some(ctx) = self.runtime.persistent_context() {
778 ctx.invoke_extension_module_wire(module_name, function, args)
779 } else {
780 Err(shape_ast::error::ShapeError::RuntimeError {
781 message: "No runtime context available".to_string(),
782 location: None,
783 })
784 }
785 }
786
787 pub fn enable_progress_tracking(
808 &mut self,
809 ) -> std::sync::Arc<crate::progress::ProgressRegistry> {
810 let registry = crate::progress::ProgressRegistry::new();
812 if let Some(ctx) = self.runtime.persistent_context_mut() {
813 ctx.set_progress_registry(registry.clone());
814 }
815 registry
816 }
817
818 pub fn progress_registry(&self) -> Option<std::sync::Arc<crate::progress::ProgressRegistry>> {
820 self.runtime
821 .persistent_context()
822 .and_then(|ctx| ctx.progress_registry())
823 .cloned()
824 }
825
826 pub fn has_pending_progress(&self) -> bool {
828 if let Some(registry) = self.progress_registry() {
829 !registry.is_empty()
830 } else {
831 false
832 }
833 }
834
835 pub fn poll_progress(&self) -> Option<crate::progress::ProgressEvent> {
839 self.progress_registry()
840 .and_then(|registry| registry.try_recv())
841 }
842}
843
844impl Default for ShapeEngine {
845 fn default() -> Self {
846 Self::new().expect("Failed to create default Shape engine")
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853 use crate::extensions::{ParsedModuleArtifact, ParsedModuleSchema};
854
855 #[test]
856 fn test_register_extension_modules_registers_module_loader_artifacts() {
857 let mut engine = ShapeEngine::new().expect("engine should create");
858
859 engine.register_extension_modules(&[ParsedModuleSchema {
860 module_name: "duckdb".to_string(),
861 functions: Vec::new(),
862 artifacts: vec![ParsedModuleArtifact {
863 module_path: "duckdb".to_string(),
864 source: Some("pub fn connect(uri) { uri }".to_string()),
865 compiled: None,
866 }],
867 }]);
868
869 let mut loader = engine.runtime.configured_module_loader();
870 let module = loader
871 .load_module("duckdb")
872 .expect("registered extension module artifact should load");
873 assert!(
874 module.exports.contains_key("connect"),
875 "expected connect export"
876 );
877 }
878}