use anyhow::{Error, Result};
use camino::Utf8PathBuf;
use clap::{Parser, ValueEnum};
use ripline::{
line_buffer::{LineBufferBuilder, LineBufferReader},
lines::LineIter,
};
use rustc_hash::FxHashMap as HashMap;
use std::hash::BuildHasherDefault;
use std::io::{self, IsTerminal, Write};
use std::process::ExitCode;
use termcolor::{ColorChoice, StandardStream};
use geoipsed::{files, geoip, input, mmdb, ExtractorBuilder, IpMatch, Tag, Tagged};
use input::FileOrStdin;
#[inline]
fn is_broken_pipe(err: &Error) -> bool {
for cause in err.chain() {
if let Some(io_err) = cause.downcast_ref::<io::Error>() {
if io_err.kind() == io::ErrorKind::BrokenPipe {
return true;
}
}
}
false
}
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long)]
only_matching: bool,
#[clap(short = 'C', long, value_enum, default_value_t = ArgsColorChoice::Auto)]
color: ArgsColorChoice,
#[clap(short, long)]
template: Option<String>,
#[clap(long, conflicts_with = "only_matching")]
tag: bool,
#[clap(long, conflicts_with = "only_matching")]
tag_files: bool,
#[clap(long)]
all: bool,
#[clap(long)]
no_private: bool,
#[clap(long)]
no_loopback: bool,
#[clap(long)]
no_broadcast: bool,
#[clap(long)]
only_routable: bool,
#[clap(long, value_name = "PROVIDER", default_value = "maxmind")]
provider: String,
#[clap(
short = 'I',
value_name = "DIR",
value_hint = clap::ValueHint::DirPath,
env = "GEOIP_MMDB_DIR"
)]
include: Option<Utf8PathBuf>,
#[clap(long)]
list_providers: bool,
#[clap(short = 'L', long)]
list_templates: bool,
#[clap(value_name = "FILE", value_hint = clap::ValueHint::FilePath)]
input: Vec<Utf8PathBuf>,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, ValueEnum)]
enum ArgsColorChoice {
Always,
Never,
Auto,
}
fn main() -> ExitCode {
let err = match run_main() {
Ok(code) => return code,
Err(err) => err,
};
if is_broken_pipe(&err) {
return ExitCode::SUCCESS;
}
if std::env::var("RUST_BACKTRACE").is_ok_and(|v| v == "1")
&& std::env::var("RUST_LIB_BACKTRACE").map_or(true, |v| v == "1")
{
writeln!(&mut std::io::stderr(), "{err:?}").unwrap();
} else {
writeln!(&mut std::io::stderr(), "{err:#}").unwrap();
}
ExitCode::FAILURE
}
fn run_main() -> Result<ExitCode> {
let mut args = Args::parse();
let mut provider_registry = mmdb::ProviderRegistry::default();
if args.list_providers {
let info = provider_registry.print_db_info()?;
println!("{info}");
return Ok(ExitCode::SUCCESS);
}
if args.list_templates {
provider_registry.set_active_provider(&args.provider)?;
provider_registry.initialize_active_provider(args.include.clone())?;
let fields = provider_registry.available_fields()?;
println!(
"Available template fields for provider '{}':",
args.provider
);
for field in fields {
println!(
"{{{}}}\t{}\t(example: {})",
field.name, field.description, field.example
);
}
return Ok(ExitCode::SUCCESS);
}
if args.input.is_empty() {
args.input.push(Utf8PathBuf::from("-"));
}
if args.include.is_none() {
if let Ok(legacy_path) = std::env::var("MAXMIND_MMDB_DIR") {
args.include = Some(Utf8PathBuf::from(legacy_path));
eprintln!("Warning: MAXMIND_MMDB_DIR is deprecated, please use GEOIP_MMDB_DIR instead");
}
}
let colormode = match args.color {
ArgsColorChoice::Auto => {
if std::io::stdout().is_terminal() {
ColorChoice::Always
} else {
ColorChoice::Never
}
}
ArgsColorChoice::Always => ColorChoice::Always,
ArgsColorChoice::Never => ColorChoice::Never,
};
run(args, colormode)?;
Ok(ExitCode::SUCCESS)
}
fn run(args: Args, colormode: ColorChoice) -> Result<()> {
let include_private = args.all || !args.no_private;
let include_loopback = args.all || !args.no_loopback;
let include_broadcast = args.all || !args.no_broadcast;
let extractor = if !include_private && !include_loopback && !include_broadcast {
ExtractorBuilder::new().only_public().build()?
} else {
let mut builder = ExtractorBuilder::new();
if !include_private {
builder.ignore_private();
}
if !include_loopback {
builder.ignore_loopback();
}
if !include_broadcast {
builder.ignore_broadcast();
}
builder.build()?
};
let mut provider_registry = mmdb::ProviderRegistry::default();
provider_registry.set_active_provider(&args.provider)?;
provider_registry.initialize_active_provider(args.include.clone())?;
let geoipdb = geoip::GeoIPSed::new_with_provider(
args.include.clone(),
args.template.clone(),
colormode,
args.only_routable,
provider_registry,
)?;
let mut out = io::BufWriter::with_capacity(65536, StandardStream::stdout(colormode));
if args.tag_files {
files::tag_files(&args.input, &extractor, &mut out)?;
out.flush()?;
return Ok(());
}
let mut cache: HashMap<Vec<u8>, String> =
HashMap::with_capacity_and_hasher(4096, BuildHasherDefault::default());
let only_matching = args.only_matching;
let tag_mode = args.tag;
let mut line_buffer = LineBufferBuilder::new().capacity(65536).build();
for path in args.input {
let file = FileOrStdin::from_path(path);
let reader = file.reader()?;
let mut lb_reader = LineBufferReader::new(reader, &mut line_buffer);
while lb_reader.fill()? {
let buffer = lb_reader.buffer();
let lines = LineIter::new(b'\n', buffer);
for line in lines {
if only_matching {
for m in extractor.match_iter(line) {
let refanged = m.as_str();
if let Some(cached) = cache.get(refanged.as_bytes()) {
out.write_all(cached.as_bytes())?;
out.write_all(b"\n")?;
} else {
let result = geoipdb.lookup(m.ip(), &refanged);
out.write_all(result.as_bytes())?;
out.write_all(b"\n")?;
if cache.len() < 100_000 {
cache.insert(refanged.into_owned().into_bytes(), result);
}
}
}
} else if tag_mode {
let mut tagged = Tagged::new(line);
for m in extractor.match_iter(line) {
tagged = tagged.tag(
Tag::new(m.as_matched_str(), m.as_str())
.with_range(m.range())
.with_decoration(String::new()),
);
}
tagged.write_json(&mut out)?;
} else {
extractor.replace_iter(line, &mut out, |m: &IpMatch, w| {
let refanged = m.as_str();
if let Some(cached) = cache.get(refanged.as_bytes()) {
w.write_all(cached.as_bytes())
} else {
let result = geoipdb.lookup(m.ip(), &refanged);
w.write_all(result.as_bytes())?;
if cache.len() < 100_000 {
cache.insert(refanged.into_owned().into_bytes(), result);
}
Ok(())
}
})?;
}
}
lb_reader.consume_all();
}
out.flush()?;
}
Ok(())
}