#![warn(
missing_debug_implementations,
// missing_docs, // TODO
rust_2018_idioms,
non_snake_case,
non_upper_case_globals
)]
#![deny(broken_intra_doc_links)]
#![allow(clippy::cognitive_complexity)]
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
os::unix::prelude::{FromRawFd, IntoRawFd},
sync::Arc,
time::Instant,
};
#[cfg(feature = "script")]
use std::path::PathBuf;
use anyhow::{Context, Result};
use argh::FromArgs;
use crossbeam_channel::{Receiver, Sender};
use dhcproto::{v4, v6};
use mac_address::MacAddress;
use opts::LogStructure;
use tracing::{error, info, trace};
mod decline;
mod discover;
mod inform;
mod opts;
mod release;
mod request;
mod runner;
#[cfg(feature = "script")]
mod script;
use opts::{parse_mac, parse_opts, parse_params};
use runner::TimeoutRunner;
use crate::{
decline::DeclineArgs, discover::DiscoverArgs, inform::InformArgs, release::ReleaseArgs,
request::RequestArgs, util::Msg,
};
#[allow(clippy::collapsible_else_if)]
fn main() -> Result<()> {
let mut args: Args = argh::from_env();
let mut default_port = false;
if args.port.is_none() {
default_port = true;
if args.target.is_ipv6() {
args.port = Some(546);
} else {
args.port = Some(67);
}
}
if args.bind.is_none() {
if args.target.is_ipv6() {
if default_port {
args.bind = Some(SocketAddr::new(
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
v6::CLIENT_PORT,
));
} else {
args.bind = Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0));
}
} else {
if default_port {
args.bind = Some(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
v4::CLIENT_PORT,
));
} else {
args.bind = Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0));
}
}
}
opts::init_tracing(&args);
trace!(?args);
let bind_addr: SocketAddr = args.bind.context("bind address must be specified")?;
let soc = match args.interface {
Some(ref int) => {
let socket = socket2::Socket::new(
if args.target.is_ipv6() {
socket2::Domain::IPV6
} else {
socket2::Domain::IPV4
},
socket2::Type::DGRAM,
None,
)?;
socket.bind_device(Some(int.as_bytes()))?;
socket.bind(&bind_addr.into())?;
info!("bound to interface {}", int);
#[cfg(windows)]
unsafe {
UdpSocket::from_raw_socket(socket.into_raw_socket())?
}
#[cfg(unix)]
unsafe {
UdpSocket::from_raw_fd(socket.into_raw_fd())
}
}
None => UdpSocket::bind(bind_addr)?,
};
soc.set_broadcast(true)?;
let soc = Arc::new(soc);
let shutdown_rx = ctrl_channel()?;
let (send_tx, send_rx) = crossbeam_channel::bounded(1);
let (recv_tx, recv_rx) = crossbeam_channel::bounded(1);
runner::sender_thread(send_rx, soc.clone());
runner::recv_thread(recv_tx, soc);
let start = Instant::now();
#[cfg(feature = "script")]
if let Some(path) = &args.script {
info!("evaluating rhai script");
let mut args = args.clone();
args.no_retry = true;
if let Err(err) = script::main(
path,
TimeoutRunner {
args,
shutdown_rx,
send_tx,
recv_rx,
},
) {
error!(?err, "error running rhai script");
}
info!(elapsed = %util::PrettyTime(start.elapsed()), "script completed");
return Ok(());
}
let mut new_args = args.clone();
let msg = run_it(
move || match &new_args.msg {
Some(MsgType::Dora(dora)) => {
new_args.msg = Some(MsgType::Discover(dora.discover()));
new_args
}
_ => new_args,
},
shutdown_rx.clone(),
send_tx.clone(),
recv_rx.clone(),
)?;
let new_args = match (&args.msg, msg) {
(Some(MsgType::Dora(dora)), Msg::V4(msg)) => {
let mut new_args = args.clone();
new_args.msg = Some(MsgType::Request(dora.request(msg.yiaddr())));
new_args
}
_ => {
drop(send_tx);
drop(recv_rx);
return Ok(());
}
};
run_it(move || new_args, shutdown_rx, send_tx, recv_rx)?;
info!(elapsed = %util::PrettyTime(start.elapsed()), "total time");
Ok(())
}
fn run_it<F: FnOnce() -> Args>(
f: F,
shutdown_rx: Receiver<()>,
send_tx: Sender<(Msg, SocketAddr, bool)>,
recv_rx: Receiver<(Msg, SocketAddr)>,
) -> Result<Msg> {
let args = f();
let runner = TimeoutRunner {
args,
shutdown_rx,
send_tx,
recv_rx,
};
match runner.send() {
Err(err) => {
error!(%err, "got an error");
Err(err)
}
Ok(msg) => Ok(msg),
}
}
fn ctrl_channel() -> Result<Receiver<()>> {
let (sender, receiver) = crossbeam_channel::bounded(1);
ctrlc::set_handler(move || {
let _ = sender.send(());
})?;
Ok(receiver)
}
#[derive(Debug, FromArgs, Clone, PartialEq)]
#[argh(description = "dhcpm is a cli tool for sending dhcpv4/v6 messages
ex dhcpv4:
dhcpm 255.255.255.255 discover (broadcast discover to default dhcp port)
dhcpm 192.168.0.255 discover (broadcast discover on interface bound to 192.168.0.x)
dhcpm 0.0.0.0 -p 9901 discover (unicast discover to 0.0.0.0:9901)
dhcpm 192.168.0.1 dora (unicast DORA to 192.168.0.1)
dhcpm 192.168.0.1 dora -o 118,C0A80001 (unicast DORA, incl opt 118:192.168.0.1)
dhcpv6:
dhcpm ::0 -p 9901 solicit (unicast solicit to [::0]:9901)
dhcpm ff02::1:2 solicit (multicast solicit to default port)
")]
pub struct Args {
#[argh(positional)]
pub target: IpAddr,
#[argh(subcommand)]
pub msg: Option<MsgType>,
#[argh(option, short = 'b')]
pub bind: Option<SocketAddr>,
#[argh(option, short = 'i')]
pub interface: Option<String>,
#[argh(option, short = 'p')]
pub port: Option<u16>,
#[argh(option, short = 't', default = "opts::default_timeout()")]
pub timeout: u64,
#[argh(option, default = "LogStructure::Pretty")]
pub output: LogStructure,
#[cfg(feature = "script")]
#[argh(option)]
pub script: Option<PathBuf>,
#[argh(option, default = "false")]
pub no_retry: bool,
}
impl Args {
pub fn get_target(&self) -> (SocketAddr, bool) {
match self.target {
IpAddr::V4(addr) => {
let [_, _, _, brd] = addr.octets();
if addr.is_broadcast() || brd == 255_u8 {
trace!("using broadcast address");
((self.target, self.port.unwrap()).into(), true)
} else {
((self.target, self.port.unwrap()).into(), false)
}
}
IpAddr::V6(addr) if addr.is_multicast() => ((addr, self.port.unwrap()).into(), true),
IpAddr::V6(addr) => ((IpAddr::V6(addr), self.port.unwrap()).into(), false),
}
}
}
#[derive(PartialEq, Debug, Clone, FromArgs)]
#[argh(subcommand)]
pub enum MsgType {
Discover(DiscoverArgs),
Request(RequestArgs),
Release(ReleaseArgs),
Inform(InformArgs),
Decline(DeclineArgs),
Dora(DoraArgs),
Solicit(SolicitArgs),
}
#[derive(FromArgs, PartialEq, Debug, Clone)]
#[argh(subcommand, name = "dora")]
pub struct DoraArgs {
#[argh(
option,
short = 'c',
from_str_fn(parse_mac),
default = "opts::get_mac()"
)]
pub chaddr: MacAddress,
#[argh(option, default = "Ipv4Addr::UNSPECIFIED")]
pub ciaddr: Ipv4Addr,
#[argh(option, short = 'y', default = "Ipv4Addr::UNSPECIFIED")]
pub yiaddr: Ipv4Addr,
#[argh(option, short = 's')]
pub sident: Option<Ipv4Addr>,
#[argh(option, short = 'r')]
pub req_addr: Option<Ipv4Addr>,
#[argh(option, short = 'g', default = "Ipv4Addr::UNSPECIFIED")]
pub giaddr: Ipv4Addr,
#[argh(option)]
pub subnet_select: Option<Ipv4Addr>,
#[argh(option)]
pub relay_link: Option<Ipv4Addr>,
#[argh(option, short = 'o', from_str_fn(parse_opts))]
pub opt: Vec<v4::DhcpOption>,
#[argh(option, from_str_fn(parse_params), default = "opts::default_params()")]
pub params: Vec<v4::OptionCode>,
}
impl DoraArgs {
pub fn discover(&self) -> DiscoverArgs {
DiscoverArgs {
chaddr: self.chaddr,
ciaddr: self.ciaddr,
req_addr: self.req_addr,
giaddr: self.giaddr,
subnet_select: self.subnet_select,
relay_link: self.relay_link,
opt: self.opt.clone(),
params: self.params.clone(),
}
}
pub fn request(&self, req_addr: Ipv4Addr) -> RequestArgs {
RequestArgs {
chaddr: self.chaddr,
ciaddr: self.ciaddr,
yiaddr: self.yiaddr,
req_addr: Some(req_addr),
sident: self.sident,
giaddr: self.giaddr,
subnet_select: self.subnet_select,
relay_link: self.relay_link,
opt: self.opt.clone(),
params: self.params.clone(),
}
}
}
#[derive(FromArgs, PartialEq, Debug, Clone, Copy)]
#[argh(subcommand, name = "solicit")]
pub struct SolicitArgs {}
pub mod util {
use std::{fmt, time::Duration};
use anyhow::Result;
use dhcproto::{v4, v6, Encodable};
#[derive(Clone, PartialEq, Eq)]
pub enum Msg {
V4(v4::Message),
V6(v6::Message),
}
impl Msg {
pub fn get_type(&self) -> String {
match self {
Msg::V4(m) => format!("{:?}", m.opts().msg_type().unwrap()),
Msg::V6(m) => format!("{:?}", m.opts()),
}
}
#[cfg(feature = "script")]
pub fn unwrap_v4(self) -> v4::Message {
match self {
Msg::V4(m) => m,
_ => panic!("unwrapped wrong variant on message"),
}
}
pub fn to_vec(&self) -> Result<Vec<u8>> {
Ok(match self {
Msg::V4(m) => m.to_vec()?,
Msg::V6(m) => m.to_vec()?,
})
}
}
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct PrettyTime(pub Duration);
impl fmt::Display for PrettyTime {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let secs = self.0.as_secs_f32().to_string();
write!(f, "{}s", if secs.len() <= 5 { &secs } else { &secs[0..=5] })
}
}
impl fmt::Debug for PrettyTime {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct PrettyPrint<T>(pub T);
impl<T: fmt::Debug> fmt::Display for PrettyPrint<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:#?}", &self.0)
}
}
impl<T: fmt::Debug> fmt::Debug for PrettyPrint<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl fmt::Debug for Msg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Msg::V4(msg) => {
f.debug_struct("v4::Message")
.field("xid", &msg.xid())
.field("broadcast flag", &msg.flags().broadcast())
.field("ciaddr", &msg.ciaddr())
.field("yiaddr", &msg.yiaddr())
.field("siaddr", &msg.siaddr())
.field("giaddr", &msg.giaddr())
.field("chaddr", &format!("0x{}", hex::encode(msg.chaddr())))
.field("opts", &msg.opts())
.finish()
}
Msg::V6(_msg) => {
todo!("unfinished")
}
}
}
}
}