use super::fallback_merge::{built_in_fallback_defaults, merge_fallback_configs};
use super::types::UnifiedConfig;
use crate::agents::fallback::FallbackConfig;
impl UnifiedConfig {
#[must_use]
pub fn merge_with(&self, local: &Self) -> Self {
use super::types::{
CcsConfig, GeneralBehaviorFlags, GeneralConfig, GeneralExecutionFlags,
GeneralWorkflowFlags,
};
fn merge_ccs_string(local: &str, global: &str) -> String {
if local.is_empty() {
global.to_string()
} else {
local.to_string()
}
}
let defaults = GeneralConfig::default();
let general = GeneralConfig {
verbosity: if local.general.verbosity == defaults.verbosity {
self.general.verbosity
} else {
local.general.verbosity
},
behavior: GeneralBehaviorFlags {
interactive: if local.general.behavior.interactive == defaults.behavior.interactive
{
self.general.behavior.interactive
} else {
local.general.behavior.interactive
},
auto_detect_stack: if local.general.behavior.auto_detect_stack
== defaults.behavior.auto_detect_stack
{
self.general.behavior.auto_detect_stack
} else {
local.general.behavior.auto_detect_stack
},
strict_validation: if local.general.behavior.strict_validation
== defaults.behavior.strict_validation
{
self.general.behavior.strict_validation
} else {
local.general.behavior.strict_validation
},
},
workflow: GeneralWorkflowFlags {
checkpoint_enabled: if local.general.workflow.checkpoint_enabled
== defaults.workflow.checkpoint_enabled
{
self.general.workflow.checkpoint_enabled
} else {
local.general.workflow.checkpoint_enabled
},
},
execution: GeneralExecutionFlags {
force_universal_prompt: if local.general.execution.force_universal_prompt
== defaults.execution.force_universal_prompt
{
self.general.execution.force_universal_prompt
} else {
local.general.execution.force_universal_prompt
},
isolation_mode: if local.general.execution.isolation_mode
== defaults.execution.isolation_mode
{
self.general.execution.isolation_mode
} else {
local.general.execution.isolation_mode
},
},
developer_iters: if local.general.developer_iters == defaults.developer_iters {
self.general.developer_iters
} else {
local.general.developer_iters
},
reviewer_reviews: if local.general.reviewer_reviews == defaults.reviewer_reviews {
self.general.reviewer_reviews
} else {
local.general.reviewer_reviews
},
developer_context: if local.general.developer_context == defaults.developer_context {
self.general.developer_context
} else {
local.general.developer_context
},
reviewer_context: if local.general.reviewer_context == defaults.reviewer_context {
self.general.reviewer_context
} else {
local.general.reviewer_context
},
review_depth: if local.general.review_depth == defaults.review_depth {
self.general.review_depth.clone()
} else {
local.general.review_depth.clone()
},
prompt_path: local
.general
.prompt_path
.clone()
.or_else(|| self.general.prompt_path.clone()),
templates_dir: local
.general
.templates_dir
.clone()
.or_else(|| self.general.templates_dir.clone()),
git_user_name: local
.general
.git_user_name
.clone()
.or_else(|| self.general.git_user_name.clone()),
git_user_email: local
.general
.git_user_email
.clone()
.or_else(|| self.general.git_user_email.clone()),
provider_fallback: if local.general.provider_fallback.is_empty() {
self.general.provider_fallback.clone()
} else {
local.general.provider_fallback.clone()
},
max_dev_continuations: if local.general.max_dev_continuations
== defaults.max_dev_continuations
{
self.general.max_dev_continuations
} else {
local.general.max_dev_continuations
},
max_xsd_retries: if local.general.max_xsd_retries == defaults.max_xsd_retries {
self.general.max_xsd_retries
} else {
local.general.max_xsd_retries
},
max_same_agent_retries: if local.general.max_same_agent_retries
== defaults.max_same_agent_retries
{
self.general.max_same_agent_retries
} else {
local.general.max_same_agent_retries
},
max_commit_residual_retries: if local.general.max_commit_residual_retries
== defaults.max_commit_residual_retries
{
self.general.max_commit_residual_retries
} else {
local.general.max_commit_residual_retries
},
max_retries: if local.general.max_retries == defaults.max_retries {
self.general.max_retries
} else {
local.general.max_retries
},
retry_delay_ms: if local.general.retry_delay_ms == defaults.retry_delay_ms {
self.general.retry_delay_ms
} else {
local.general.retry_delay_ms
},
backoff_multiplier: if (local.general.backoff_multiplier - defaults.backoff_multiplier)
.abs()
< f64::EPSILON
{
self.general.backoff_multiplier
} else {
local.general.backoff_multiplier
},
max_backoff_ms: if local.general.max_backoff_ms == defaults.max_backoff_ms {
self.general.max_backoff_ms
} else {
local.general.max_backoff_ms
},
max_cycles: if local.general.max_cycles == defaults.max_cycles {
self.general.max_cycles
} else {
local.general.max_cycles
},
execution_history_limit: if local.general.execution_history_limit
== defaults.execution_history_limit
{
self.general.execution_history_limit
} else {
local.general.execution_history_limit
},
};
let ccs = CcsConfig {
output_flag: merge_ccs_string(&local.ccs.output_flag, &self.ccs.output_flag),
yolo_flag: merge_ccs_string(&local.ccs.yolo_flag, &self.ccs.yolo_flag),
verbose_flag: merge_ccs_string(&local.ccs.verbose_flag, &self.ccs.verbose_flag),
print_flag: merge_ccs_string(&local.ccs.print_flag, &self.ccs.print_flag),
streaming_flag: merge_ccs_string(&local.ccs.streaming_flag, &self.ccs.streaming_flag),
json_parser: merge_ccs_string(&local.ccs.json_parser, &self.ccs.json_parser),
session_flag: merge_ccs_string(&local.ccs.session_flag, &self.ccs.session_flag),
can_commit: if local.ccs.can_commit == CcsConfig::default().can_commit {
self.ccs.can_commit
} else {
local.ccs.can_commit
},
};
let agents: std::collections::HashMap<_, _> = self
.agents
.iter()
.chain(local.agents.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let ccs_aliases: std::collections::HashMap<_, _> = self
.ccs_aliases
.iter()
.chain(local.ccs_aliases.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let agent_chains: std::collections::HashMap<_, _> = self
.agent_chains
.iter()
.chain(local.agent_chains.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let agent_drains: std::collections::HashMap<_, _> = self
.agent_drains
.iter()
.chain(local.agent_drains.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let agent_chain = merge_fallback_configs(
self.agent_chain.as_ref(),
local.agent_chain.as_ref(),
|_field| true, false,
);
Self {
general,
ccs,
agents,
ccs_aliases,
agent_chains,
agent_drains,
agent_chain,
}
}
#[must_use]
pub fn merge_with_content(&self, local_content: &str, local_parsed: &Self) -> Self {
use super::types::{
CcsConfig, GeneralBehaviorFlags, GeneralConfig, GeneralExecutionFlags,
GeneralWorkflowFlags,
};
let local_toml: toml::Value = toml::from_str(local_content)
.unwrap_or_else(|_| toml::Value::Table(toml::map::Map::default()));
let general_table = local_toml.get("general");
let behavior_table = general_table.and_then(|g| g.get("behavior"));
let provider_fallback_table = general_table.and_then(|g| g.get("provider_fallback"));
let chain_table = local_toml.get("agent_chain");
let has_field = |key: &str| -> bool { general_table.and_then(|g| g.get(key)).is_some() };
let has_behavior_field =
|key: &str| -> bool { behavior_table.and_then(|b| b.get(key)).is_some() };
let general = GeneralConfig {
verbosity: if has_field("verbosity") {
local_parsed.general.verbosity
} else {
self.general.verbosity
},
behavior: GeneralBehaviorFlags {
interactive: if has_behavior_field("interactive") {
local_parsed.general.behavior.interactive
} else {
self.general.behavior.interactive
},
auto_detect_stack: if has_behavior_field("auto_detect_stack") {
local_parsed.general.behavior.auto_detect_stack
} else {
self.general.behavior.auto_detect_stack
},
strict_validation: if has_behavior_field("strict_validation") {
local_parsed.general.behavior.strict_validation
} else {
self.general.behavior.strict_validation
},
},
workflow: GeneralWorkflowFlags {
checkpoint_enabled: if has_field("checkpoint_enabled") {
local_parsed.general.workflow.checkpoint_enabled
} else {
self.general.workflow.checkpoint_enabled
},
},
execution: GeneralExecutionFlags {
force_universal_prompt: if has_field("force_universal_prompt") {
local_parsed.general.execution.force_universal_prompt
} else {
self.general.execution.force_universal_prompt
},
isolation_mode: if has_field("isolation_mode") {
local_parsed.general.execution.isolation_mode
} else {
self.general.execution.isolation_mode
},
},
developer_iters: if has_field("developer_iters") {
local_parsed.general.developer_iters
} else {
self.general.developer_iters
},
reviewer_reviews: if has_field("reviewer_reviews") {
local_parsed.general.reviewer_reviews
} else {
self.general.reviewer_reviews
},
developer_context: if has_field("developer_context") {
local_parsed.general.developer_context
} else {
self.general.developer_context
},
reviewer_context: if has_field("reviewer_context") {
local_parsed.general.reviewer_context
} else {
self.general.reviewer_context
},
review_depth: if has_field("review_depth") {
local_parsed.general.review_depth.clone()
} else {
self.general.review_depth.clone()
},
prompt_path: local_parsed
.general
.prompt_path
.clone()
.or_else(|| self.general.prompt_path.clone()),
templates_dir: local_parsed
.general
.templates_dir
.clone()
.or_else(|| self.general.templates_dir.clone()),
git_user_name: local_parsed
.general
.git_user_name
.clone()
.or_else(|| self.general.git_user_name.clone()),
git_user_email: local_parsed
.general
.git_user_email
.clone()
.or_else(|| self.general.git_user_email.clone()),
provider_fallback: if provider_fallback_table.is_some() {
local_parsed.general.provider_fallback.clone()
} else if chain_table
.and_then(|c| c.get("provider_fallback"))
.is_some()
{
local_parsed.agent_chain.as_ref().map_or_else(
|| self.general.provider_fallback.clone(),
|fallback| fallback.provider_fallback.clone(),
)
} else {
self.general.provider_fallback.clone()
},
max_dev_continuations: if has_field("max_dev_continuations") {
local_parsed.general.max_dev_continuations
} else {
self.general.max_dev_continuations
},
max_xsd_retries: if has_field("max_xsd_retries") {
local_parsed.general.max_xsd_retries
} else {
self.general.max_xsd_retries
},
max_same_agent_retries: if has_field("max_same_agent_retries") {
local_parsed.general.max_same_agent_retries
} else {
self.general.max_same_agent_retries
},
max_commit_residual_retries: if has_field("max_commit_residual_retries") {
local_parsed.general.max_commit_residual_retries
} else {
self.general.max_commit_residual_retries
},
max_retries: if has_field("max_retries") {
local_parsed.general.max_retries
} else if chain_table.and_then(|c| c.get("max_retries")).is_some() {
local_parsed
.agent_chain
.as_ref()
.map_or(self.general.max_retries, |fallback| fallback.max_retries)
} else {
self.general.max_retries
},
retry_delay_ms: if has_field("retry_delay_ms") {
local_parsed.general.retry_delay_ms
} else if chain_table.and_then(|c| c.get("retry_delay_ms")).is_some() {
local_parsed
.agent_chain
.as_ref()
.map_or(self.general.retry_delay_ms, |fallback| {
fallback.retry_delay_ms
})
} else {
self.general.retry_delay_ms
},
backoff_multiplier: if has_field("backoff_multiplier") {
local_parsed.general.backoff_multiplier
} else if chain_table
.and_then(|c| c.get("backoff_multiplier"))
.is_some()
{
local_parsed
.agent_chain
.as_ref()
.map_or(self.general.backoff_multiplier, |fallback| {
fallback.backoff_multiplier
})
} else {
self.general.backoff_multiplier
},
max_backoff_ms: if has_field("max_backoff_ms") {
local_parsed.general.max_backoff_ms
} else if chain_table.and_then(|c| c.get("max_backoff_ms")).is_some() {
local_parsed
.agent_chain
.as_ref()
.map_or(self.general.max_backoff_ms, |fallback| {
fallback.max_backoff_ms
})
} else {
self.general.max_backoff_ms
},
max_cycles: if has_field("max_cycles") {
local_parsed.general.max_cycles
} else if chain_table.and_then(|c| c.get("max_cycles")).is_some() {
local_parsed
.agent_chain
.as_ref()
.map_or(self.general.max_cycles, |fallback| fallback.max_cycles)
} else {
self.general.max_cycles
},
execution_history_limit: if has_field("execution_history_limit") {
local_parsed.general.execution_history_limit
} else {
self.general.execution_history_limit
},
};
let ccs_table = local_toml.get("ccs");
let has_ccs_field = |key: &str| -> bool { ccs_table.and_then(|c| c.get(key)).is_some() };
let ccs = CcsConfig {
output_flag: if has_ccs_field("output_flag") {
local_parsed.ccs.output_flag.clone()
} else {
self.ccs.output_flag.clone()
},
yolo_flag: if has_ccs_field("yolo_flag") {
local_parsed.ccs.yolo_flag.clone()
} else {
self.ccs.yolo_flag.clone()
},
verbose_flag: if has_ccs_field("verbose_flag") {
local_parsed.ccs.verbose_flag.clone()
} else {
self.ccs.verbose_flag.clone()
},
print_flag: if has_ccs_field("print_flag") {
local_parsed.ccs.print_flag.clone()
} else {
self.ccs.print_flag.clone()
},
streaming_flag: if has_ccs_field("streaming_flag") {
local_parsed.ccs.streaming_flag.clone()
} else {
self.ccs.streaming_flag.clone()
},
json_parser: if has_ccs_field("json_parser") {
local_parsed.ccs.json_parser.clone()
} else {
self.ccs.json_parser.clone()
},
session_flag: if has_ccs_field("session_flag") {
local_parsed.ccs.session_flag.clone()
} else {
self.ccs.session_flag.clone()
},
can_commit: if has_ccs_field("can_commit") {
local_parsed.ccs.can_commit
} else {
self.ccs.can_commit
},
};
let agents: std::collections::HashMap<_, _> = self
.agents
.iter()
.chain(local_parsed.agents.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let ccs_aliases: std::collections::HashMap<_, _> = self
.ccs_aliases
.iter()
.chain(local_parsed.ccs_aliases.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let agent_chains: std::collections::HashMap<_, _> = self
.agent_chains
.iter()
.chain(local_parsed.agent_chains.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let agent_drains: std::collections::HashMap<_, _> = self
.agent_drains
.iter()
.chain(local_parsed.agent_drains.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let chain_table = local_toml.get("agent_chain");
let named_schema_present = !agent_chains.is_empty() || !agent_drains.is_empty();
let agent_chain = if chain_table.is_some()
|| self.agent_chain.is_some()
|| local_parsed.agent_chain.is_some()
{
if named_schema_present {
merge_named_schema_legacy_metadata(
self.agent_chain.as_ref(),
local_parsed.agent_chain.as_ref(),
|field| chain_table.and_then(|c| c.get(field)).is_some(),
)
} else {
let built_in_chain = built_in_fallback_defaults();
merge_fallback_configs(
Some(self.agent_chain.as_ref().unwrap_or(&built_in_chain)),
local_parsed.agent_chain.as_ref(),
|field| chain_table.and_then(|c| c.get(field)).is_some(),
true,
)
}
} else {
None
};
Self {
general,
ccs,
agents,
ccs_aliases,
agent_chains,
agent_drains,
agent_chain,
}
}
}
fn merge_named_schema_legacy_metadata(
global: Option<&FallbackConfig>,
local: Option<&FallbackConfig>,
is_local_field_present: impl Fn(&str) -> bool,
) -> Option<FallbackConfig> {
if matches!((global, local), (None, None)) {
None
} else {
let defaults = FallbackConfig::default();
let global = global.unwrap_or(&defaults);
let local = local.unwrap_or(&defaults);
let provider_fallback: std::collections::HashMap<_, _> = global
.provider_fallback
.iter()
.chain(local.provider_fallback.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Some(FallbackConfig {
developer: Vec::new(),
reviewer: Vec::new(),
commit: Vec::new(),
analysis: Vec::new(),
provider_fallback,
max_retries: if is_local_field_present("max_retries") {
local.max_retries
} else {
global.max_retries
},
retry_delay_ms: if is_local_field_present("retry_delay_ms") {
local.retry_delay_ms
} else {
global.retry_delay_ms
},
backoff_multiplier: if is_local_field_present("backoff_multiplier") {
local.backoff_multiplier
} else {
global.backoff_multiplier
},
max_backoff_ms: if is_local_field_present("max_backoff_ms") {
local.max_backoff_ms
} else {
global.max_backoff_ms
},
max_cycles: if is_local_field_present("max_cycles") {
local.max_cycles
} else {
global.max_cycles
},
legacy_role_keys_present: global.has_legacy_role_key_presence()
|| local.has_legacy_role_key_presence(),
})
}
}