shape_runtime/
extension_context.rs1use crate::extensions::ParsedModuleSchema;
8use crate::frontmatter::parse_frontmatter;
9use crate::module_loader::{ModuleCode, ModuleLoader};
10use crate::project::find_project_root;
11use crate::provider_registry::ProviderRegistry;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex, OnceLock};
15
16#[derive(Debug, Clone)]
17pub struct ExtensionModuleSpec {
18 pub name: String,
19 pub path: PathBuf,
20 pub config: serde_json::Value,
21 pub extension_sections: HashMap<String, toml::Value>,
23}
24
25static EXTENSION_MODULE_SCHEMA_CACHE: OnceLock<Mutex<HashMap<String, Option<ParsedModuleSchema>>>> =
26 OnceLock::new();
27
28pub fn declared_extension_specs_for_context(
32 current_file: Option<&Path>,
33 workspace_root: Option<&Path>,
34 current_source: Option<&str>,
35) -> Vec<ExtensionModuleSpec> {
36 let mut by_name: HashMap<String, ExtensionModuleSpec> = HashMap::new();
37
38 if let Some(source) = current_source {
39 let (frontmatter, _) = parse_frontmatter(source);
40 if let Some(frontmatter) = frontmatter {
41 let base_dir = current_file
42 .and_then(Path::parent)
43 .map(Path::to_path_buf)
44 .or_else(|| std::env::current_dir().ok())
45 .unwrap_or_else(|| PathBuf::from("."));
46 for extension in frontmatter.extensions {
47 let config = extension.config_as_json();
48 let resolved_path = if extension.path.is_absolute() {
49 extension.path.clone()
50 } else {
51 base_dir.join(&extension.path)
52 };
53 by_name.insert(
54 extension.name.clone(),
55 ExtensionModuleSpec {
56 name: extension.name,
57 path: resolved_path,
58 config,
59 extension_sections: frontmatter.extension_sections.clone(),
60 },
61 );
62 }
63 }
64 }
65
66 let project = current_file
67 .and_then(|file| file.parent())
68 .and_then(find_project_root)
69 .or_else(|| workspace_root.and_then(find_project_root));
70 if let Some(project) = project {
71 for extension in project.config.extensions {
72 by_name.entry(extension.name.clone()).or_insert_with(|| {
73 let config = extension.config_as_json();
74 let resolved_path = if extension.path.is_absolute() {
75 extension.path.clone()
76 } else {
77 project.root_path.join(&extension.path)
78 };
79 ExtensionModuleSpec {
80 name: extension.name,
81 path: resolved_path,
82 config,
83 extension_sections: project.config.extension_sections.clone(),
84 }
85 });
86 }
87 }
88
89 let mut specs: Vec<ExtensionModuleSpec> = by_name.into_values().collect();
90 specs.sort_by(|left, right| left.name.cmp(&right.name));
91 specs
92}
93
94pub fn declared_extension_spec_for_module(
96 module_name: &str,
97 current_file: Option<&Path>,
98 workspace_root: Option<&Path>,
99 current_source: Option<&str>,
100) -> Option<ExtensionModuleSpec> {
101 declared_extension_specs_for_context(current_file, workspace_root, current_source)
102 .into_iter()
103 .find(|spec| spec.name == module_name)
104}
105
106pub fn extension_module_schema_for_spec(spec: &ExtensionModuleSpec) -> Option<ParsedModuleSchema> {
108 if !spec.path.exists() {
109 return None;
110 }
111
112 let canonical = spec
113 .path
114 .canonicalize()
115 .unwrap_or_else(|_| spec.path.clone())
116 .to_string_lossy()
117 .to_string();
118 let config_key = serde_json::to_string(&spec.config).unwrap_or_default();
119 let key = format!("{}|{}|{}", spec.name, canonical, config_key);
120
121 let cache = EXTENSION_MODULE_SCHEMA_CACHE.get_or_init(|| Mutex::new(HashMap::new()));
122 if let Ok(guard) = cache.lock()
123 && let Some(cached) = guard.get(&key)
124 {
125 return cached.clone();
126 }
127
128 let schema = {
129 let registry = ProviderRegistry::new();
130 match registry.load_extension(&spec.path, &spec.config) {
131 Ok(_) => registry
132 .get_extension_module_schema(&spec.name)
133 .or_else(|| {
134 registry
135 .list_extensions()
136 .first()
137 .and_then(|name| registry.get_extension_module_schema(name))
138 }),
139 Err(_) => None,
140 }
141 };
142
143 if let Ok(mut guard) = cache.lock() {
144 guard.insert(key, schema.clone());
145 }
146
147 schema
148}
149
150pub fn extension_module_schema_for_context(
152 module_name: &str,
153 current_file: Option<&Path>,
154 workspace_root: Option<&Path>,
155 current_source: Option<&str>,
156) -> Option<ParsedModuleSchema> {
157 let spec = declared_extension_spec_for_module(
158 module_name,
159 current_file,
160 workspace_root,
161 current_source,
162 )?;
163 extension_module_schema_for_spec(&spec)
164}
165
166pub fn register_declared_extensions_in_loader(
168 loader: &mut ModuleLoader,
169 current_file: Option<&Path>,
170 workspace_root: Option<&Path>,
171 current_source: Option<&str>,
172) {
173 for spec in declared_extension_specs_for_context(current_file, workspace_root, current_source) {
174 let Some(schema) = extension_module_schema_for_spec(&spec) else {
175 continue;
176 };
177 for artifact in schema.artifacts {
178 let code = match (artifact.source, artifact.compiled) {
179 (Some(source), Some(compiled)) => ModuleCode::Both {
180 source: Arc::from(source.as_str()),
181 compiled: Arc::from(compiled),
182 },
183 (Some(source), None) => ModuleCode::Source(Arc::from(source.as_str())),
184 (None, Some(compiled)) => ModuleCode::Compiled(Arc::from(compiled)),
185 (None, None) => continue,
186 };
187 loader.register_extension_module(artifact.module_path, code);
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_declared_extension_spec_for_module_uses_project_config() {
198 let tmp = tempfile::tempdir().expect("temp dir");
199 let root = tmp.path();
200 std::fs::create_dir_all(root.join("src")).expect("create src");
201 std::fs::write(
202 root.join("shape.toml"),
203 r#"
204[[extensions]]
205name = "proj_ext_unique_for_test"
206path = "./extensions/libproj.so"
207"#,
208 )
209 .expect("write shape.toml");
210 std::fs::write(root.join("src/main.shape"), "use proj_ext_unique_for_test")
211 .expect("write main");
212
213 let spec = declared_extension_spec_for_module(
214 "proj_ext_unique_for_test",
215 Some(&root.join("src/main.shape")),
216 None,
217 None,
218 )
219 .expect("project extension should be discovered");
220
221 assert_eq!(spec.name, "proj_ext_unique_for_test");
222 assert_eq!(spec.path, root.join("extensions/libproj.so"));
223 }
224
225 #[test]
226 fn test_declared_extension_specs_frontmatter_overrides_project() {
227 let tmp = tempfile::tempdir().expect("temp dir");
228 let root = tmp.path();
229 std::fs::create_dir_all(root.join("src")).expect("create src");
230 std::fs::write(
231 root.join("shape.toml"),
232 r#"
233[[extensions]]
234name = "duckdb"
235path = "./project/libproject.so"
236"#,
237 )
238 .expect("write shape.toml");
239 std::fs::write(root.join("src/main.shape"), "use duckdb").expect("write main");
240
241 let source = r#"---
242[[extensions]]
243name = "duckdb"
244path = "./frontmatter/libfront.so"
245---
246use duckdb
247"#;
248
249 let spec = declared_extension_spec_for_module(
250 "duckdb",
251 Some(&root.join("src/main.shape")),
252 None,
253 Some(source),
254 )
255 .expect("frontmatter extension should be discovered");
256
257 assert_eq!(spec.path, root.join("src/frontmatter/libfront.so"));
258 }
259}