1mod builder;
8mod execution;
9mod query_extraction;
10mod stdlib;
11mod types;
12
13pub use crate::query_result::QueryType;
15pub use builder::ShapeEngineBuilder;
16use shape_value::KindedSlot;
17pub use types::{
18 EngineBootstrapState, ExecutionMetrics, ExecutionResult, ExecutionType, Message, MessageLevel,
19};
20
21use crate::Runtime;
22use crate::data::DataFrame;
23use shape_ast::error::{Result, ShapeError};
24
25#[cfg(feature = "jit")]
26use std::collections::HashMap;
27
28use crate::hashing::HashDigest;
29use crate::snapshot::{ContextSnapshot, ExecutionSnapshot, SemanticSnapshot, SnapshotStore};
30use serde::Serialize;
31use shape_ast::Program;
32use shape_wire::WireValue;
33
34pub trait ExpressionEvaluator: Send + Sync {
44 fn eval_statements(
46 &self,
47 stmts: &[shape_ast::Statement],
48 ctx: &mut crate::context::ExecutionContext,
49 ) -> Result<KindedSlot>;
50
51 fn eval_expr(
53 &self,
54 expr: &shape_ast::Expr,
55 ctx: &mut crate::context::ExecutionContext,
56 ) -> Result<KindedSlot>;
57}
58
59pub struct ProgramExecutorResult {
61 pub wire_value: WireValue,
62 pub type_info: Option<shape_wire::metadata::TypeInfo>,
63 pub execution_type: ExecutionType,
64 pub content_json: Option<serde_json::Value>,
65 pub content_html: Option<String>,
66 pub content_terminal: Option<String>,
67}
68
69pub trait ProgramExecutor {
71 fn execute_program(
72 &mut self,
73 engine: &mut ShapeEngine,
74 program: &Program,
75 ) -> Result<ProgramExecutorResult>;
76}
77
78pub struct ShapeEngine {
80 pub runtime: Runtime,
82 pub default_data: DataFrame,
84 #[cfg(feature = "jit")]
86 pub(crate) jit_cache: HashMap<u64, ()>,
87 pub(crate) current_source: Option<String>,
89 pub(crate) snapshot_store: Option<SnapshotStore>,
91 pub(crate) last_snapshot: Option<HashDigest>,
93 pub(crate) script_path: Option<String>,
95 pub(crate) exported_symbols: std::collections::HashSet<String>,
97}
98
99impl ShapeEngine {
100 pub fn new() -> Result<Self> {
102 let mut runtime = Runtime::new_without_stdlib();
103 runtime.enable_persistent_context_without_data();
104
105 Ok(Self {
106 runtime,
107 default_data: DataFrame::default(),
108 #[cfg(feature = "jit")]
109 jit_cache: HashMap::new(),
110 current_source: None,
111 snapshot_store: None,
112 last_snapshot: None,
113 script_path: None,
114 exported_symbols: std::collections::HashSet::new(),
115 })
116 }
117
118 pub fn with_data(data: DataFrame) -> Result<Self> {
120 let mut runtime = Runtime::new_without_stdlib();
121 runtime.enable_persistent_context(&data);
122 Ok(Self {
123 runtime,
124 default_data: data,
125 #[cfg(feature = "jit")]
126 jit_cache: HashMap::new(),
127 current_source: None,
128 snapshot_store: None,
129 last_snapshot: None,
130 script_path: None,
131 exported_symbols: std::collections::HashSet::new(),
132 })
133 }
134
135 pub fn with_async_provider(provider: crate::data::SharedAsyncProvider) -> Result<Self> {
140 let runtime_handle = tokio::runtime::Handle::try_current()
141 .map_err(|_| ShapeError::RuntimeError {
142 message: "No tokio runtime available. Ensure with_async_provider is called within a tokio context.".to_string(),
143 location: None,
144 })?;
145 let mut runtime = Runtime::new_without_stdlib();
146
147 let ctx = crate::context::ExecutionContext::with_async_provider(provider, runtime_handle);
149 runtime.set_persistent_context(ctx);
150
151 Ok(Self {
152 runtime,
153 default_data: DataFrame::default(),
154 #[cfg(feature = "jit")]
155 jit_cache: HashMap::new(),
156 current_source: None,
157 snapshot_store: None,
158 last_snapshot: None,
159 script_path: None,
160 exported_symbols: std::collections::HashSet::new(),
161 })
162 }
163
164 pub fn init_repl(&mut self) {
170 if let Some(ctx) = self.runtime.persistent_context_mut() {
172 ctx.set_output_adapter(Box::new(crate::output_adapter::ReplAdapter));
173 }
174 }
175
176 pub fn capture_bootstrap_state(&self) -> Result<EngineBootstrapState> {
180 let context =
181 self.runtime
182 .persistent_context()
183 .cloned()
184 .ok_or_else(|| ShapeError::RuntimeError {
185 message: "No persistent context available for bootstrap capture".to_string(),
186 location: None,
187 })?;
188 Ok(EngineBootstrapState {
189 semantic: SemanticSnapshot {
190 exported_symbols: self.exported_symbols.clone(),
191 },
192 context,
193 })
194 }
195
196 pub fn apply_bootstrap_state(&mut self, state: &EngineBootstrapState) {
198 self.exported_symbols = state.semantic.exported_symbols.clone();
199 self.runtime.set_persistent_context(state.context.clone());
200 }
201
202 pub fn set_script_path(&mut self, path: impl Into<String>) {
204 self.script_path = Some(path.into());
205 }
206
207 pub fn script_path(&self) -> Option<&str> {
209 self.script_path.as_deref()
210 }
211
212 pub fn enable_snapshot_store(&mut self, store: SnapshotStore) {
214 self.snapshot_store = Some(store);
215 }
216
217 pub fn last_snapshot(&self) -> Option<&HashDigest> {
219 self.last_snapshot.as_ref()
220 }
221
222 pub fn snapshot_store(&self) -> Option<&SnapshotStore> {
224 self.snapshot_store.as_ref()
225 }
226
227 pub fn store_snapshot_blob<T: Serialize>(&self, value: &T) -> Result<HashDigest> {
229 let store = self
230 .snapshot_store
231 .as_ref()
232 .ok_or_else(|| ShapeError::RuntimeError {
233 message: "Snapshot store not configured".to_string(),
234 location: None,
235 })?;
236 Ok(store.put_struct(value)?)
237 }
238
239 pub fn snapshot_with_hashes(
241 &mut self,
242 vm_hash: Option<HashDigest>,
243 bytecode_hash: Option<HashDigest>,
244 ) -> Result<HashDigest> {
245 let store = self
246 .snapshot_store
247 .as_ref()
248 .ok_or_else(|| ShapeError::RuntimeError {
249 message: "Snapshot store not configured".to_string(),
250 location: None,
251 })?;
252
253 let semantic = SemanticSnapshot {
254 exported_symbols: self.exported_symbols.clone(),
255 };
256 let semantic_hash = store.put_struct(&semantic)?;
257
258 let context = if let Some(ctx) = self.runtime.persistent_context() {
259 ctx.snapshot(store)?
260 } else {
261 return Err(ShapeError::RuntimeError {
262 message: "No persistent context for snapshot".to_string(),
263 location: None,
264 });
265 };
266 let context_hash = store.put_struct(&context)?;
267
268 let snapshot = ExecutionSnapshot {
269 version: crate::snapshot::SNAPSHOT_VERSION,
270 created_at_ms: chrono::Utc::now().timestamp_millis(),
271 semantic_hash,
272 context_hash,
273 vm_hash,
274 bytecode_hash,
275 script_path: self.script_path.clone(),
276 };
277
278 let snapshot_hash = store.put_snapshot(&snapshot)?;
279 self.last_snapshot = Some(snapshot_hash.clone());
280 Ok(snapshot_hash)
281 }
282
283 pub fn load_snapshot(
285 &self,
286 snapshot_id: &HashDigest,
287 ) -> Result<(
288 SemanticSnapshot,
289 ContextSnapshot,
290 Option<HashDigest>,
291 Option<HashDigest>,
292 )> {
293 let store = self
294 .snapshot_store
295 .as_ref()
296 .ok_or_else(|| ShapeError::RuntimeError {
297 message: "Snapshot store not configured".to_string(),
298 location: None,
299 })?;
300 let snapshot = store.get_snapshot(snapshot_id)?;
301 let semantic: SemanticSnapshot =
302 store
303 .get_struct(&snapshot.semantic_hash)
304 .map_err(|e| ShapeError::RuntimeError {
305 message: format!("failed to deserialize SemanticSnapshot: {e}"),
306 location: None,
307 })?;
308 let context: ContextSnapshot =
309 store
310 .get_struct(&snapshot.context_hash)
311 .map_err(|e| ShapeError::RuntimeError {
312 message: format!("failed to deserialize ContextSnapshot: {e}"),
313 location: None,
314 })?;
315 Ok((semantic, context, snapshot.vm_hash, snapshot.bytecode_hash))
316 }
317
318 pub fn apply_snapshot(
320 &mut self,
321 semantic: SemanticSnapshot,
322 context: ContextSnapshot,
323 ) -> Result<()> {
324 let _scope = self.runtime.enter_schema_scope();
329
330 self.exported_symbols = semantic.exported_symbols;
331 if let Some(ctx) = self.runtime.persistent_context_mut() {
332 let store = self
333 .snapshot_store
334 .as_ref()
335 .ok_or_else(|| ShapeError::RuntimeError {
336 message: "Snapshot store not configured".to_string(),
337 location: None,
338 })?;
339 ctx.restore_from_snapshot(context, store)?;
340 Ok(())
341 } else {
342 Err(ShapeError::RuntimeError {
343 message: "No persistent context for snapshot".to_string(),
344 location: None,
345 })
346 }
347 }
348
349 pub fn register_extension_modules(
352 &mut self,
353 modules: &[crate::extensions::ParsedModuleSchema],
354 ) {
355 self.runtime.register_extension_module_artifacts(modules);
356 }
357
358 pub fn register_language_runtime_artifacts(&mut self) {
367 let runtimes = self.language_runtimes();
368 let mut schemas = Vec::new();
369 for (_lang_id, runtime) in &runtimes {
370 match runtime.shape_source() {
371 Ok(Some((namespace, source))) => {
372 schemas.push(crate::extensions::ParsedModuleSchema {
373 module_name: namespace.clone(),
374 functions: Vec::new(),
375 artifacts: vec![crate::extensions::ParsedModuleArtifact {
376 module_path: namespace,
377 source: Some(source),
378 compiled: None,
379 }],
380 });
381 }
382 Ok(None) => {}
383 Err(e) => {
384 tracing::warn!(
385 "Failed to get shape source from language runtime: {}",
386 e
387 );
388 }
389 }
390 }
391 if !schemas.is_empty() {
392 self.runtime.register_extension_module_artifacts(&schemas);
393 }
394 }
395
396 pub fn set_source(&mut self, source: &str) {
401 self.current_source = Some(source.to_string());
402 }
403
404 pub fn current_source(&self) -> Option<&str> {
406 self.current_source.as_deref()
407 }
408
409 pub fn register_provider(&mut self, name: &str, provider: crate::data::SharedAsyncProvider) {
420 if let Some(ctx) = self.runtime.persistent_context_mut() {
421 ctx.register_provider(name, provider);
422 }
423 }
424
425 pub fn set_default_provider(&mut self, name: &str) -> Result<()> {
429 if let Some(ctx) = self.runtime.persistent_context_mut() {
430 ctx.set_default_provider(name)
431 } else {
432 Err(ShapeError::RuntimeError {
433 message: "No execution context available".to_string(),
434 location: None,
435 })
436 }
437 }
438
439 pub fn register_type_mapping(
466 &mut self,
467 type_name: &str,
468 mapping: crate::type_mapping::TypeMapping,
469 ) {
470 if let Some(ctx) = self.runtime.persistent_context_mut() {
471 ctx.register_type_mapping(type_name, mapping);
472 }
473 }
474
475 pub fn get_runtime(&self) -> &Runtime {
477 &self.runtime
478 }
479
480 pub fn get_runtime_mut(&mut self) -> &mut Runtime {
482 &mut self.runtime
483 }
484
485 pub fn get_variable_format_hint(&self, name: &str) -> Option<String> {
490 self.runtime
491 .persistent_context()
492 .and_then(|ctx| ctx.get_variable_format_hint(name))
493 }
494
495 pub fn format_value_string(
527 &mut self,
528 value: f64,
529 type_name: &str,
530 format_name: Option<&str>,
531 params: &std::collections::HashMap<String, serde_json::Value>,
532 ) -> Result<String> {
533 let (resolved_type_name, merged_params) =
535 self.resolve_type_alias_for_formatting(type_name, params)?;
536
537 let param_values: std::collections::HashMap<String, WireValue> = merged_params
544 .iter()
545 .map(|(k, v)| {
546 let runtime_val = match v {
547 serde_json::Value::Number(n) => WireValue::Number(n.as_f64().unwrap_or(0.0)),
548 serde_json::Value::String(s) => WireValue::String(s.clone()),
549 serde_json::Value::Bool(b) => WireValue::Bool(*b),
550 _ => WireValue::Null,
551 };
552 (k.clone(), runtime_val)
553 })
554 .collect();
555
556 let runtime_value = WireValue::Number(value);
557
558 self.runtime.format_value(
559 runtime_value,
560 resolved_type_name.as_str(),
561 format_name,
562 param_values,
563 )
564 }
565
566 fn resolve_type_alias_for_formatting(
571 &self,
572 type_name: &str,
573 params: &std::collections::HashMap<String, serde_json::Value>,
574 ) -> Result<(String, std::collections::HashMap<String, serde_json::Value>)> {
575 let resolved = self
577 .runtime
578 .persistent_context()
579 .map(|ctx| ctx.resolve_type_for_format(type_name));
580
581 if let Some((base_type, Some(_overrides))) = resolved {
582 if base_type != type_name {
583 return Ok((base_type, params.clone()));
591 }
592 }
593
594 Ok((type_name.to_string(), params.clone()))
596 }
597
598 pub fn load_extension(
624 &mut self,
625 path: &std::path::Path,
626 config: &serde_json::Value,
627 ) -> Result<crate::extensions::LoadedExtension> {
628 if let Some(ctx) = self.runtime.persistent_context_mut() {
629 ctx.load_extension(path, config)
630 } else {
631 Err(ShapeError::RuntimeError {
632 message: "No execution context available for extension loading".to_string(),
633 location: None,
634 })
635 }
636 }
637
638 pub fn unload_extension(&mut self, name: &str) -> bool {
648 if let Some(ctx) = self.runtime.persistent_context_mut() {
649 ctx.unload_extension(name)
650 } else {
651 false
652 }
653 }
654
655 pub fn list_extensions(&self) -> Vec<String> {
657 if let Some(ctx) = self.runtime.persistent_context() {
658 ctx.list_extensions()
659 } else {
660 Vec::new()
661 }
662 }
663
664 pub fn get_extension_query_schema(
674 &self,
675 name: &str,
676 ) -> Option<crate::extensions::ParsedQuerySchema> {
677 if let Some(ctx) = self.runtime.persistent_context() {
678 ctx.get_extension_query_schema(name)
679 } else {
680 None
681 }
682 }
683
684 pub fn get_extension_output_schema(
694 &self,
695 name: &str,
696 ) -> Option<crate::extensions::ParsedOutputSchema> {
697 if let Some(ctx) = self.runtime.persistent_context() {
698 ctx.get_extension_output_schema(name)
699 } else {
700 None
701 }
702 }
703
704 pub fn get_extension(
706 &self,
707 name: &str,
708 ) -> Option<std::sync::Arc<crate::extensions::ExtensionDataSource>> {
709 if let Some(ctx) = self.runtime.persistent_context() {
710 ctx.get_extension(name)
711 } else {
712 None
713 }
714 }
715
716 pub fn get_extension_module_schema(
718 &self,
719 module_name: &str,
720 ) -> Option<crate::extensions::ParsedModuleSchema> {
721 if let Some(ctx) = self.runtime.persistent_context() {
722 ctx.get_extension_module_schema(module_name)
723 } else {
724 None
725 }
726 }
727
728 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
730 if let Some(ctx) = self.runtime.persistent_context() {
731 ctx.module_exports_from_extensions()
732 } else {
733 Vec::new()
734 }
735 }
736
737 pub fn language_runtimes(
739 &self,
740 ) -> std::collections::HashMap<
741 String,
742 std::sync::Arc<crate::plugins::language_runtime::PluginLanguageRuntime>,
743 > {
744 if let Some(ctx) = self.runtime.persistent_context() {
745 ctx.language_runtimes()
746 } else {
747 std::collections::HashMap::new()
748 }
749 }
750
751 pub fn invoke_extension_module_wire(
753 &self,
754 module_name: &str,
755 function: &str,
756 args: &[shape_wire::WireValue],
757 ) -> Result<shape_wire::WireValue> {
758 if let Some(ctx) = self.runtime.persistent_context() {
759 ctx.invoke_extension_module_wire(module_name, function, args)
760 } else {
761 Err(shape_ast::error::ShapeError::RuntimeError {
762 message: "No runtime context available".to_string(),
763 location: None,
764 })
765 }
766 }
767
768 pub fn enable_progress_tracking(
789 &mut self,
790 ) -> std::sync::Arc<crate::progress::ProgressRegistry> {
791 let registry = crate::progress::ProgressRegistry::new();
793 if let Some(ctx) = self.runtime.persistent_context_mut() {
794 ctx.set_progress_registry(registry.clone());
795 }
796 registry
797 }
798
799 pub fn progress_registry(&self) -> Option<std::sync::Arc<crate::progress::ProgressRegistry>> {
801 self.runtime
802 .persistent_context()
803 .and_then(|ctx| ctx.progress_registry())
804 .cloned()
805 }
806
807 pub fn has_pending_progress(&self) -> bool {
809 if let Some(registry) = self.progress_registry() {
810 !registry.is_empty()
811 } else {
812 false
813 }
814 }
815
816 pub fn poll_progress(&self) -> Option<crate::progress::ProgressEvent> {
820 self.progress_registry()
821 .and_then(|registry| registry.try_recv())
822 }
823}
824
825impl Default for ShapeEngine {
826 fn default() -> Self {
827 Self::new().expect("Failed to create default Shape engine")
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834 use crate::extensions::{ParsedModuleArtifact, ParsedModuleSchema};
835
836 #[test]
837 fn test_register_extension_modules_registers_module_loader_artifacts() {
838 let mut engine = ShapeEngine::new().expect("engine should create");
839
840 engine.register_extension_modules(&[ParsedModuleSchema {
841 module_name: "duckdb".to_string(),
842 functions: Vec::new(),
843 artifacts: vec![ParsedModuleArtifact {
844 module_path: "duckdb".to_string(),
845 source: Some("pub fn connect(uri) { uri }".to_string()),
846 compiled: None,
847 }],
848 }]);
849
850 let mut loader = engine.runtime.configured_module_loader();
851 let module = loader
852 .load_module("duckdb")
853 .expect("registered extension module artifact should load");
854 assert!(
855 module.exports.contains_key("connect"),
856 "expected connect export"
857 );
858 }
859}