Skip to main content

systemprompt_loader/
extension_loader.rs

1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6use systemprompt_models::{DiscoveredExtension, ExtensionManifest};
7
8const CARGO_TARGET: &str = "target";
9
10#[derive(Debug, Clone, Copy)]
11pub struct ExtensionLoader;
12
13impl ExtensionLoader {
14    pub fn discover(project_root: &Path) -> Vec<DiscoveredExtension> {
15        let extensions_dir = project_root.join("extensions");
16
17        if !extensions_dir.exists() {
18            return vec![];
19        }
20
21        let mut discovered = vec![];
22
23        Self::scan_directory(&extensions_dir, &mut discovered);
24
25        if let Ok(entries) = fs::read_dir(&extensions_dir) {
26            for entry in entries.flatten() {
27                let path = entry.path();
28                if path.is_dir() {
29                    Self::scan_directory(&path, &mut discovered);
30                }
31            }
32        }
33
34        discovered
35    }
36
37    fn scan_directory(dir: &Path, discovered: &mut Vec<DiscoveredExtension>) {
38        let Ok(entries) = fs::read_dir(dir) else {
39            return;
40        };
41
42        for entry in entries.flatten() {
43            let ext_dir = entry.path();
44            if !ext_dir.is_dir() {
45                continue;
46            }
47
48            let manifest_path = ext_dir.join("manifest.yaml");
49            if manifest_path.exists() {
50                match Self::load_manifest(&manifest_path) {
51                    Ok(manifest) => {
52                        discovered.push(DiscoveredExtension::new(manifest, ext_dir, manifest_path));
53                    },
54                    Err(e) => {
55                        tracing::warn!(
56                            path = %manifest_path.display(),
57                            error = %e,
58                            "Failed to parse extension manifest, skipping"
59                        );
60                    },
61                }
62            }
63        }
64    }
65
66    fn load_manifest(path: &Path) -> Result<ExtensionManifest> {
67        let content = fs::read_to_string(path)
68            .with_context(|| format!("Failed to read manifest: {}", path.display()))?;
69
70        serde_yaml::from_str(&content)
71            .with_context(|| format!("Failed to parse manifest: {}", path.display()))
72    }
73
74    pub fn get_enabled_mcp_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
75        Self::discover(project_root)
76            .into_iter()
77            .filter(|e| e.is_mcp() && e.is_enabled())
78            .collect()
79    }
80
81    pub fn get_enabled_cli_extensions(project_root: &Path) -> Vec<DiscoveredExtension> {
82        Self::discover(project_root)
83            .into_iter()
84            .filter(|e| e.is_cli() && e.is_enabled())
85            .collect()
86    }
87
88    pub fn find_cli_extension(project_root: &Path, name: &str) -> Option<DiscoveredExtension> {
89        Self::get_enabled_cli_extensions(project_root)
90            .into_iter()
91            .find(|e| {
92                e.binary_name()
93                    .is_some_and(|b| b == name || e.manifest.extension.name == name)
94            })
95    }
96
97    pub fn get_cli_binary_path(
98        project_root: &Path,
99        binary_name: &str,
100    ) -> Option<std::path::PathBuf> {
101        let release_path = project_root
102            .join(CARGO_TARGET)
103            .join("release")
104            .join(binary_name);
105        if release_path.exists() {
106            return Some(release_path);
107        }
108
109        let debug_path = project_root
110            .join(CARGO_TARGET)
111            .join("debug")
112            .join(binary_name);
113        if debug_path.exists() {
114            return Some(debug_path);
115        }
116
117        None
118    }
119
120    pub fn resolve_bin_directory(
121        project_root: &Path,
122        override_path: Option<&Path>,
123    ) -> std::path::PathBuf {
124        if let Some(path) = override_path {
125            return path.to_path_buf();
126        }
127
128        let release_dir = project_root.join(CARGO_TARGET).join("release");
129        let debug_dir = project_root.join(CARGO_TARGET).join("debug");
130
131        let release_binary = release_dir.join("systemprompt");
132        let debug_binary = debug_dir.join("systemprompt");
133
134        match (release_binary.exists(), debug_binary.exists()) {
135            (true, true) => {
136                let release_mtime = fs::metadata(&release_binary)
137                    .and_then(|m| m.modified())
138                    .ok();
139                let debug_mtime = fs::metadata(&debug_binary).and_then(|m| m.modified()).ok();
140
141                match (release_mtime, debug_mtime) {
142                    (Some(r), Some(d)) if d > r => debug_dir,
143                    _ => release_dir,
144                }
145            },
146            (true | false, false) => release_dir,
147            (false, true) => debug_dir,
148        }
149    }
150
151    pub fn validate_mcp_binaries(project_root: &Path) -> Vec<(String, std::path::PathBuf)> {
152        let extensions = Self::get_enabled_mcp_extensions(project_root);
153        let target_dir = project_root.join(CARGO_TARGET).join("release");
154
155        extensions
156            .into_iter()
157            .filter_map(|ext| {
158                ext.binary_name().and_then(|binary| {
159                    let binary_path = target_dir.join(binary);
160                    if binary_path.exists() {
161                        None
162                    } else {
163                        Some((binary.to_string(), ext.path.clone()))
164                    }
165                })
166            })
167            .collect()
168    }
169
170    pub fn get_mcp_binary_names(project_root: &Path) -> Vec<String> {
171        Self::get_enabled_mcp_extensions(project_root)
172            .iter()
173            .filter_map(|e| e.binary_name().map(String::from))
174            .collect()
175    }
176
177    pub fn get_production_mcp_binary_names(
178        project_root: &Path,
179        services_config: &systemprompt_models::ServicesConfig,
180    ) -> Vec<String> {
181        Self::get_enabled_mcp_extensions(project_root)
182            .iter()
183            .filter_map(|e| {
184                let binary = e.binary_name()?;
185                let is_dev_only = services_config
186                    .mcp_servers
187                    .values()
188                    .find(|d| d.binary == binary)
189                    .is_some_and(|d| d.dev_only);
190                (!is_dev_only).then(|| binary.to_string())
191            })
192            .collect()
193    }
194
195    pub fn build_binary_map(project_root: &Path) -> HashMap<String, DiscoveredExtension> {
196        Self::discover(project_root)
197            .into_iter()
198            .filter_map(|ext| {
199                let name = ext.binary_name()?.to_string();
200                Some((name, ext))
201            })
202            .collect()
203    }
204
205    pub fn validate(project_root: &Path) -> ExtensionValidationResult {
206        ExtensionValidationResult {
207            discovered: Self::discover(project_root),
208            missing_binaries: Self::validate_mcp_binaries(project_root),
209            missing_manifests: vec![],
210        }
211    }
212}
213
214#[derive(Debug)]
215pub struct ExtensionValidationResult {
216    pub discovered: Vec<DiscoveredExtension>,
217    pub missing_binaries: Vec<(String, std::path::PathBuf)>,
218    pub missing_manifests: Vec<std::path::PathBuf>,
219}
220
221impl ExtensionValidationResult {
222    pub fn is_valid(&self) -> bool {
223        self.missing_binaries.is_empty()
224    }
225
226    pub fn format_missing_binaries(&self) -> String {
227        self.missing_binaries
228            .iter()
229            .map(|(binary, path)| format!("  ✗ {} ({})", binary, path.display()))
230            .collect::<Vec<_>>()
231            .join("\n")
232    }
233}