use log::debug;
use std::io;
use std::mem;
use std::net;
use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
use std::str;
use std::time;
const MODE_MASK: u8 = 0b0000_0111;
const MODE_SHIFT: u8 = 0;
const VERSION_MASK: u8 = 0b0011_1000;
const VERSION_SHIFT: u8 = 3;
const LI_MASK: u8 = 0b1100_0000;
const LI_SHIFT: u8 = 6;
struct NtpPacket {
li_vn_mode: u8,
stratum: u8,
poll: i8,
precision: i8,
root_delay: u32,
root_dispersion: u32,
ref_id: u32,
ref_timestamp: u64,
origin_timestamp: u64,
recv_timestamp: u64,
tx_timestamp: u64,
}
impl NtpPacket {
const NTP_TIMESTAMP_DELTA: u32 = 2_208_988_800u32;
const SNTP_CLIENT_MODE: u8 = 3;
const SNTP_VERSION: u8 = 4 << 3;
#[allow(dead_code)]
const LI_MASK: u8 = 0b0000_0011;
#[allow(dead_code)]
const VN_MASK: u8 = 0b0001_1100;
#[allow(dead_code)]
const MODE_MASK: u8 = 0b1110_0000;
pub fn new() -> NtpPacket {
let now_since_unix = time::SystemTime::now()
.duration_since(time::SystemTime::UNIX_EPOCH)
.unwrap();
let tx_timestamp = ((now_since_unix.as_secs()
+ (u64::from(NtpPacket::NTP_TIMESTAMP_DELTA)))
<< 32)
+ u64::from(now_since_unix.subsec_micros());
NtpPacket {
li_vn_mode: NtpPacket::SNTP_CLIENT_MODE | NtpPacket::SNTP_VERSION,
stratum: 0,
poll: 0,
precision: 0,
root_delay: 0,
root_dispersion: 0,
ref_id: 0,
ref_timestamp: 0,
origin_timestamp: 0,
recv_timestamp: 0,
tx_timestamp,
}
}
}
trait NtpNum {
type Type;
fn ntohl(&self) -> Self::Type;
}
impl NtpNum for u32 {
type Type = u32;
fn ntohl(&self) -> Self::Type {
self.to_be()
}
}
impl NtpNum for u64 {
type Type = u64;
fn ntohl(&self) -> Self::Type {
self.to_be()
}
}
struct RawNtpPacket([u8; mem::size_of::<NtpPacket>()]);
impl Default for RawNtpPacket {
fn default() -> Self {
RawNtpPacket([0u8; mem::size_of::<NtpPacket>()])
}
}
impl From<RawNtpPacket> for NtpPacket {
fn from(val: RawNtpPacket) -> Self {
let to_array_u32 = |x: &[u8]| {
let mut temp_buf = [0u8; mem::size_of::<u32>()];
temp_buf.copy_from_slice(x);
temp_buf
};
let to_array_u64 = |x: &[u8]| {
let mut temp_buf = [0u8; mem::size_of::<u64>()];
temp_buf.copy_from_slice(x);
temp_buf
};
NtpPacket {
li_vn_mode: val.0[0],
stratum: val.0[1],
poll: val.0[2] as i8,
precision: val.0[3] as i8,
root_delay: u32::from_le_bytes(to_array_u32(&val.0[4..8])),
root_dispersion: u32::from_le_bytes(to_array_u32(&val.0[8..12])),
ref_id: u32::from_le_bytes(to_array_u32(&val.0[12..16])),
ref_timestamp: u64::from_le_bytes(to_array_u64(&val.0[16..24])),
origin_timestamp: u64::from_le_bytes(to_array_u64(&val.0[24..32])),
recv_timestamp: u64::from_le_bytes(to_array_u64(&val.0[32..40])),
tx_timestamp: u64::from_le_bytes(to_array_u64(&val.0[40..48])),
}
}
}
impl From<&NtpPacket> for RawNtpPacket {
fn from(val: &NtpPacket) -> Self {
let mut tmp_buf = [0u8; mem::size_of::<NtpPacket>()];
tmp_buf[0] = val.li_vn_mode;
tmp_buf[1] = val.stratum;
tmp_buf[2] = val.poll as u8;
tmp_buf[3] = val.precision as u8;
tmp_buf[4..8].copy_from_slice(&val.root_delay.to_be_bytes());
tmp_buf[8..12].copy_from_slice(&val.root_dispersion.to_be_bytes());
tmp_buf[12..16].copy_from_slice(&val.ref_id.to_be_bytes());
tmp_buf[16..24].copy_from_slice(&val.ref_timestamp.to_be_bytes());
tmp_buf[24..32].copy_from_slice(&val.origin_timestamp.to_be_bytes());
tmp_buf[32..40].copy_from_slice(&val.recv_timestamp.to_be_bytes());
tmp_buf[40..48].copy_from_slice(&val.tx_timestamp.to_be_bytes());
RawNtpPacket(tmp_buf)
}
}
pub fn request(pool: &str, port: u32) -> io::Result<u32> {
debug!("Pool: {}", pool);
let socket = net::UdpSocket::bind("0.0.0.0:0")
.expect("Unable to create a UDP socket");
let dest = format!("{}:{}", pool, port).to_socket_addrs()?;
socket
.set_read_timeout(Some(time::Duration::new(2, 0)))
.expect("Unable to set up socket timeout");
let req = NtpPacket::new();
let dest = process_request(dest, &req, &socket)?;
let mut buf: RawNtpPacket = RawNtpPacket::default();
let (response, src) = socket.recv_from(buf.0.as_mut())?;
debug!("Response: {}", response);
if src != dest {
return Err(io::Error::new(
io::ErrorKind::Other,
"SNTP response port / address mismatch",
));
}
if response == mem::size_of::<NtpPacket>() {
let result = process_response(&req, buf);
match result {
Ok(timestamp) => return Ok(timestamp),
Err(err_str) => {
return Err(io::Error::new(io::ErrorKind::Other, err_str));
}
}
}
Err(io::Error::new(
io::ErrorKind::Other,
"Incorrect NTP packet size read",
))
}
fn process_request(
dest: std::vec::IntoIter<SocketAddr>,
req: &NtpPacket,
socket: &UdpSocket,
) -> io::Result<SocketAddr> {
for addr in dest {
debug!("Address: {}", &addr);
match send_request(&req, &socket, addr) {
Ok(write_bytes) => {
assert_eq!(write_bytes, mem::size_of::<NtpPacket>());
return Ok(addr);
}
Err(err) => debug!("{}. Try another one", err),
}
}
Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"SNTP servers not responding",
))
}
fn send_request(
req: &NtpPacket,
socket: &net::UdpSocket,
dest: net::SocketAddr,
) -> io::Result<usize> {
let buf: RawNtpPacket = req.into();
socket.send_to(&buf.0, dest)
}
fn process_response(
req: &NtpPacket,
resp: RawNtpPacket,
) -> Result<u32, &'static str> {
const SNTP_UNICAST: u8 = 4;
const SNTP_BROADCAST: u8 = 5;
const LI_MAX_VALUE: u8 = 3;
let shifter = |val, mask, shift| (val & mask) >> shift;
let mut packet = NtpPacket::from(resp);
convert_from_network(&mut packet);
#[cfg(debug_assertions)]
debug_ntp_packet(&packet);
if req.tx_timestamp != packet.origin_timestamp {
return Err("Incorrect origin timestamp");
}
let mode = shifter(packet.li_vn_mode, MODE_MASK, MODE_SHIFT);
let li = shifter(packet.li_vn_mode, LI_MASK, LI_SHIFT);
let resp_version = shifter(packet.li_vn_mode, VERSION_MASK, VERSION_SHIFT);
let req_version = shifter(req.li_vn_mode, VERSION_MASK, VERSION_SHIFT);
if mode != SNTP_UNICAST && mode != SNTP_BROADCAST {
return Err("Incorrect MODE value");
}
if li > LI_MAX_VALUE {
return Err("Incorrect LI value");
}
if req_version != resp_version {
return Err("Incorrect response version");
}
if packet.stratum == 0 {
return Err("Incorrect STRATUM headers");
}
if packet.origin_timestamp == 0 || packet.recv_timestamp == 0 {
return Err("Invalid origin/receive timestamp");
}
if packet.tx_timestamp == 0 {
return Err("Transmit timestamp is 0");
}
let seconds = (packet.tx_timestamp >> 32) as u32;
let tx_tm = seconds - NtpPacket::NTP_TIMESTAMP_DELTA;
Ok(tx_tm)
}
fn convert_from_network(packet: &mut NtpPacket) {
fn ntohl<T: NtpNum>(val: T) -> T::Type {
val.ntohl()
}
packet.root_delay = ntohl(packet.root_delay);
packet.root_dispersion = ntohl(packet.root_dispersion);
packet.ref_id = ntohl(packet.ref_id);
packet.ref_timestamp = ntohl(packet.ref_timestamp);
packet.origin_timestamp = ntohl(packet.origin_timestamp);
packet.recv_timestamp = ntohl(packet.recv_timestamp);
packet.tx_timestamp = ntohl(packet.tx_timestamp);
}
#[cfg(debug_assertions)]
fn debug_ntp_packet(packet: &NtpPacket) {
let shifter = |val, mask, shift| (val & mask) >> shift;
let mode = shifter(packet.li_vn_mode, MODE_MASK, MODE_SHIFT);
let version = shifter(packet.li_vn_mode, VERSION_MASK, VERSION_SHIFT);
let li = shifter(packet.li_vn_mode, LI_MASK, LI_SHIFT);
debug!("{}", (0..52).map(|_| "=").collect::<String>());
debug!("| Mode:\t\t{}", mode);
debug!("| Version:\t{}", version);
debug!("| Leap:\t\t{}", li);
debug!("| Stratum:\t{}", packet.stratum);
debug!("| Poll:\t\t{}", packet.poll);
debug!("| Precision:\t\t{}", packet.precision);
debug!("| Root delay:\t\t{}", packet.root_delay);
debug!("| Root dispersion:\t{}", packet.root_dispersion);
debug!(
"| Reference ID:\t\t{}",
str::from_utf8(&packet.ref_id.to_be_bytes()).unwrap_or("")
);
debug!("| Reference timestamp:\t{:>16}", packet.ref_timestamp);
debug!("| Origin timestamp:\t\t{:>16}", packet.origin_timestamp);
debug!("| Receive timestamp:\t\t{:>16}", packet.recv_timestamp);
debug!("| Transmit timestamp:\t\t{:>16}", packet.tx_timestamp);
debug!("{}", (0..52).map(|_| "=").collect::<String>());
}