Skip to main content

shape_runtime/
extension_context.rs

1//! Context-aware extension discovery and module-artifact registration.
2//!
3//! This module is the single source of truth for resolving declared
4//! `[[extensions]]` across frontmatter / project config and exposing
5//! extension module artifacts to the unified module loader.
6
7use crate::extensions::ParsedModuleSchema;
8use crate::frontmatter::parse_frontmatter;
9use crate::module_loader::{ModuleCode, ModuleLoader};
10use crate::project::find_project_root;
11use crate::provider_registry::ProviderRegistry;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex, OnceLock};
15
16#[derive(Debug, Clone)]
17pub struct ExtensionModuleSpec {
18    pub name: String,
19    pub path: PathBuf,
20    pub config: serde_json::Value,
21    /// Extension sections from the project config, available for section claims.
22    pub extension_sections: HashMap<String, toml::Value>,
23}
24
25static EXTENSION_MODULE_SCHEMA_CACHE: OnceLock<Mutex<HashMap<String, Option<ParsedModuleSchema>>>> =
26    OnceLock::new();
27
28/// Resolve declared extension module specs for the current context.
29///
30/// Precedence: frontmatter > shape.toml.
31pub fn declared_extension_specs_for_context(
32    current_file: Option<&Path>,
33    workspace_root: Option<&Path>,
34    current_source: Option<&str>,
35) -> Vec<ExtensionModuleSpec> {
36    let mut by_name: HashMap<String, ExtensionModuleSpec> = HashMap::new();
37
38    if let Some(source) = current_source {
39        let (frontmatter, _) = parse_frontmatter(source);
40        if let Some(frontmatter) = frontmatter {
41            let base_dir = current_file
42                .and_then(Path::parent)
43                .map(Path::to_path_buf)
44                .or_else(|| std::env::current_dir().ok())
45                .unwrap_or_else(|| PathBuf::from("."));
46            for extension in frontmatter.extensions {
47                let config = extension.config_as_json();
48                let resolved_path = if extension.path.is_absolute() {
49                    extension.path.clone()
50                } else {
51                    base_dir.join(&extension.path)
52                };
53                by_name.insert(
54                    extension.name.clone(),
55                    ExtensionModuleSpec {
56                        name: extension.name,
57                        path: resolved_path,
58                        config,
59                        extension_sections: frontmatter.extension_sections.clone(),
60                    },
61                );
62            }
63        }
64    }
65
66    let project = current_file
67        .and_then(|file| file.parent())
68        .and_then(find_project_root)
69        .or_else(|| workspace_root.and_then(find_project_root));
70    if let Some(project) = project {
71        for extension in project.config.extensions {
72            by_name.entry(extension.name.clone()).or_insert_with(|| {
73                let config = extension.config_as_json();
74                let resolved_path = if extension.path.is_absolute() {
75                    extension.path.clone()
76                } else {
77                    project.root_path.join(&extension.path)
78                };
79                ExtensionModuleSpec {
80                    name: extension.name,
81                    path: resolved_path,
82                    config,
83                    extension_sections: project.config.extension_sections.clone(),
84                }
85            });
86        }
87    }
88
89    let mut specs: Vec<ExtensionModuleSpec> = by_name.into_values().collect();
90    specs.sort_by(|left, right| left.name.cmp(&right.name));
91    specs
92}
93
94/// Resolve one declared extension module spec by module namespace.
95pub fn declared_extension_spec_for_module(
96    module_name: &str,
97    current_file: Option<&Path>,
98    workspace_root: Option<&Path>,
99    current_source: Option<&str>,
100) -> Option<ExtensionModuleSpec> {
101    declared_extension_specs_for_context(current_file, workspace_root, current_source)
102        .into_iter()
103        .find(|spec| spec.name == module_name)
104}
105
106/// Load one declared extension's `shape.module` schema with process-local caching.
107pub fn extension_module_schema_for_spec(spec: &ExtensionModuleSpec) -> Option<ParsedModuleSchema> {
108    if !spec.path.exists() {
109        return None;
110    }
111
112    let canonical = spec
113        .path
114        .canonicalize()
115        .unwrap_or_else(|_| spec.path.clone())
116        .to_string_lossy()
117        .to_string();
118    let config_key = serde_json::to_string(&spec.config).unwrap_or_default();
119    let key = format!("{}|{}|{}", spec.name, canonical, config_key);
120
121    let cache = EXTENSION_MODULE_SCHEMA_CACHE.get_or_init(|| Mutex::new(HashMap::new()));
122    if let Ok(guard) = cache.lock()
123        && let Some(cached) = guard.get(&key)
124    {
125        return cached.clone();
126    }
127
128    let schema = {
129        let registry = ProviderRegistry::new();
130        match registry.load_extension(&spec.path, &spec.config) {
131            Ok(_) => registry
132                .get_extension_module_schema(&spec.name)
133                .or_else(|| {
134                    registry
135                        .list_extensions()
136                        .first()
137                        .and_then(|name| registry.get_extension_module_schema(name))
138                }),
139            Err(_) => None,
140        }
141    };
142
143    if let Ok(mut guard) = cache.lock() {
144        guard.insert(key, schema.clone());
145    }
146
147    schema
148}
149
150/// Load one declared extension module schema by name for current context.
151pub fn extension_module_schema_for_context(
152    module_name: &str,
153    current_file: Option<&Path>,
154    workspace_root: Option<&Path>,
155    current_source: Option<&str>,
156) -> Option<ParsedModuleSchema> {
157    let spec = declared_extension_spec_for_module(
158        module_name,
159        current_file,
160        workspace_root,
161        current_source,
162    )?;
163    extension_module_schema_for_spec(&spec)
164}
165
166/// Register declared extension module artifacts into the given module loader.
167pub fn register_declared_extensions_in_loader(
168    loader: &mut ModuleLoader,
169    current_file: Option<&Path>,
170    workspace_root: Option<&Path>,
171    current_source: Option<&str>,
172) {
173    for spec in declared_extension_specs_for_context(current_file, workspace_root, current_source) {
174        let Some(schema) = extension_module_schema_for_spec(&spec) else {
175            continue;
176        };
177        for artifact in schema.artifacts {
178            let code = match (artifact.source, artifact.compiled) {
179                (Some(source), Some(compiled)) => ModuleCode::Both {
180                    source: Arc::from(source.as_str()),
181                    compiled: Arc::from(compiled),
182                },
183                (Some(source), None) => ModuleCode::Source(Arc::from(source.as_str())),
184                (None, Some(compiled)) => ModuleCode::Compiled(Arc::from(compiled)),
185                (None, None) => continue,
186            };
187            loader.register_extension_module(artifact.module_path, code);
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_declared_extension_spec_for_module_uses_project_config() {
198        let tmp = tempfile::tempdir().expect("temp dir");
199        let root = tmp.path();
200        std::fs::create_dir_all(root.join("src")).expect("create src");
201        std::fs::write(
202            root.join("shape.toml"),
203            r#"
204[[extensions]]
205name = "proj_ext_unique_for_test"
206path = "./extensions/libproj.so"
207"#,
208        )
209        .expect("write shape.toml");
210        std::fs::write(root.join("src/main.shape"), "use proj_ext_unique_for_test")
211            .expect("write main");
212
213        let spec = declared_extension_spec_for_module(
214            "proj_ext_unique_for_test",
215            Some(&root.join("src/main.shape")),
216            None,
217            None,
218        )
219        .expect("project extension should be discovered");
220
221        assert_eq!(spec.name, "proj_ext_unique_for_test");
222        assert_eq!(spec.path, root.join("extensions/libproj.so"));
223    }
224
225    #[test]
226    fn test_declared_extension_specs_frontmatter_overrides_project() {
227        let tmp = tempfile::tempdir().expect("temp dir");
228        let root = tmp.path();
229        std::fs::create_dir_all(root.join("src")).expect("create src");
230        std::fs::write(
231            root.join("shape.toml"),
232            r#"
233[[extensions]]
234name = "duckdb"
235path = "./project/libproject.so"
236"#,
237        )
238        .expect("write shape.toml");
239        std::fs::write(root.join("src/main.shape"), "use duckdb").expect("write main");
240
241        let source = r#"---
242[[extensions]]
243name = "duckdb"
244path = "./frontmatter/libfront.so"
245---
246use duckdb
247"#;
248
249        let spec = declared_extension_spec_for_module(
250            "duckdb",
251            Some(&root.join("src/main.shape")),
252            None,
253            Some(source),
254        )
255        .expect("frontmatter extension should be discovered");
256
257        assert_eq!(spec.path, root.join("src/frontmatter/libfront.so"));
258    }
259}