use anyhow::{Context, Result};
use gnaw_core::{
builtin_templates::BuiltinTemplates,
configuration::{DiffMode, GnawConfig, TomlConfig},
session::SelectionState,
sort::FileSortMethod,
template::OutputFormat,
tokenizer::TokenizerType,
};
use log::error;
use std::path::PathBuf;
use crate::{args::Cli, config_loader::ConfigSource};
use gnaw_core::configuration::CompressionOptions;
const STRIP_TOKENS: [&str; 4] = ["tests", "fn-bodies", "doc-comments", "private-bodies"];
const GIT_NARRATIVE_TEMPLATES: &[&str] = &[
"write-git-commit",
"write-git-changeset-commits",
"write-github-pull-request",
];
pub fn build_session(
base: Option<&ConfigSource>,
args: &Cli,
tui_mode: bool,
) -> Result<SelectionState> {
let mut configuration = GnawConfig::builder();
let cfg = base.map(|b| &b.config);
if let Some(c) = cfg {
if let Some(path) = &c.path {
configuration.path(PathBuf::from(path));
} else {
configuration.path(args.path.clone());
}
} else {
configuration.path(args.path.clone());
}
let (cfg_include, cfg_exclude) = cfg
.map(|c| (c.include_patterns.clone(), c.exclude_patterns.clone()))
.unwrap_or_default();
let mut include_patterns = cfg_include;
include_patterns.extend(expand_comma_separated_patterns(&args.include));
let mut exclude_patterns = cfg_exclude;
exclude_patterns.extend(expand_comma_separated_patterns(&args.exclude));
configuration
.include_patterns(include_patterns)
.exclude_patterns(exclude_patterns);
let cfg_line_numbers = cfg.map(|c| c.line_numbers).unwrap_or(false);
let cfg_absolute = cfg.map(|c| c.absolute_path).unwrap_or(false);
let cfg_full_tree = cfg.map(|c| c.full_directory_tree).unwrap_or(false);
configuration
.line_numbers(args.line_numbers || cfg_line_numbers)
.absolute_path(args.absolute_paths || cfg_absolute)
.full_directory_tree(args.full_directory_tree || cfg_full_tree);
let output_format = if let Some(output_format_str) = args.output_format {
output_format_str
} else if let Some(c) = cfg {
c.output_format.unwrap_or(OutputFormat::Markdown)
} else {
OutputFormat::Markdown
};
configuration.output_format(output_format);
let sort_method = if let Some(sort_str) = args.sort {
sort_str
} else if let Some(c) = cfg {
c.sort_method.unwrap_or(FileSortMethod::NameAsc)
} else {
FileSortMethod::NameAsc
};
configuration.sort_method(sort_method);
let tokenizer_type = if let Some(encoding) = args.encoding {
encoding
} else if let Some(c) = cfg {
c.encoding.unwrap_or(TokenizerType::Cl100kBase)
} else {
TokenizerType::Cl100kBase
};
let token_format = if let Some(format) = args.token_format {
format
} else if let Some(c) = cfg {
c.token_format
.unwrap_or(gnaw_core::tokenizer::TokenFormat::Format)
} else {
gnaw_core::tokenizer::TokenFormat::Format
};
configuration
.encoding(tokenizer_type)
.token_format(token_format);
let (template_str, template_name) = if args.template.is_some() {
parse_template(&args.template).map_err(|e| {
error!("Failed to parse template: {}", e);
e
})?
} else if let Some(c) = cfg {
(
c.template_str.clone().unwrap_or_default(),
c.template_name
.clone()
.unwrap_or_else(|| "default".to_string()),
)
} else {
("".to_string(), "default".to_string())
};
let diff_branches = parse_branch_argument(&args.git_diff_branch).or_else(|| {
cfg.and_then(|c| {
c.diff_branches.as_ref().and_then(|branches| {
if branches.len() == 2 {
Some((branches[0].clone(), branches[1].clone()))
} else {
None
}
})
})
});
let log_branches = parse_branch_argument(&args.git_log_branch).or_else(|| {
cfg.and_then(|c| {
c.log_branches.as_ref().and_then(|branches| {
if branches.len() == 2 {
Some((branches[0].clone(), branches[1].clone()))
} else {
None
}
})
})
});
let diff_shas = parse_ref_range(&args.git_diff_shas);
let cfg_diff_enabled = cfg.map(|c| c.diff_enabled).unwrap_or(false);
let cfg_token_map_enabled = cfg.map(|c| c.token_map_enabled).unwrap_or(false);
let cfg_deselected = cfg.map(|c| c.deselected).unwrap_or(false);
let diff_enabled_resolved = args.diff || cfg_diff_enabled;
let diff_mode_resolved = args.diff_mode.unwrap_or_default();
let (template_str, template_name, git_narrative) = resolve_flag_template_from_parts(
&template_str,
&template_name,
diff_branches.is_some(),
log_branches.is_some(),
diff_enabled_resolved,
diff_mode_resolved,
);
configuration
.template_str(template_str)
.template_name(template_name)
.git_narrative(git_narrative);
let policy = args
.secret_scan
.or_else(|| cfg.and_then(|c| c.secret_scan))
.unwrap_or_default(); configuration.secret_scan(policy);
let allow_paths = if !args.secret_scan_allow.is_empty() {
args.secret_scan_allow.clone()
} else {
cfg.map(|c| c.secret_scan_allow_paths.clone())
.unwrap_or_default()
};
configuration.secret_scan_allow_paths(allow_paths);
configuration.compression(resolve_compression(args, cfg)?);
configuration
.diff_enabled(args.diff || cfg_diff_enabled)
.diff_mode(args.diff_mode.unwrap_or_default())
.diff_branches(diff_branches)
.log_branches(log_branches)
.diff_shas(diff_shas)
.diff_shas_content(args.git_diff_shas_content)
.diff_shas_max_bytes(args.git_diff_shas_max_bytes)
.no_ignore(args.no_ignore)
.hidden(args.hidden)
.no_codeblock(args.no_codeblock)
.follow_symlinks(args.follow_symlinks)
.token_map_enabled(args.token_map || cfg_token_map_enabled || tui_mode)
.deselected(args.deselected || cfg_deselected);
if let Some(c) = cfg {
configuration.user_variables(c.user_variables.clone());
}
let session = SelectionState::new(configuration.build()?);
Ok(session)
}
pub fn parse_branch_argument(branch_arg: &Option<Vec<String>>) -> Option<(String, String)> {
match branch_arg {
Some(branches) if branches.len() == 2 => Some((branches[0].clone(), branches[1].clone())),
_ => None,
}
}
pub fn parse_ref_range(arg: &Option<Vec<String>>) -> Option<(String, String)> {
let parts = arg.as_ref()?;
match parts.as_slice() {
[a, b] => Some((a.clone(), b.clone())),
[single] => {
let (a, b) = single.split_once("..").or_else(|| single.split_once(','))?;
let (a, b) = (a.trim(), b.trim());
if a.is_empty() || b.is_empty() {
None
} else {
Some((a.to_string(), b.to_string()))
}
}
_ => None,
}
}
pub fn parse_template(template_arg: &Option<String>) -> Result<(String, String)> {
match template_arg {
Some(arg) => {
if let Some(t) = gnaw_core::builtin_templates::BuiltinTemplates::get_template(arg) {
return Ok((t.content.to_string(), arg.clone()));
}
let content = std::fs::read_to_string(arg).with_context(|| {
let keys = gnaw_core::builtin_templates::BuiltinTemplates::get_template_keys();
format!(
"'{arg}' is not a built-in template and no file exists at that path.\n\
Available built-ins: {}",
keys.join(", ")
)
})?;
Ok((content, "custom".to_string()))
}
None => Ok(("".to_string(), "default".to_string())),
}
}
fn expand_comma_separated_patterns(patterns: &[String]) -> Vec<String> {
let mut expanded = Vec::new();
for pattern in patterns {
if pattern.contains('{') && pattern.contains('}') {
expanded.push(pattern.clone());
} else {
for part in pattern.split(',') {
let trimmed = part.trim();
if !trimmed.is_empty() {
expanded.push(trimmed.to_string());
}
}
}
}
expanded
}
fn resolve_compression(args: &Cli, cfg: Option<&TomlConfig>) -> Result<CompressionOptions> {
let base = match args.compress {
Some(level) => level.options(),
None => cfg.and_then(|c| c.compression).unwrap_or_default(),
};
match &args.compress_strip {
Some(csv) => apply_strip_overrides(base, csv),
None => Ok(base),
}
}
fn apply_strip_overrides(
mut o: CompressionOptions,
tokens: &[String],
) -> Result<CompressionOptions> {
for raw in tokens {
let raw = raw.trim();
if raw.is_empty() {
continue;
}
let (on, name) = match raw.strip_prefix("no-") {
Some(rest) => (false, rest),
None => (true, raw),
};
match name {
"tests" => o.strip_test_modules = on,
"fn-bodies" => o.strip_fn_bodies = on,
"doc-comments" => o.strip_doc_comments = on,
"private-bodies" => o.strip_private_bodies = on,
other => {
let hint = closest(other)
.map(|s| format!(" — did you mean '{s}'?"))
.unwrap_or_default();
anyhow::bail!(
"unknown compression flag '{other}'{hint}\n valid tokens: {} \
(each optionally prefixed with `no-` to disable)",
STRIP_TOKENS.join(", ")
);
}
}
}
Ok(o)
}
fn closest(input: &str) -> Option<&'static str> {
STRIP_TOKENS
.iter()
.map(|&t| (t, levenshtein(input, t)))
.filter(|(_, d)| *d <= 3)
.min_by_key(|(_, d)| *d)
.map(|(t, _)| t)
}
fn levenshtein(a: &str, b: &str) -> usize {
let (a, b): (Vec<char>, Vec<char>) = (a.chars().collect(), b.chars().collect());
let mut prev: Vec<usize> = (0..=b.len()).collect();
let mut curr = vec![0usize; b.len() + 1];
for (i, ca) in a.iter().enumerate() {
curr[0] = i + 1;
for (j, cb) in b.iter().enumerate() {
let cost = usize::from(ca != cb);
curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(prev[j] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[b.len()]
}
fn resolve_flag_template_from_parts(
explicit_template_str: &str,
explicit_template_name: &str,
has_diff_branches: bool,
has_log_branches: bool,
diff_enabled: bool,
diff_mode: DiffMode,
) -> (String, String, bool) {
let user_picked = !explicit_template_str.is_empty() || explicit_template_name != "default";
if user_picked {
let git_narrative = GIT_NARRATIVE_TEMPLATES.contains(&explicit_template_name);
return (
explicit_template_str.to_string(),
explicit_template_name.to_string(),
git_narrative,
);
}
let key: Option<&str> = if has_diff_branches || has_log_branches {
Some("write-github-pull-request")
} else if diff_enabled {
match diff_mode {
DiffMode::Unstaged | DiffMode::All => Some("write-git-changeset-commits"),
DiffMode::Staged => Some("write-git-commit"),
}
} else {
None
};
match key.and_then(BuiltinTemplates::get_template) {
Some(t) => (t.content.to_string(), key.unwrap().to_string(), true),
None => (String::new(), "default".to_string(), false),
}
}