use std::fs;
use std::io::{self, Write};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
pub fn handle(cli: &crate::cli::Cli, srcs: &[String], target: Option<&PathBuf>) -> bool {
let target_path = target.cloned().unwrap_or_else(|| cli.hosts_file.clone());
let date = chrono::Utc::now().format("%Y-%m-%d %H:%M UTC").to_string();
let mut merged = Vec::new();
for src in srcs {
let content = match fetch(src) {
Ok(c) => c,
Err(e) => {
eprintln!("Error reading '{}': {}", src, e);
return false;
}
};
let label = src_label(src);
merged.push(format!(
"### source: {} — {}\n{}",
label,
date,
strip_comments(&content)
));
}
if let Err(e) = crate::util::validation::validate_secure_path(&target_path) {
eprintln!("Error: {}", e);
return false;
}
let existing = fs::read_to_string(&target_path).unwrap_or_default();
let preserved = remove_previous_merges(&existing);
let mut output = preserved.trim_end().to_string();
output.push_str("\n\n");
output.push_str(&merged.join("\n\n"));
let dir = target_path
.parent()
.unwrap_or_else(|| std::path::Path::new("."));
let mut tmp = match tempfile::NamedTempFile::new_in(dir) {
Ok(f) => f,
Err(e) => {
eprintln!("Error: {}", e);
return false;
}
};
#[cfg(unix)]
{
let orig_perms = fs::metadata(&target_path)
.ok()
.map(|m| m.permissions().mode());
let mode = orig_perms.unwrap_or(0o100644);
if let Err(e) = tmp
.as_file()
.set_permissions(fs::Permissions::from_mode(mode & 0o777))
{
eprintln!("Warning: failed to set permissions: {}", e);
}
}
if tmp.write_all(output.as_bytes()).is_err() {
return false;
}
let _ = tmp.flush();
let _ = tmp.as_file().sync_all();
if tmp.persist(&target_path).is_err() {
eprintln!("Error saving {}", target_path.display());
return false;
}
if !cli.quiet {
println!(
"Merged {} source(s) → {}",
srcs.len(),
target_path.display()
);
}
true
}
fn fetch(src: &str) -> io::Result<String> {
if src.starts_with("http://") || src.starts_with("https://") {
validate_url_target(src)?;
use std::time::Duration;
let config = ureq::config::Config::builder()
.timeout_connect(Some(Duration::from_secs(10)))
.timeout_global(Some(Duration::from_secs(30)))
.build();
let agent = ureq::Agent::new_with_config(config);
let response = agent
.get(src)
.call()
.map_err(|e| io::Error::other(format!("HTTP request failed: {}", e)))?;
let mut body = response.into_body();
body.read_to_string()
.map_err(|e| io::Error::other(format!("Read response body failed: {}", e)))
} else {
fs::read_to_string(src)
}
}
fn validate_url_target(url_str: &str) -> io::Result<()> {
use std::net::{IpAddr, ToSocketAddrs};
let host = url_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.split('/')
.next()
.unwrap_or("")
.split(':')
.next()
.unwrap_or("");
if host.is_empty() {
return Err(io::Error::other("Empty host in URL"));
}
let addrs = (host, 0)
.to_socket_addrs()
.map_err(|e| io::Error::other(format!("DNS resolution failed: {}", e)))?;
for addr in addrs {
match addr.ip() {
IpAddr::V4(v4) => {
if v4.is_private()
|| v4.is_loopback()
|| v4.is_link_local()
|| v4.is_multicast()
|| v4.is_broadcast()
|| v4.octets()[0] == 0
{
return Err(io::Error::other(format!(
"Blocked request to private/reserved IP: {}",
addr.ip()
)));
}
}
IpAddr::V6(v6) => {
if v6.is_loopback()
|| v6.is_multicast()
|| v6.is_unspecified()
|| v6.octets()[0] == 0xfe && v6.octets()[1] >= 0x80
{
return Err(io::Error::other(format!(
"Blocked request to private/reserved IP: {}",
addr.ip()
)));
}
}
}
}
Ok(())
}
fn src_label(src: &str) -> String {
if src.starts_with("http://") || src.starts_with("https://") {
src.to_string()
} else {
std::path::Path::new(src)
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
}
fn strip_comments(content: &str) -> String {
content
.lines()
.filter(|l| {
let t = l.trim();
!t.is_empty() && !t.starts_with('#')
})
.filter(|l| {
let parts: Vec<&str> = l.split_whitespace().collect();
if parts.is_empty() {
return false;
}
crate::util::validation::is_valid_ip(parts[0])
})
.collect::<Vec<_>>()
.join("\n")
}
fn remove_previous_merges(content: &str) -> String {
let lines: Vec<&str> = content.lines().collect();
let mut result: Vec<&str> = Vec::new();
let mut skip = false;
for line in &lines {
if line.starts_with("### source:") {
skip = true;
continue;
}
if skip {
if line.trim().is_empty() {
skip = false;
} else {
continue;
}
}
result.push(line);
}
result.join("\n")
}