use anyhow::{anyhow, Result};
use atty;
use clap::{ArgAction, Parser};
use diffsquare::factor::difference_of_squares;
use indicatif::{ProgressBar, ProgressStyle};
use malachite::{
base::num::conversion::traits::{FromSciString, FromStringBase},
Integer,
};
use rayon::prelude::*;
use serde::Serialize;
use std::{
fs::OpenOptions,
io::{self, Read, Write},
thread,
time::{Duration, Instant},
};
#[derive(Parser)]
#[command(
version = env!("CARGO_PKG_VERSION"),
disable_help_flag = true,
disable_version_flag = true,
about,
long_about = None
)]
struct Args {
#[arg(short = 'n', long = "mod", display_order = 1)]
modulus: Option<String>,
#[arg(short, long, display_order = 2)]
iter: Option<String>,
#[arg(short, long, display_order = 3)]
prec: Option<u64>,
#[arg(short, long, display_order = 4)]
quiet: bool,
#[arg(long, display_order = 5)]
json: bool,
#[arg(long, display_order = 6)]
csv: bool,
#[arg(long, display_order = 7)]
time_only: bool,
#[arg(long, display_order = 8)]
stdin: bool,
#[arg(long, display_order = 9)]
input: Option<String>,
#[arg(long, display_order = 10)]
threads: Option<usize>,
#[arg(long, display_order = 11)]
output: Option<String>,
#[arg(long, display_order = 12)]
timeout: Option<u64>,
#[arg(short = 'h', long = "help", action = ArgAction::Help, display_order = 100)]
help: Option<bool>,
#[arg(short = 'v', long = "version", action = ArgAction::Version, display_order = 101)]
version: Option<bool>,
}
impl Args {
fn is_quiet(&self) -> bool {
self.quiet || self.json || self.csv || self.time_only
}
}
#[derive(Serialize)]
struct JsonResult {
modulus: String,
factor_1: String,
factor_2: String,
iterations: String,
time_ms: u128,
}
fn input(prompt: &str) -> Result<String> {
print!("{}", prompt);
io::stdout().flush()?;
let mut s = String::new();
io::stdin().read_line(&mut s)?;
Ok(s.trim().to_string())
}
fn parse_bigint(s: &str) -> Result<Integer> {
if s.starts_with("0x") || s.starts_with("0X") {
Integer::from_string_base(16, &s[2..]).ok_or_else(|| anyhow!("Invalid hex"))
} else {
Integer::from_sci_string(s).ok_or_else(|| anyhow!("Invalid integer"))
}
}
fn write_output(file: &str, content: &str) -> Result<()> {
let mut f = OpenOptions::new().create(true).append(true).open(file)?;
writeln!(f, "{}", content)?;
Ok(())
}
trait JoinTimeout<T> {
fn join_timeout(self, dur: Duration) -> Option<T>;
}
impl<T: Send + 'static> JoinTimeout<T> for thread::JoinHandle<T> {
fn join_timeout(self, dur: Duration) -> Option<T> {
use std::sync::mpsc::channel;
let (tx, rx) = channel();
thread::spawn(move || {
if let Ok(val) = self.join() {
let _ = tx.send(val);
}
});
rx.recv_timeout(dur).ok()
}
}
fn factor_and_print(
n: Integer,
iter: Integer,
prec: u64,
args: &Args,
write_if_needed: &dyn Fn(&str) -> Result<()>,
) -> Result<()> {
let start_time = Instant::now();
let quiet = args.is_quiet();
let mut iter_clone = iter.clone();
let result = if let Some(ms) = args.timeout {
let n_clone = n.clone();
let handle = thread::spawn(move || {
difference_of_squares(&n_clone, &mut iter_clone, prec, quiet)
.map(|(p, q)| (p, q, iter_clone))
});
handle.join_timeout(Duration::from_millis(ms)).flatten()
} else {
difference_of_squares(&n, &mut iter_clone, prec, quiet).map(|(p, q)| (p, q, iter_clone))
};
let duration = start_time.elapsed();
if let Some((p, q, iterations)) = result {
if args.csv {
let out = format!("{},{},{},{},{}", n, p, q, iterations, duration.as_millis());
println!("{}", &out);
write_if_needed(&out)?;
} else if args.json {
let result = JsonResult {
modulus: n.to_string(),
factor_1: p.to_string(),
factor_2: q.to_string(),
iterations: iterations.to_string(),
time_ms: duration.as_millis(),
};
let out = serde_json::to_string_pretty(&result)?;
println!("{}", &out);
write_if_needed(&out)?;
} else if args.time_only {
let out = duration.as_millis().to_string();
println!("{}", &out);
write_if_needed(&out)?;
} else if args.quiet {
let out = format!("{} {}", p, q);
println!("{}", &out);
write_if_needed(&out)?;
} else {
let out = format!(
"\n✅ Factors of {}:\n\np = {}\nq = {}\n⏱️ Execution time: {:?}",
n, p, q, duration
);
println!("{}", &out);
write_if_needed(&out)?;
}
} else {
let err = if args.csv {
format!("{},ERROR,ERROR,ERROR,ERROR", n)
} else if args.json {
format!(
"{{\n \"modulus\": \"{}\",\n \"error\": \"Factorization failed or timed out\"\n}}",
n
)
} else {
format!("❌ Failed to factor {} (timeout or failure).", n)
};
eprintln!("{}", &err);
write_if_needed(&err)?;
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
let write_if_needed = |content: &str| -> Result<()> {
if let Some(ref file) = args.output {
write_output(file, content)?;
}
Ok(())
};
let prec = args.prec.unwrap_or(30);
if args.stdin || args.input.is_some() {
let inputs: Vec<String> = if args.stdin {
io::stdin()
.lines()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.filter(|line| !line.trim().is_empty())
.collect()
} else {
let file = args.input.as_ref().unwrap();
std::fs::read_to_string(file)?
.lines()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
};
if let Some(t) = args.threads {
rayon::ThreadPoolBuilder::new()
.num_threads(t)
.build_global()?;
}
let pb = if !args.is_quiet() {
let pb = ProgressBar::new(inputs.len() as u64);
pb.set_style(
ProgressStyle::with_template("{bar:40.cyan/blue} {pos}/{len} [{elapsed_precise}]")
.unwrap()
.progress_chars("=> "),
);
Some(pb)
} else {
None
};
inputs.par_iter().for_each(|input| {
let n = match parse_bigint(input) {
Ok(val) => val,
Err(e) => {
eprintln!("❌ Error parsing '{}': {e}", input);
return;
}
};
let iter = Integer::from(1);
let _ = factor_and_print(n, iter, prec, &args, &write_if_needed);
if let Some(ref pb) = pb {
pb.inc(1);
}
});
if let Some(ref pb) = pb {
pb.finish_with_message("Done");
}
} else {
if args.modulus.is_some() || !atty::is(atty::Stream::Stdin) {
let n = if let Some(ref m) = args.modulus {
parse_bigint(m)?
} else {
let mut s = String::new();
io::stdin().read_to_string(&mut s)?;
let cleaned = s.replace("\\\n", "").replace('\n', "").trim().to_string();
parse_bigint(&cleaned)?
};
let iter = if let Some(ref i) = args.iter {
parse_bigint(i)?
} else {
Integer::from(1)
};
factor_and_print(n, iter, prec, &args, &write_if_needed)?;
} else if args.is_quiet() {
return Err(anyhow!(
"Modulus must be provided in quiet/json/csv/time-only mode (prompts are disabled)"
));
} else {
loop {
let m = input("Modulus (or type 'exit' to quit): ")?;
if m.eq_ignore_ascii_case("exit") || m.eq_ignore_ascii_case("quit") {
println!("👋 Exiting diffsquare.");
break;
}
let n = match parse_bigint(&m) {
Ok(val) => val,
Err(e) => {
eprintln!("❌ Invalid input: {e}");
continue;
}
};
let iter = if let Some(ref i) = args.iter {
parse_bigint(i)?
} else {
Integer::from(1)
};
factor_and_print(n, iter.clone(), prec, &args, &write_if_needed)?;
}
}
}
Ok(())
}