use crate::get_ntp_timestamp;
#[cfg(any(feature = "log", feature = "defmt"))]
use crate::log::debug;
use crate::net::SocketAddr;
use cfg_if::cfg_if;
use core::fmt::Formatter;
use core::fmt::{Debug, Display};
use core::future::Future;
use core::mem;
pub(crate) const MODE_MASK: u8 = 0b0000_0111;
pub(crate) const MODE_SHIFT: u8 = 0;
pub(crate) const VERSION_MASK: u8 = 0b0011_1000;
pub(crate) const VERSION_SHIFT: u8 = 3;
pub(crate) const LI_MASK: u8 = 0b1100_0000;
pub(crate) const LI_SHIFT: u8 = 6;
pub(crate) const PSEC_IN_SEC: u64 = 1_000_000_000_000;
pub(crate) const NSEC_IN_SEC: u32 = 1_000_000_000;
pub(crate) const USEC_IN_SEC: u32 = 1_000_000;
pub(crate) const MSEC_IN_SEC: u32 = 1_000;
pub(crate) const SECONDS_MASK: u64 = 0xffff_ffff_0000_0000;
pub(crate) const SECONDS_FRAC_MASK: u64 = 0xffff_ffff;
pub type Result<T> = core::result::Result<T, Error>;
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub(crate) struct NtpPacket {
pub(crate) li_vn_mode: u8,
pub(crate) stratum: u8,
pub(crate) poll: i8,
pub(crate) precision: i8,
pub(crate) root_delay: u32,
pub(crate) root_dispersion: u32,
pub(crate) ref_id: u32,
pub(crate) ref_timestamp: u64,
pub(crate) origin_timestamp: u64,
pub(crate) recv_timestamp: u64,
pub(crate) tx_timestamp: u64,
}
cfg_if! {
if #[cfg(any(feature = "log", feature = "defmt"))] {
use crate::shifter;
use core::str;
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub(crate) struct DebugNtpPacket<'a> {
packet: &'a NtpPacket,
client_recv_timestamp: u64,
}
impl<'a> DebugNtpPacket<'a> {
pub(crate) fn new(
packet: &'a NtpPacket,
client_recv_timestamp: u64,
) -> Self {
Self {
packet,
client_recv_timestamp,
}
}
}
impl Debug for DebugNtpPacket<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
let mode = shifter(self.packet.li_vn_mode, MODE_MASK, MODE_SHIFT);
let version =
shifter(self.packet.li_vn_mode, VERSION_MASK, VERSION_SHIFT);
let li = shifter(self.packet.li_vn_mode, LI_MASK, LI_SHIFT);
let id_slice = &self.packet.ref_id.to_be_bytes();
let reference_id = str::from_utf8(id_slice).unwrap_or("");
f.debug_struct("NtpPacket")
.field("mode", &mode)
.field("version", &version)
.field("leap", &li)
.field("stratum", &self.packet.stratum)
.field("poll", &self.packet.poll)
.field("precision", &self.packet.precision)
.field("root delay", &self.packet.root_delay)
.field("root dispersion", &self.packet.root_dispersion)
.field("reference ID", &reference_id)
.field(
"origin timestamp (client)",
&self.packet.origin_timestamp,
)
.field(
"receive timestamp (server)",
&self.packet.recv_timestamp,
)
.field(
"transmit timestamp (server)",
&self.packet.tx_timestamp,
)
.field("receive timestamp (client)", &self.client_recv_timestamp)
.field("reference timestamp (server)", &self.packet.ref_timestamp)
.finish()
}
}
}
}
#[derive(Debug, Copy, Clone)]
pub(crate) struct NtpTimestamp {
pub(crate) seconds: i64,
pub(crate) seconds_fraction: i64,
}
impl From<u64> for NtpTimestamp {
#[allow(clippy::cast_possible_wrap)]
fn from(v: u64) -> Self {
let seconds = (((v & SECONDS_MASK) >> 32) - u64::from(NtpPacket::NTP_TIMESTAMP_DELTA)) as i64;
let microseconds = (v & SECONDS_FRAC_MASK) as i64;
NtpTimestamp {
seconds,
seconds_fraction: microseconds,
}
}
}
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub(crate) enum Units {
#[allow(dead_code)]
Milliseconds,
Microseconds,
}
impl Display for Units {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
let unit = match self {
Units::Microseconds => "us",
Units::Milliseconds => "ms",
};
write!(f, "{unit}")
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct KissOfDeathCode {
code: [u8; 4],
}
impl KissOfDeathCode {
pub(crate) fn new(code: [u8; 4]) -> Self {
Self { code }
}
#[must_use]
pub fn as_str(&self) -> &str {
str::from_utf8(&self.code).unwrap_or("")
}
}
#[derive(Debug, PartialEq, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[non_exhaustive]
pub enum Error {
IncorrectOriginTimestamp,
IncorrectMode,
IncorrectLeapIndicator,
IncorrectResponseVersion,
IncorrectStratumHeaders,
IncorrectPayload,
Network,
AddressResolve,
ResponseAddressMismatch,
KissOfDeath(KissOfDeathCode),
}
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct NtpResult {
pub seconds: u32,
pub seconds_fraction: u32,
pub roundtrip: u64,
pub offset: i64,
pub stratum: u8,
pub precision: i8,
}
impl NtpResult {
#[must_use]
pub fn new(seconds: u32, seconds_fraction: u32, roundtrip: u64, offset: i64, stratum: u8, precision: i8) -> Self {
let seconds = seconds + seconds_fraction / u32::MAX;
let seconds_fraction = seconds_fraction % u32::MAX;
NtpResult {
seconds,
seconds_fraction,
roundtrip,
offset,
stratum,
precision,
}
}
#[must_use]
pub fn sec(&self) -> u32 {
self.seconds
}
#[must_use]
pub fn sec_fraction(&self) -> u32 {
self.seconds_fraction
}
#[must_use]
pub fn roundtrip(&self) -> u64 {
self.roundtrip
}
#[must_use]
pub fn offset(&self) -> i64 {
self.offset
}
#[must_use]
pub fn stratum(&self) -> u8 {
self.stratum
}
#[must_use]
pub fn precision(&self) -> i8 {
self.precision
}
}
impl NtpPacket {
pub(crate) const NTP_TIMESTAMP_DELTA: u32 = 2_208_988_800u32;
const SNTP_CLIENT_MODE: u8 = 3;
const SNTP_VERSION: u8 = 4 << 3;
pub fn new<T: NtpTimestampGenerator>(mut timestamp_gen: T) -> Self {
timestamp_gen.init();
let tx_timestamp = get_ntp_timestamp(×tamp_gen);
#[cfg(any(feature = "log", feature = "defmt"))]
debug!("NtpPacket::new(tx_timestamp: {})", tx_timestamp);
NtpPacket {
li_vn_mode: NtpPacket::SNTP_CLIENT_MODE | NtpPacket::SNTP_VERSION,
stratum: 0,
poll: 0,
precision: 0,
root_delay: 0,
root_dispersion: 0,
ref_id: 0,
ref_timestamp: 0,
origin_timestamp: 0,
recv_timestamp: 0,
tx_timestamp,
}
}
}
pub trait NtpTimestampGenerator {
fn init(&mut self);
fn timestamp_sec(&self) -> u64;
fn timestamp_subsec_micros(&self) -> u32;
}
#[cfg(feature = "std")]
mod sup {
use std::time::{Duration, SystemTime};
use crate::NtpTimestampGenerator;
#[derive(Copy, Clone, Default)]
pub struct StdTimestampGen {
duration: Duration,
}
impl NtpTimestampGenerator for StdTimestampGen {
fn init(&mut self) {
self.duration = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap();
}
fn timestamp_sec(&self) -> u64 {
self.duration.as_secs()
}
fn timestamp_subsec_micros(&self) -> u32 {
self.duration.subsec_micros()
}
}
}
#[cfg(feature = "std")]
pub use sup::*;
pub trait NtpUdpSocket {
fn send_to(&self, buf: &[u8], addr: SocketAddr) -> impl Future<Output = Result<usize>>;
fn recv_from(&self, buf: &mut [u8]) -> impl Future<Output = Result<(usize, SocketAddr)>>;
}
#[derive(Copy, Clone)]
pub struct NtpContext<T: NtpTimestampGenerator> {
pub timestamp_gen: T,
}
impl<T: NtpTimestampGenerator + Copy> NtpContext<T> {
pub fn new(timestamp_gen: T) -> Self {
NtpContext { timestamp_gen }
}
}
#[derive(Copy, Clone, Debug)]
pub struct SendRequestResult {
pub(crate) originate_timestamp: u64,
pub(crate) version: u8,
}
impl From<NtpPacket> for SendRequestResult {
fn from(ntp_packet: NtpPacket) -> Self {
SendRequestResult {
originate_timestamp: ntp_packet.tx_timestamp,
version: ntp_packet.li_vn_mode,
}
}
}
pub(crate) trait NtpNum {
type Type;
fn ntohl(&self) -> Self::Type;
}
impl NtpNum for u32 {
type Type = u32;
fn ntohl(&self) -> Self::Type {
self.to_be()
}
}
impl NtpNum for u64 {
type Type = u64;
fn ntohl(&self) -> Self::Type {
self.to_be()
}
}
#[derive(Copy, Clone)]
pub(crate) struct RawNtpPacket(pub(crate) [u8; size_of::<NtpPacket>()]);
impl Default for RawNtpPacket {
fn default() -> Self {
RawNtpPacket([0u8; size_of::<NtpPacket>()])
}
}
impl From<RawNtpPacket> for NtpPacket {
fn from(val: RawNtpPacket) -> Self {
let to_array_u32 = |x: &[u8]| {
let mut temp_buf = [0u8; mem::size_of::<u32>()];
temp_buf.copy_from_slice(x);
temp_buf
};
let to_array_u64 = |x: &[u8]| {
let mut temp_buf = [0u8; mem::size_of::<u64>()];
temp_buf.copy_from_slice(x);
temp_buf
};
NtpPacket {
li_vn_mode: val.0[0],
stratum: val.0[1],
#[allow(clippy::cast_possible_wrap)]
poll: val.0[2] as i8,
#[allow(clippy::cast_possible_wrap)]
precision: val.0[3] as i8,
root_delay: u32::from_le_bytes(to_array_u32(&val.0[4..8])),
root_dispersion: u32::from_le_bytes(to_array_u32(&val.0[8..12])),
ref_id: u32::from_le_bytes(to_array_u32(&val.0[12..16])),
ref_timestamp: u64::from_le_bytes(to_array_u64(&val.0[16..24])),
origin_timestamp: u64::from_le_bytes(to_array_u64(&val.0[24..32])),
recv_timestamp: u64::from_le_bytes(to_array_u64(&val.0[32..40])),
tx_timestamp: u64::from_le_bytes(to_array_u64(&val.0[40..48])),
}
}
}
impl From<&NtpPacket> for RawNtpPacket {
#[allow(clippy::cast_sign_loss)]
fn from(val: &NtpPacket) -> Self {
let mut tmp_buf = [0u8; size_of::<NtpPacket>()];
tmp_buf[0] = val.li_vn_mode;
tmp_buf[1] = val.stratum;
tmp_buf[2] = val.poll as u8;
tmp_buf[3] = val.precision as u8;
tmp_buf[4..8].copy_from_slice(&val.root_delay.to_be_bytes());
tmp_buf[8..12].copy_from_slice(&val.root_dispersion.to_be_bytes());
tmp_buf[12..16].copy_from_slice(&val.ref_id.to_be_bytes());
tmp_buf[16..24].copy_from_slice(&val.ref_timestamp.to_be_bytes());
tmp_buf[24..32].copy_from_slice(&val.origin_timestamp.to_be_bytes());
tmp_buf[32..40].copy_from_slice(&val.recv_timestamp.to_be_bytes());
tmp_buf[40..48].copy_from_slice(&val.tx_timestamp.to_be_bytes());
RawNtpPacket(tmp_buf)
}
}