use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
use super::Extends;
use super::cache::ConfigCache;
use crate::config::{Config, EnvConfig, HooksConfig, ServicesConfig, ToolConfig, ToolHooks};
pub const MAX_DEPTH: usize = 10;
pub type Result<T> = std::result::Result<T, InheritanceError>;
#[derive(Debug)]
pub enum InheritanceError {
MaxDepthExceeded { path: Vec<String>, depth: usize },
CircularDependency(Vec<String>),
FileNotFound { path: String, error: String },
FetchFailed { url: String, error: String },
InvalidToml { source: String, error: String },
}
impl std::fmt::Display for InheritanceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InheritanceError::MaxDepthExceeded { path, depth } => {
write!(
f,
"Maximum inheritance depth ({}) exceeded.\nChain: {}",
depth,
path.join(" -> ")
)
}
InheritanceError::CircularDependency(cycle) => {
write!(f, "Circular dependency detected: {}", cycle.join(" -> "))
}
InheritanceError::FileNotFound { path, error } => {
write!(f, "Config file not found: {} ({})", path, error)
}
InheritanceError::FetchFailed { url, error } => {
write!(
f,
"Failed to fetch remote config: {}\n\
Error: {}\n\
Hint: Try --offline to use cached config",
url, error
)
}
InheritanceError::InvalidToml { source, error } => {
write!(f, "Invalid TOML in '{}': {}", source, error)
}
}
}
}
impl std::error::Error for InheritanceError {}
#[derive(Debug, Clone)]
pub struct TraceEntry {
pub source: String,
pub depth: usize,
pub from_cache: bool,
pub parents: Vec<String>,
}
#[derive(Debug, Default)]
pub struct ResolutionTrace {
pub entries: Vec<TraceEntry>,
pub max_depth: usize,
pub network_fetches: Vec<String>,
pub cache_hits: Vec<String>,
}
impl ResolutionTrace {
pub fn display_tree(&self) -> String {
let mut output = String::new();
for entry in &self.entries {
let indent = " ".repeat(entry.depth);
let cache_marker = if entry.from_cache { " (cached)" } else { "" };
output.push_str(&format!("{}{}{}\n", indent, entry.source, cache_marker));
}
output
}
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct ExtendedConfig {
#[serde(default)]
pub extends: Option<Extends>,
#[serde(rename = "provisioner", default)]
pub tools: HashMap<String, ToolConfig>,
#[serde(default)]
pub hooks: HooksConfig,
#[serde(default)]
pub env: EnvConfig,
#[serde(default)]
pub services: ServicesConfig,
}
impl ExtendedConfig {
pub fn from_str(content: &str) -> Result<Self> {
toml::from_str(content).map_err(|e| InheritanceError::InvalidToml {
source: "<string>".to_string(),
error: e.to_string(),
})
}
pub fn into_config(self) -> Config {
let mut toml_content = String::new();
toml_content.push_str("[provisioner]\n");
for (name, config) in &self.tools {
match config {
ToolConfig::Simple(v) => {
toml_content.push_str(&format!("{} = \"{}\"\n", name, v));
}
ToolConfig::Detailed {
version,
version_manager,
use_sudo,
} => {
toml_content.push_str(&format!("{} = {{ version = \"{}\"", name, version));
if let Some(vm) = version_manager {
toml_content.push_str(&format!(", version_manager = {}", vm));
}
if let Some(sudo) = use_sudo {
toml_content.push_str(&format!(", use_sudo = {}", sudo));
}
toml_content.push_str(" }\n");
}
}
}
toml::from_str(&toml_content).unwrap_or_else(|_| toml::from_str("[provisioner]\n").unwrap())
}
}
pub struct InheritanceResolver {
cache: ConfigCache,
in_progress: HashSet<String>,
resolved_cache: HashMap<String, ExtendedConfig>,
depth: usize,
base_dir: PathBuf,
offline: bool,
trace: ResolutionTrace,
}
impl InheritanceResolver {
pub fn new() -> Self {
Self {
cache: ConfigCache::new(),
in_progress: HashSet::new(),
resolved_cache: HashMap::new(),
depth: 0,
base_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
offline: false,
trace: ResolutionTrace::default(),
}
}
pub fn with_base_dir(mut self, dir: PathBuf) -> Self {
self.base_dir = dir;
self
}
pub fn offline(mut self, enabled: bool) -> Self {
self.offline = enabled;
self
}
pub fn with_cache(mut self, cache: ConfigCache) -> Self {
self.cache = cache;
self
}
pub fn resolve(&mut self, path: &str) -> Result<ExtendedConfig> {
self.depth = 0;
self.in_progress.clear();
self.trace = ResolutionTrace::default();
self.resolve_recursive(path)
}
pub fn resolve_with_trace(&mut self, path: &str) -> Result<(ExtendedConfig, ResolutionTrace)> {
let config = self.resolve(path)?;
let trace = std::mem::take(&mut self.trace);
Ok((config, trace))
}
fn resolve_recursive(&mut self, source: &str) -> Result<ExtendedConfig> {
if self.depth > MAX_DEPTH {
return Err(InheritanceError::MaxDepthExceeded {
path: self.in_progress.iter().cloned().collect(),
depth: self.depth,
});
}
let normalized = self.normalize_source(source);
if self.in_progress.contains(&normalized) {
let mut cycle: Vec<String> = self.in_progress.iter().cloned().collect();
cycle.push(normalized);
return Err(InheritanceError::CircularDependency(cycle));
}
self.in_progress.insert(normalized.clone());
self.depth += 1;
let (content, from_cache) = self.load_config(&normalized)?;
let mut config: ExtendedConfig =
toml::from_str(&content).map_err(|e| InheritanceError::InvalidToml {
source: normalized.clone(),
error: e.to_string(),
})?;
let parents: Vec<String> = config
.extends
.as_ref()
.map(|e| e.as_vec().iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
self.trace.entries.push(TraceEntry {
source: normalized.clone(),
depth: self.depth - 1,
from_cache,
parents: parents.clone(),
});
if from_cache {
self.trace.cache_hits.push(normalized.clone());
} else if is_url(&normalized) {
self.trace.network_fetches.push(normalized.clone());
}
if self.depth > self.trace.max_depth {
self.trace.max_depth = self.depth;
}
let parent_sources: Vec<String> = config
.extends
.as_ref()
.map(|e| e.as_vec().iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
for parent_source in &parent_sources {
let resolved_parent = self.resolve_relative_path(&normalized, parent_source);
let parent_config = self.resolve_recursive(&resolved_parent)?;
config = merge_configs(parent_config, config);
}
self.in_progress.remove(&normalized);
self.depth -= 1;
Ok(config)
}
fn normalize_source(&self, source: &str) -> String {
if is_url(source) {
transform_github_url(source)
} else {
let path = Path::new(source);
if path.is_absolute() {
source.to_string()
} else {
self.base_dir.join(source).display().to_string()
}
}
}
fn resolve_relative_path(&self, parent_source: &str, relative: &str) -> String {
if is_url(relative) || Path::new(relative).is_absolute() {
return relative.to_string();
}
if is_url(parent_source) {
if let Some(pos) = parent_source.rfind('/') {
let base = &parent_source[..pos];
format!("{}/{}", base, relative)
} else {
relative.to_string()
}
} else {
let parent_path = Path::new(parent_source);
if let Some(parent_dir) = parent_path.parent() {
parent_dir.join(relative).display().to_string()
} else {
relative.to_string()
}
}
}
fn load_config(&self, source: &str) -> Result<(String, bool)> {
if is_url(source) {
self.load_remote_config(source)
} else {
self.load_local_config(source).map(|c| (c, false))
}
}
fn load_local_config(&self, path: &str) -> Result<String> {
fs::read_to_string(path).map_err(|e| InheritanceError::FileNotFound {
path: path.to_string(),
error: e.to_string(),
})
}
fn load_remote_config(&self, url: &str) -> Result<(String, bool)> {
if let Some(cached) = self.cache.get(url) {
return Ok((cached, true));
}
if self.offline {
if let Some(stale) = self.cache.get_stale(url) {
return Ok((stale, true));
}
return Err(InheritanceError::FetchFailed {
url: url.to_string(),
error: "Offline mode enabled and no cached config available".to_string(),
});
}
let content = self.fetch_url(url)?;
if let Err(e) = self.cache.set(url, &content) {
eprintln!("Warning: Failed to cache config: {}", e);
}
Ok((content, false))
}
fn fetch_url(&self, url: &str) -> Result<String> {
let response = ureq::Agent::new_with_defaults()
.get(url)
.call()
.map_err(|e| InheritanceError::FetchFailed {
url: url.to_string(),
error: e.to_string(),
})?;
let body =
response
.into_body()
.read_to_string()
.map_err(|e| InheritanceError::FetchFailed {
url: url.to_string(),
error: e.to_string(),
})?;
Ok(body)
}
pub fn trace(&self) -> &ResolutionTrace {
&self.trace
}
}
impl Default for InheritanceResolver {
fn default() -> Self {
Self::new()
}
}
fn is_url(s: &str) -> bool {
s.starts_with("http://") || s.starts_with("https://")
}
fn transform_github_url(url: &str) -> String {
if url.contains("github.com") && url.contains("/blob/") {
return url
.replace("github.com", "raw.githubusercontent.com")
.replace("/blob/", "/");
}
if url.contains("gist.github.com") && !url.ends_with("/raw") {
return format!("{}/raw", url.trim_end_matches('/'));
}
url.to_string()
}
fn merge_configs(base: ExtendedConfig, overlay: ExtendedConfig) -> ExtendedConfig {
ExtendedConfig {
extends: None, tools: merge_tools(base.tools, overlay.tools),
hooks: merge_hooks(base.hooks, overlay.hooks),
env: merge_env(base.env, overlay.env),
services: merge_services(base.services, overlay.services),
}
}
fn merge_tools(
base: HashMap<String, ToolConfig>,
overlay: HashMap<String, ToolConfig>,
) -> HashMap<String, ToolConfig> {
let mut result = base;
for (name, config) in overlay {
result.insert(name, config);
}
result
}
fn merge_hooks(base: HooksConfig, overlay: HooksConfig) -> HooksConfig {
HooksConfig {
pre_setup: append_hook_scripts(base.pre_setup, overlay.pre_setup),
post_setup: append_hook_scripts(base.post_setup, overlay.post_setup),
config: overlay.config,
tool_hooks: merge_tool_hooks(base.tool_hooks, overlay.tool_hooks),
}
}
fn append_hook_scripts(base: Option<String>, overlay: Option<String>) -> Option<String> {
match (base, overlay) {
(None, None) => None,
(Some(b), None) => Some(b),
(None, Some(o)) => Some(o),
(Some(b), Some(o)) => Some(format!("{}\n{}", b, o)),
}
}
fn merge_tool_hooks(
base: HashMap<String, ToolHooks>,
overlay: HashMap<String, ToolHooks>,
) -> HashMap<String, ToolHooks> {
let mut result = base;
for (name, hooks) in overlay {
if let Some(existing) = result.get_mut(&name) {
existing.post_install =
append_hook_scripts(existing.post_install.clone(), hooks.post_install);
} else {
result.insert(name, hooks);
}
}
result
}
fn merge_env(base: EnvConfig, overlay: EnvConfig) -> EnvConfig {
let mut vars = base.vars;
for (k, v) in overlay.vars {
vars.insert(k, v);
}
let mut secrets = base.secrets;
for (k, v) in overlay.secrets {
secrets.insert(k, v);
}
let mut tool_env = base.tool_env;
for (k, v) in overlay.tool_env {
tool_env.insert(k, v);
}
EnvConfig {
vars,
secrets,
config: overlay.config, tool_env,
}
}
fn merge_services(base: ServicesConfig, overlay: ServicesConfig) -> ServicesConfig {
ServicesConfig {
enabled: overlay.enabled || base.enabled,
auto_start: overlay.auto_start,
compose_file: overlay.compose_file.or(base.compose_file),
tilt_file: overlay.tilt_file.or(base.tilt_file),
start_in_ci: overlay.start_in_ci,
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_config(content: &str) -> (TempDir, PathBuf) {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("jarvy.toml");
fs::write(&config_path, content).unwrap();
(temp_dir, config_path)
}
#[test]
fn test_parse_extended_config() {
let content = r#"
extends = "https://example.com/base.toml"
[provisioner]
git = "latest"
node = "20"
"#;
let config: ExtendedConfig = toml::from_str(content).unwrap();
assert!(matches!(config.extends, Some(Extends::Single(_))));
assert_eq!(config.tools.len(), 2);
}
#[test]
fn test_parse_multiple_extends() {
let content = r#"
extends = ["base.toml", "override.toml"]
[provisioner]
git = "latest"
"#;
let config: ExtendedConfig = toml::from_str(content).unwrap();
assert!(matches!(config.extends, Some(Extends::Multiple(_))));
if let Some(Extends::Multiple(v)) = &config.extends {
assert_eq!(v.len(), 2);
}
}
#[test]
fn test_merge_tools_overlay_wins() {
let base: HashMap<String, ToolConfig> = [
("git".to_string(), ToolConfig::Simple("2.40".to_string())),
("node".to_string(), ToolConfig::Simple("18".to_string())),
]
.into_iter()
.collect();
let overlay: HashMap<String, ToolConfig> = [
("git".to_string(), ToolConfig::Simple("2.45".to_string())),
(
"docker".to_string(),
ToolConfig::Simple("latest".to_string()),
),
]
.into_iter()
.collect();
let merged = merge_tools(base, overlay);
assert_eq!(merged.len(), 3);
match merged.get("git") {
Some(ToolConfig::Simple(v)) => assert_eq!(v, "2.45"),
_ => panic!("Expected Simple config"),
}
match merged.get("node") {
Some(ToolConfig::Simple(v)) => assert_eq!(v, "18"),
_ => panic!("Expected Simple config"),
}
assert!(merged.contains_key("docker"));
}
#[test]
fn test_append_hook_scripts() {
assert_eq!(append_hook_scripts(None, None), None);
assert_eq!(
append_hook_scripts(Some("a".to_string()), None),
Some("a".to_string())
);
assert_eq!(
append_hook_scripts(None, Some("b".to_string())),
Some("b".to_string())
);
assert_eq!(
append_hook_scripts(Some("a".to_string()), Some("b".to_string())),
Some("a\nb".to_string())
);
}
#[test]
fn test_is_url() {
assert!(is_url("https://example.com/config.toml"));
assert!(is_url("http://localhost:8080/config.toml"));
assert!(!is_url("./local/config.toml"));
assert!(!is_url("/absolute/path/config.toml"));
}
#[test]
fn test_transform_github_url() {
let url = "https://github.com/user/repo/blob/main/jarvy.toml";
let transformed = transform_github_url(url);
assert_eq!(
transformed,
"https://raw.githubusercontent.com/user/repo/main/jarvy.toml"
);
}
#[test]
fn test_transform_gist_url() {
let url = "https://gist.github.com/user/abc123";
let transformed = transform_github_url(url);
assert_eq!(transformed, "https://gist.github.com/user/abc123/raw");
}
#[test]
fn test_resolve_local_config() {
let content = r#"
[provisioner]
git = "latest"
"#;
let (_temp, config_path) = create_test_config(content);
let mut resolver = InheritanceResolver::new();
let config = resolver.resolve(config_path.to_str().unwrap()).unwrap();
assert_eq!(config.tools.len(), 1);
assert!(config.tools.contains_key("git"));
}
#[test]
fn test_resolve_with_local_extends() {
let temp_dir = TempDir::new().unwrap();
let base_content = r#"
[provisioner]
git = "2.40"
node = "18"
"#;
let base_path = temp_dir.path().join("base.toml");
fs::write(&base_path, base_content).unwrap();
let child_content = format!(
r#"
extends = '{}'
[provisioner]
git = "2.45"
docker = "latest"
"#,
base_path.display()
);
let child_path = temp_dir.path().join("child.toml");
fs::write(&child_path, &child_content).unwrap();
let mut resolver = InheritanceResolver::new().with_base_dir(temp_dir.path().to_path_buf());
let config = resolver.resolve(child_path.to_str().unwrap()).unwrap();
match config.tools.get("git") {
Some(ToolConfig::Simple(v)) => assert_eq!(v, "2.45"),
_ => panic!("Expected git to be overridden"),
}
assert!(config.tools.contains_key("node"));
assert!(config.tools.contains_key("docker"));
}
#[test]
fn test_circular_dependency_detection() {
let temp_dir = TempDir::new().unwrap();
let a_path = temp_dir.path().join("a.toml");
let b_path = temp_dir.path().join("b.toml");
let a_content = format!("extends = '{}'\n[provisioner]\na = \"1\"", b_path.display());
let b_content = format!("extends = '{}'\n[provisioner]\nb = \"1\"", a_path.display());
fs::write(&a_path, a_content).unwrap();
fs::write(&b_path, b_content).unwrap();
let mut resolver = InheritanceResolver::new().with_base_dir(temp_dir.path().to_path_buf());
let result = resolver.resolve(a_path.to_str().unwrap());
assert!(matches!(
result,
Err(InheritanceError::CircularDependency(_))
));
}
#[test]
fn test_max_depth_exceeded() {
let temp_dir = TempDir::new().unwrap();
let mut paths = Vec::new();
for i in 0..=MAX_DEPTH + 2 {
let path = temp_dir.path().join(format!("config{}.toml", i));
paths.push(path);
}
for i in (0..paths.len()).rev() {
let content = if i < paths.len() - 1 {
format!(
"extends = '{}'\n[provisioner]\nvar{} = \"{}\"",
paths[i + 1].display(),
i,
i
)
} else {
format!("[provisioner]\nvar{} = \"{}\"", i, i)
};
fs::write(&paths[i], content).unwrap();
}
let mut resolver = InheritanceResolver::new().with_base_dir(temp_dir.path().to_path_buf());
let result = resolver.resolve(paths[0].to_str().unwrap());
assert!(matches!(
result,
Err(InheritanceError::MaxDepthExceeded { .. })
));
}
#[test]
fn test_diamond_dependency() {
let temp_dir = TempDir::new().unwrap();
let d_path = temp_dir.path().join("d.toml");
let b_path = temp_dir.path().join("b.toml");
let c_path = temp_dir.path().join("c.toml");
let a_path = temp_dir.path().join("a.toml");
fs::write(&d_path, "[provisioner]\nd_tool = \"1.0\"").unwrap();
fs::write(
&b_path,
format!(
"extends = '{}'\n[provisioner]\nb_tool = \"1.0\"",
d_path.display()
),
)
.unwrap();
fs::write(
&c_path,
format!(
"extends = '{}'\n[provisioner]\nc_tool = \"1.0\"",
d_path.display()
),
)
.unwrap();
fs::write(
&a_path,
format!(
"extends = ['{}', '{}']\n[provisioner]\na_tool = \"1.0\"",
b_path.display(),
c_path.display()
),
)
.unwrap();
let mut resolver = InheritanceResolver::new().with_base_dir(temp_dir.path().to_path_buf());
let config = resolver.resolve(a_path.to_str().unwrap()).unwrap();
assert!(config.tools.contains_key("a_tool"));
assert!(config.tools.contains_key("b_tool"));
assert!(config.tools.contains_key("c_tool"));
assert!(config.tools.contains_key("d_tool"));
}
#[test]
fn test_resolution_trace() {
let temp_dir = TempDir::new().unwrap();
let base_path = temp_dir.path().join("base.toml");
let child_path = temp_dir.path().join("child.toml");
fs::write(&base_path, "[provisioner]\nbase = \"1\"").unwrap();
fs::write(
&child_path,
format!(
"extends = '{}'\n[provisioner]\nchild = \"1\"",
base_path.display()
),
)
.unwrap();
let mut resolver = InheritanceResolver::new().with_base_dir(temp_dir.path().to_path_buf());
let (_, trace) = resolver
.resolve_with_trace(child_path.to_str().unwrap())
.unwrap();
assert_eq!(trace.entries.len(), 2);
assert_eq!(trace.max_depth, 2);
}
}