use std::{
io::{Error, ErrorKind, Result},
net::{IpAddr, SocketAddr},
ops::Range,
path::PathBuf,
time::Duration,
};
use clap::{Parser, Subcommand};
use color_print::ceprintln;
use futures::executor::block_on;
use n3agent::Agent;
use n3io::reactor::{Reactor, set_global_reactor};
fn parse_port_range(arg: &str) -> std::result::Result<Range<u16>, String> {
let parts = arg.split(":").collect::<Vec<_>>();
match parts.len() {
1 => {
let port = parts[0]
.parse::<u16>()
.map_err(|err| format!("failed to parse port: {}", err.to_string()))?;
return Ok(port..port + 1);
}
2 => {
let from = parts[0]
.parse::<u16>()
.map_err(|err| format!("failed to parse port(from): {}", err.to_string()))?;
let to = parts[1]
.parse::<u16>()
.map_err(|err| format!("failed to parse port(to): {}", err.to_string()))?;
if !(to > from) {
return Err(format!("failed to parse port range: ensure `to > from`"));
}
return Ok(from..to);
}
_ => {
return Err("Invalid port range, valid syntax: `xxx:xxx` or `xxx`".to_owned());
}
}
}
#[derive(Parser)]
#[command(version, about, long_about = None)]
struct Cli {
#[arg(long, value_name = "PROTO_LIST", default_values_t = ["n3".to_string()])]
protos: Vec<String>,
#[arg(short = 'i', long, value_name = "ADDR")]
n3_ip: IpAddr,
#[arg(short = 'p', long, value_name = "PORT", value_parser=parse_port_range)]
n3_port_range: Range<u16>,
#[arg(short, long, value_name = "PEM_FILE")]
cert: Option<PathBuf>,
#[arg(short, long, value_name = "PEM_FILE", default_value = "n3.key")]
key: PathBuf,
#[arg(long, value_name = "SIZE", default_value_t = 1024 * 1024 * 10)]
initial_max_stream_data: u64,
#[arg(long, value_name = "SIZE", default_value_t = 60 * 1000)]
max_idle_timeout: u64,
#[arg(long, value_name = "SIZE", default_value_t = 40)]
max_ack_delay: u64,
#[arg(long, value_name = "STREAMS", default_value_t = 100)]
initial_max_streams: u64,
#[arg(long, value_name = "INTERVAL", default_value_t = 200)]
io_timer_tick_interval: u64,
#[arg(short, long, default_value_t = false, action)]
debug: bool,
#[command(subcommand)]
commands: Commands,
}
#[derive(Subcommand)]
enum Commands {
Listen {
target: Option<SocketAddr>,
},
}
fn parse_n3_addrs(cli: &Cli) -> Result<Vec<SocketAddr>> {
let mut laddrs: Vec<SocketAddr> = vec![];
for port in cli.n3_port_range.clone() {
laddrs.push(SocketAddr::new(cli.n3_ip, port));
}
Ok(laddrs)
}
async fn run_agent(cli: Cli, laddr: SocketAddr) -> Result<()> {
let n3_addrs = parse_n3_addrs(&cli)?;
if n3_addrs.is_empty() {
return Err(Error::new(ErrorKind::InvalidInput, "`n3_addrs` ."));
}
let protos = cli
.protos
.iter()
.map(|proto| proto.as_bytes())
.collect::<Vec<_>>();
Agent::new(n3_addrs.as_slice())
.connector(|connector| {
connector.quiche_config(|config| {
config.set_initial_max_data(cli.initial_max_streams * cli.initial_max_stream_data);
config.set_initial_max_stream_data_bidi_local(cli.initial_max_stream_data);
config.set_initial_max_stream_data_bidi_remote(cli.initial_max_stream_data);
config.set_initial_max_stream_data_uni(cli.initial_max_stream_data);
config.set_initial_max_streams_bidi(cli.initial_max_streams);
config.set_initial_max_streams_uni(cli.initial_max_streams);
config.set_max_idle_timeout(cli.max_idle_timeout);
config.set_max_ack_delay(cli.max_ack_delay);
if let Some(cert) = &cli.cert {
config
.load_cert_chain_from_pem_file(cert.to_str().unwrap())
.map_err(|err| {
Error::new(
ErrorKind::NotFound,
format!(
"Unable to load certificate chain file {:?}, {}",
cli.cert, err
),
)
})?;
}
config
.load_priv_key_from_pem_file(cli.key.to_str().unwrap())
.map_err(|err| {
Error::new(
ErrorKind::NotFound,
format!("Unable to load key file {:?}, {}", cli.key, err),
)
})?;
config.set_application_protos(&protos).map_err(|err| {
Error::new(
ErrorKind::InvalidInput,
format!(
"failed to set application protos as {:?}, {}",
cli.protos, err
),
)
})?;
Ok(())
})
})
.bind(laddr)
.await
}
async fn run_n3_agent() -> Result<()> {
let cli = Cli::parse();
let io_timer_tick_interval = cli.io_timer_tick_interval;
set_global_reactor(move || {
let (reactor, _) =
Reactor::with_background_thread(Duration::from_millis(io_timer_tick_interval), 1024)
.unwrap();
reactor
});
if cli.debug {
pretty_env_logger::try_init_timed().map_err(Error::other)?;
}
match cli.commands {
Commands::Listen { target } => {
run_agent(
cli,
target.unwrap_or("[::]:1812".parse().map_err(Error::other)?),
)
.await?;
}
}
Ok(())
}
fn main() {
if let Err(err) = block_on(run_n3_agent()) {
ceprintln!("<s><r>error:</r></s> {}", err)
}
}