use std::path::{Path, PathBuf};
pub struct SystemPromptLoader {
pub global_path: PathBuf,
pub project_path: PathBuf,
pub local_path: PathBuf,
}
impl SystemPromptLoader {
pub fn new(cwd: &Path) -> Self {
let home = dirs::home_dir().unwrap_or_default();
Self {
global_path: home.join(".kx").join("CLAUDE.md"),
project_path: cwd.join("CLAUDE.md"),
local_path: cwd.join(".kx").join("CLAUDE.md"),
}
}
pub fn load(&self) -> String {
let candidates: &[&PathBuf] = &[&self.global_path, &self.project_path, &self.local_path];
let parts: Vec<String> = candidates
.iter()
.filter(|p| p.exists())
.filter_map(|p| {
let content = std::fs::read_to_string(p).ok()?;
let base_dir = p.parent().unwrap_or_else(|| Path::new("."));
Some(resolve_imports(&content, base_dir))
})
.collect();
parts.join("\n\n")
}
}
const MAX_IMPORT_BYTES: u64 = 256 * 1024;
fn resolve_imports(content: &str, base_dir: &Path) -> String {
let mut out: Vec<String> = Vec::with_capacity(content.lines().count());
for line in content.lines() {
if let Some(rest) = line.trim().strip_prefix("@import ") {
let raw = rest.trim();
if let Some(imported) = read_confined_import(raw, base_dir) {
out.push(imported);
}
} else {
out.push(line.to_string());
}
}
out.join("\n")
}
fn read_confined_import(raw: &str, base_dir: &Path) -> Option<String> {
let candidate = Path::new(raw);
if candidate.is_absolute() {
tracing::warn!(import = raw, "rejected @import: absolute path");
return None;
}
let joined = base_dir.join(candidate);
let canonical_target = std::fs::canonicalize(&joined).ok()?;
let canonical_base = std::fs::canonicalize(base_dir).ok()?;
if !canonical_target.starts_with(&canonical_base) {
tracing::warn!(
import = raw,
"rejected @import: resolves outside base directory"
);
return None;
}
let metadata = std::fs::metadata(&canonical_target).ok()?;
if metadata.len() > MAX_IMPORT_BYTES {
tracing::warn!(
import = raw,
size = metadata.len(),
"rejected @import: file exceeds size cap"
);
return None;
}
std::fs::read_to_string(&canonical_target).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn tmp_dir(_tag: &str) -> (PathBuf, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().to_path_buf();
(path, dir)
}
#[test]
fn load_no_files_returns_empty() {
let (tmp, _guard) = tmp_dir("empty");
let loader = SystemPromptLoader {
global_path: tmp.join("g.md"),
project_path: tmp.join("p.md"),
local_path: tmp.join("l.md"),
};
assert!(loader.load().is_empty());
}
#[test]
fn load_single_project_file() {
let (tmp, _guard) = tmp_dir("single");
let path = tmp.join("CLAUDE.md");
fs::write(&path, "Be helpful.").unwrap();
let loader = SystemPromptLoader {
global_path: tmp.join("missing_g.md"),
project_path: path,
local_path: tmp.join("missing_l.md"),
};
assert_eq!(loader.load(), "Be helpful.");
}
#[test]
fn load_merges_all_three_files() {
let (tmp, _guard) = tmp_dir("merge");
fs::write(tmp.join("global.md"), "Global rules.").unwrap();
fs::write(tmp.join("project.md"), "Project rules.").unwrap();
fs::write(tmp.join("local.md"), "Local rules.").unwrap();
let loader = SystemPromptLoader {
global_path: tmp.join("global.md"),
project_path: tmp.join("project.md"),
local_path: tmp.join("local.md"),
};
let result = loader.load();
assert!(result.contains("Global rules."));
assert!(result.contains("Project rules."));
assert!(result.contains("Local rules."));
}
#[test]
fn load_order_global_before_project_before_local() {
let (tmp, _guard) = tmp_dir("order");
fs::write(tmp.join("a.md"), "first").unwrap();
fs::write(tmp.join("b.md"), "second").unwrap();
fs::write(tmp.join("c.md"), "third").unwrap();
let loader = SystemPromptLoader {
global_path: tmp.join("a.md"),
project_path: tmp.join("b.md"),
local_path: tmp.join("c.md"),
};
let result = loader.load();
let pos_first = result.find("first").unwrap_or(usize::MAX);
let pos_second = result.find("second").unwrap_or(usize::MAX);
let pos_third = result.find("third").unwrap_or(usize::MAX);
assert!(pos_first < pos_second);
assert!(pos_second < pos_third);
}
#[test]
fn load_skips_missing_files_silently() {
let (tmp, _guard) = tmp_dir("skip");
fs::write(tmp.join("project.md"), "only project").unwrap();
let loader = SystemPromptLoader {
global_path: tmp.join("missing_g.md"),
project_path: tmp.join("project.md"),
local_path: tmp.join("missing_l.md"),
};
assert_eq!(loader.load(), "only project");
}
#[test]
fn resolve_imports_expands_directive() {
let (tmp, _guard) = tmp_dir("import_ok");
fs::write(tmp.join("rules.md"), "imported content").unwrap();
let content = "before\n@import rules.md\nafter";
let result = resolve_imports(content, &tmp);
assert!(result.contains("imported content"));
assert!(result.contains("before"));
assert!(result.contains("after"));
}
#[test]
fn resolve_imports_missing_file_drops_line() {
let (tmp, _guard) = tmp_dir("import_missing");
let content = "@import nonexistent.md\nstays";
let result = resolve_imports(content, &tmp);
assert!(result.contains("stays"));
assert!(!result.contains("@import"));
}
#[test]
fn resolve_imports_plain_lines_unchanged() {
let (tmp, _guard) = tmp_dir("import_plain");
let content = "line one\nline two\nline three";
let result = resolve_imports(content, &tmp);
assert_eq!(result, content);
}
#[test]
fn resolve_imports_indented_directive() {
let (tmp, _guard) = tmp_dir("import_indent");
fs::write(tmp.join("extra.md"), "extra rules").unwrap();
let content = " @import extra.md";
let result = resolve_imports(content, &tmp);
assert!(result.contains("extra rules"));
}
#[test]
fn resolve_imports_rejects_absolute_path() {
let (tmp, _guard) = tmp_dir("import_abs");
let content = "before\n@import /etc/passwd\nafter";
let result = resolve_imports(content, &tmp);
assert!(result.contains("before"));
assert!(result.contains("after"));
assert!(!result.contains("root:"));
}
#[test]
fn resolve_imports_rejects_parent_traversal() {
let (outer, _outer_guard) = tmp_dir("import_outer");
let inner = outer.join("inner");
fs::create_dir_all(&inner).unwrap();
fs::write(outer.join("secret.md"), "should not leak").unwrap();
fs::write(inner.join("ok.md"), "fine").unwrap();
let content = "@import ../secret.md\n@import ok.md";
let result = resolve_imports(content, &inner);
assert!(!result.contains("should not leak"));
assert!(result.contains("fine"));
}
#[test]
fn resolve_imports_drops_oversized_file() {
let (tmp, _guard) = tmp_dir("import_huge");
let big = "X".repeat((MAX_IMPORT_BYTES as usize) + 1);
fs::write(tmp.join("big.md"), &big).unwrap();
let content = "head\n@import big.md\ntail";
let result = resolve_imports(content, &tmp);
assert!(result.contains("head"));
assert!(result.contains("tail"));
assert!(!result.contains("XXXX"));
}
}