Skip to main content

kapsl_rag/extension/
mod.rs

1use std::collections::HashMap;
2use std::fs;
3use std::io;
4use std::path::{Path, PathBuf};
5
6use kapsl_rag_sdk::manifest::{ConnectorManifest, ConnectorRuntime as ManifestRuntime};
7use kapsl_rag_sdk::types::ConnectorConfig;
8use serde::Deserialize;
9use serde_json::Value;
10
11use crate::runtime::{
12    ConnectorClient, ConnectorRuntime as RuntimeTrait, SidecarConnectorRuntime, WasiPermissions,
13    WasmConnectorRuntime,
14};
15
16#[derive(thiserror::Error, Debug)]
17pub enum ExtensionError {
18    #[error("io error: {0}")]
19    Io(String),
20    #[error("manifest not found in {0}")]
21    ManifestMissing(String),
22    #[error("invalid manifest: {0}")]
23    InvalidManifest(String),
24    #[error("invalid config: {0}")]
25    InvalidConfig(String),
26    #[error("extension not installed: {0}")]
27    NotInstalled(String),
28    #[error("runtime error: {0}")]
29    Runtime(String),
30}
31
32impl From<io::Error> for ExtensionError {
33    fn from(err: io::Error) -> Self {
34        ExtensionError::Io(err.to_string())
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct InstalledExtension {
40    pub manifest: ConnectorManifest,
41    pub path: PathBuf,
42}
43
44#[derive(Debug, Clone)]
45pub struct ExtensionRegistry {
46    pub root: PathBuf,
47}
48
49impl ExtensionRegistry {
50    pub fn new(root: impl Into<PathBuf>) -> Self {
51        Self { root: root.into() }
52    }
53
54    pub fn discover(&self) -> Result<Vec<InstalledExtension>, ExtensionError> {
55        let mut extensions = Vec::new();
56        if !self.root.exists() {
57            return Ok(extensions);
58        }
59        for entry in fs::read_dir(&self.root)? {
60            let entry = entry?;
61            let path = entry.path();
62            if !path.is_dir() {
63                continue;
64            }
65            if let Ok(manifest) = load_manifest(&path) {
66                extensions.push(InstalledExtension { manifest, path });
67            }
68        }
69        Ok(extensions)
70    }
71
72    pub fn install_from_dir(&self, source: &Path) -> Result<InstalledExtension, ExtensionError> {
73        let manifest = load_manifest(source)?;
74        let target = self.root.join(&manifest.id);
75        if target.exists() {
76            fs::remove_dir_all(&target)?;
77        }
78        copy_dir_all(source, &target)?;
79        Ok(InstalledExtension {
80            manifest,
81            path: target,
82        })
83    }
84
85    pub fn uninstall(&self, extension_id: &str) -> Result<(), ExtensionError> {
86        let target = self.root.join(extension_id);
87        if !target.exists() {
88            return Err(ExtensionError::NotInstalled(extension_id.to_string()));
89        }
90        fs::remove_dir_all(target)?;
91        Ok(())
92    }
93}
94
95#[derive(Debug, Clone)]
96pub struct ExtensionManager {
97    pub registry: ExtensionRegistry,
98    pub config_root: PathBuf,
99}
100
101impl ExtensionManager {
102    pub fn new(registry: ExtensionRegistry, config_root: impl Into<PathBuf>) -> Self {
103        Self {
104            registry,
105            config_root: config_root.into(),
106        }
107    }
108
109    pub fn set_workspace_config(
110        &self,
111        workspace_id: &str,
112        extension_id: &str,
113        config: &ConnectorConfig,
114    ) -> Result<(), ExtensionError> {
115        let dir = self.config_root.join(workspace_id);
116        fs::create_dir_all(&dir)?;
117        let path = dir.join(format!("{extension_id}.json"));
118        let data = serde_json::to_vec_pretty(config)
119            .map_err(|e| ExtensionError::InvalidManifest(e.to_string()))?;
120        fs::write(path, data)?;
121        Ok(())
122    }
123
124    pub fn get_workspace_config(
125        &self,
126        workspace_id: &str,
127        extension_id: &str,
128    ) -> Result<Option<ConnectorConfig>, ExtensionError> {
129        let path = self
130            .config_root
131            .join(workspace_id)
132            .join(format!("{extension_id}.json"));
133        if !path.exists() {
134            return Ok(None);
135        }
136        let data = fs::read_to_string(path)?;
137        let config = serde_json::from_str(&data)
138            .map_err(|e| ExtensionError::InvalidManifest(e.to_string()))?;
139        Ok(Some(config))
140    }
141
142    pub fn get_workspace_connector_config(
143        &self,
144        workspace_id: &str,
145        extension_id: &str,
146    ) -> Result<Option<ConnectorConfig>, ExtensionError> {
147        let config = self.get_workspace_config(workspace_id, extension_id)?;
148        Ok(config.map(strip_wasi_block))
149    }
150
151    pub fn list_configs(
152        &self,
153        workspace_id: &str,
154    ) -> Result<HashMap<String, ConnectorConfig>, ExtensionError> {
155        let mut configs = HashMap::new();
156        let dir = self.config_root.join(workspace_id);
157        if !dir.exists() {
158            return Ok(configs);
159        }
160        for entry in fs::read_dir(&dir)? {
161            let entry = entry?;
162            let path = entry.path();
163            if path.extension().and_then(|s| s.to_str()) != Some("json") {
164                continue;
165            }
166            if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
167                let data = fs::read_to_string(&path)?;
168                let config = serde_json::from_str(&data)
169                    .map_err(|e| ExtensionError::InvalidManifest(e.to_string()))?;
170                configs.insert(stem.to_string(), config);
171            }
172        }
173        Ok(configs)
174    }
175
176    pub fn get_workspace_wasi_permissions(
177        &self,
178        workspace_id: &str,
179        extension_id: &str,
180    ) -> Result<WasiPermissions, ExtensionError> {
181        let config = self.get_workspace_config(workspace_id, extension_id)?;
182        wasi_permissions_from_config(config.as_ref())
183    }
184
185    pub fn launch_connector(
186        &self,
187        workspace_id: &str,
188        extension: &InstalledExtension,
189    ) -> Result<ConnectorClient<ConnectorRuntimeHandle>, ExtensionError> {
190        let entrypoint = resolve_entrypoint(extension)?;
191        let runtime = match extension.manifest.runtime {
192            ManifestRuntime::Wasm => {
193                let permissions =
194                    self.get_workspace_wasi_permissions(workspace_id, &extension.manifest.id)?;
195                ConnectorRuntimeHandle::Wasm(
196                    WasmConnectorRuntime::spawn_with_permissions(&entrypoint, permissions)
197                        .map_err(|e| ExtensionError::Runtime(e.to_string()))?,
198                )
199            }
200            ManifestRuntime::Sidecar => {
201                let runtime = SidecarConnectorRuntime::spawn(&entrypoint)
202                    .map_err(|e| ExtensionError::Runtime(e.to_string()))?;
203                ConnectorRuntimeHandle::Sidecar(runtime)
204            }
205        };
206        Ok(ConnectorClient::new(runtime))
207    }
208}
209
210pub enum ConnectorRuntimeHandle {
211    Wasm(WasmConnectorRuntime),
212    Sidecar(SidecarConnectorRuntime),
213}
214
215impl RuntimeTrait for ConnectorRuntimeHandle {
216    fn send(
217        &mut self,
218        request: kapsl_rag_sdk::protocol::ConnectorRequest,
219    ) -> Result<kapsl_rag_sdk::protocol::ConnectorResponse, crate::runtime::RuntimeError> {
220        match self {
221            ConnectorRuntimeHandle::Wasm(runtime) => runtime.send(request),
222            ConnectorRuntimeHandle::Sidecar(runtime) => runtime.send(request),
223        }
224    }
225
226    fn close(&mut self) -> Result<(), crate::runtime::RuntimeError> {
227        match self {
228            ConnectorRuntimeHandle::Wasm(runtime) => runtime.close(),
229            ConnectorRuntimeHandle::Sidecar(runtime) => runtime.close(),
230        }
231    }
232}
233
234#[derive(Debug, Deserialize, Default)]
235struct WasiConfig {
236    #[serde(default)]
237    env: HashMap<String, String>,
238    #[serde(default)]
239    preopen_dirs: Vec<WasiDirConfig>,
240}
241
242#[derive(Debug, Deserialize)]
243struct WasiDirConfig {
244    host_path: String,
245    guest_path: String,
246    #[serde(default)]
247    read_only: bool,
248}
249
250fn wasi_permissions_from_config(
251    config: Option<&ConnectorConfig>,
252) -> Result<WasiPermissions, ExtensionError> {
253    let Some(config) = config else {
254        return Ok(WasiPermissions::default());
255    };
256    let obj = match config {
257        serde_json::Value::Object(_) => config,
258        _ => return Ok(WasiPermissions::default()),
259    };
260
261    let wasi_value = obj.get("wasi");
262    if wasi_value.is_none() {
263        return Ok(WasiPermissions::default());
264    }
265    let wasi_value = wasi_value.unwrap();
266    let parsed: WasiConfig = serde_json::from_value(wasi_value.clone())
267        .map_err(|e| ExtensionError::InvalidConfig(e.to_string()))?;
268
269    let mut permissions = WasiPermissions::default();
270    for (key, value) in parsed.env {
271        validate_env_kv(&key, &value)?;
272        permissions = permissions.with_env(key, value);
273    }
274
275    for dir in parsed.preopen_dirs {
276        validate_host_path(&dir.host_path)?;
277        validate_guest_path(&dir.guest_path)?;
278        permissions =
279            permissions.allow_dir(PathBuf::from(dir.host_path), dir.guest_path, dir.read_only);
280    }
281
282    Ok(permissions)
283}
284
285fn strip_wasi_block(config: ConnectorConfig) -> ConnectorConfig {
286    match config {
287        Value::Object(mut map) => {
288            map.remove("wasi");
289            Value::Object(map)
290        }
291        other => other,
292    }
293}
294
295fn validate_env_kv(key: &str, value: &str) -> Result<(), ExtensionError> {
296    if key.is_empty() {
297        return Err(ExtensionError::InvalidConfig(
298            "env key cannot be empty".to_string(),
299        ));
300    }
301    if key.contains('\0') || value.contains('\0') {
302        return Err(ExtensionError::InvalidConfig(
303            "env key/value cannot contain NUL".to_string(),
304        ));
305    }
306    Ok(())
307}
308
309fn validate_guest_path(path: &str) -> Result<(), ExtensionError> {
310    if path.is_empty() || !path.starts_with('/') {
311        return Err(ExtensionError::InvalidConfig(
312            "preopened guest path must be absolute".to_string(),
313        ));
314    }
315    if path.contains('\0') {
316        return Err(ExtensionError::InvalidConfig(
317            "preopened guest path cannot contain NUL".to_string(),
318        ));
319    }
320    Ok(())
321}
322
323fn validate_host_path(path: &str) -> Result<(), ExtensionError> {
324    let host_path = Path::new(path);
325    if !host_path.is_absolute() {
326        return Err(ExtensionError::InvalidConfig(
327            "preopened host path must be absolute".to_string(),
328        ));
329    }
330    Ok(())
331}
332
333fn load_manifest(dir: &Path) -> Result<ConnectorManifest, ExtensionError> {
334    let toml_path = dir.join("rag-extension.toml");
335    let json_path = dir.join("rag-extension.json");
336
337    if toml_path.exists() {
338        let data = fs::read_to_string(&toml_path)?;
339        let manifest =
340            toml::from_str(&data).map_err(|e| ExtensionError::InvalidManifest(e.to_string()))?;
341        return Ok(manifest);
342    }
343
344    if json_path.exists() {
345        let data = fs::read_to_string(&json_path)?;
346        let manifest = serde_json::from_str(&data)
347            .map_err(|e| ExtensionError::InvalidManifest(e.to_string()))?;
348        return Ok(manifest);
349    }
350
351    Err(ExtensionError::ManifestMissing(dir.display().to_string()))
352}
353
354fn resolve_entrypoint(extension: &InstalledExtension) -> Result<PathBuf, ExtensionError> {
355    let entry = extension.manifest.entrypoint.as_deref();
356    let runtime = &extension.manifest.runtime;
357    let default_entry = match runtime {
358        ManifestRuntime::Wasm => "connector.wasm",
359        ManifestRuntime::Sidecar => "connector",
360    };
361    let entry = entry.unwrap_or(default_entry);
362    let path = Path::new(entry);
363    let resolved = if path.is_absolute() {
364        path.to_path_buf()
365    } else {
366        extension.path.join(entry)
367    };
368    if !resolved.exists() {
369        return Err(ExtensionError::InvalidConfig(format!(
370            "entrypoint not found: {}",
371            resolved.display()
372        )));
373    }
374    Ok(resolved)
375}
376
377fn copy_dir_all(src: &Path, dst: &Path) -> Result<(), ExtensionError> {
378    fs::create_dir_all(dst)?;
379    for entry in fs::read_dir(src)? {
380        let entry = entry?;
381        let ty = entry.file_type()?;
382        let src_path = entry.path();
383        let dst_path = dst.join(entry.file_name());
384        if ty.is_dir() {
385            copy_dir_all(&src_path, &dst_path)?;
386        } else {
387            fs::copy(&src_path, &dst_path)?;
388        }
389    }
390    Ok(())
391}