#[cfg(feature = "nts")]
use rkik_nts::{NtsClient, NtsClientConfig};
use chrono::{DateTime, Utc};
use std::time::Duration;
use crate::error::RkikError;
#[cfg(feature = "json")]
use serde::Serialize;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "json", derive(Serialize))]
#[serde(rename_all = "snake_case")]
pub enum NtsErrorKind {
KeHandshakeFailed,
CertificateInvalid,
MissingCookies,
AeadFailure,
MissingAuthenticator,
InvalidUniqueId,
InvalidOriginTimestamp,
MalformedExtensions,
UnauthenticatedResponse,
Timeout,
Network,
Unknown,
}
impl NtsErrorKind {
pub fn as_str(&self) -> &'static str {
match self {
NtsErrorKind::KeHandshakeFailed => "ke_handshake_failed",
NtsErrorKind::CertificateInvalid => "certificate_invalid",
NtsErrorKind::MissingCookies => "missing_cookies",
NtsErrorKind::AeadFailure => "aead_failure",
NtsErrorKind::MissingAuthenticator => "missing_authenticator",
NtsErrorKind::InvalidUniqueId => "invalid_unique_id",
NtsErrorKind::InvalidOriginTimestamp => "invalid_origin_timestamp",
NtsErrorKind::MalformedExtensions => "malformed_extensions",
NtsErrorKind::UnauthenticatedResponse => "unauthenticated_response",
NtsErrorKind::Timeout => "timeout",
NtsErrorKind::Network => "network",
NtsErrorKind::Unknown => "unknown",
}
}
pub fn plugin_exit_code(&self) -> i32 {
match self {
NtsErrorKind::AeadFailure
| NtsErrorKind::MissingAuthenticator
| NtsErrorKind::UnauthenticatedResponse
| NtsErrorKind::InvalidUniqueId
| NtsErrorKind::InvalidOriginTimestamp => 2,
NtsErrorKind::KeHandshakeFailed
| NtsErrorKind::CertificateInvalid
| NtsErrorKind::MissingCookies
| NtsErrorKind::MalformedExtensions => 3,
NtsErrorKind::Timeout | NtsErrorKind::Network | NtsErrorKind::Unknown => 3,
}
}
}
impl std::fmt::Display for NtsErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "json", derive(Serialize))]
pub struct NtsError {
pub kind: NtsErrorKind,
pub message: String,
}
impl NtsError {
pub fn new(kind: NtsErrorKind, message: impl Into<String>) -> Self {
Self {
kind,
message: message.into(),
}
}
}
impl std::fmt::Display for NtsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "json", derive(Serialize))]
pub struct NtsValidationOutcome {
pub authenticated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<NtsError>,
}
impl NtsValidationOutcome {
pub fn success() -> Self {
Self {
authenticated: true,
error: None,
}
}
pub fn failure(error: NtsError) -> Self {
Self {
authenticated: false,
error: Some(error),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "json", derive(Serialize))]
pub struct NtsTimeResult {
pub network_time: DateTime<Utc>,
pub offset_ms: f64,
pub rtt_ms: f64,
pub authenticated: bool,
pub server: String,
pub nts_ke_data: Option<NtsKeData>,
pub nts_validation: NtsValidationOutcome,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "json", derive(Serialize))]
pub struct NtsKeData {
pub ke_duration_ms: f64,
pub cookie_count: usize,
pub cookie_sizes: Vec<usize>,
pub aead_algorithm: String,
pub ntp_server: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub certificate: Option<CertificateInfo>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "json", derive(Serialize))]
pub struct CertificateInfo {
pub subject: String,
pub issuer: String,
pub valid_from: String,
pub valid_until: String,
pub serial_number: String,
pub san_dns_names: Vec<String>,
pub signature_algorithm: String,
pub public_key_algorithm: String,
pub fingerprint_sha256: String,
pub is_self_signed: bool,
}
#[cfg(feature = "nts")]
pub fn map_error_to_kind(error_msg: &str) -> NtsErrorKind {
let msg_lower = error_msg.to_lowercase();
if msg_lower.contains("aead") || msg_lower.contains("authentication tag") {
NtsErrorKind::AeadFailure
} else if msg_lower.contains("authenticator") {
NtsErrorKind::MissingAuthenticator
} else if msg_lower.contains("unique identifier") || msg_lower.contains("uid") {
NtsErrorKind::InvalidUniqueId
} else if msg_lower.contains("origin timestamp") || msg_lower.contains("replay") {
NtsErrorKind::InvalidOriginTimestamp
} else if msg_lower.contains("cookie") {
NtsErrorKind::MissingCookies
} else if msg_lower.contains("certificate") || msg_lower.contains("cert") {
NtsErrorKind::CertificateInvalid
} else if msg_lower.contains("extension") || msg_lower.contains("malformed") {
NtsErrorKind::MalformedExtensions
} else if msg_lower.contains("handshake")
|| msg_lower.contains("nts-ke")
|| msg_lower.contains("tls")
{
NtsErrorKind::KeHandshakeFailed
} else if msg_lower.contains("timeout") || msg_lower.contains("timed out") {
NtsErrorKind::Timeout
} else if msg_lower.contains("network")
|| msg_lower.contains("connection")
|| msg_lower.contains("refused")
{
NtsErrorKind::Network
} else {
NtsErrorKind::Unknown
}
}
#[cfg(feature = "nts")]
pub async fn query_nts(
server: &str,
nts_ke_port: Option<u16>,
timeout: Duration,
) -> Result<NtsTimeResult, RkikError> {
let mut config = NtsClientConfig::new(server);
if let Some(port) = nts_ke_port {
config = config.with_port(port);
}
config = config.with_timeout(timeout);
let mut client = NtsClient::new(config);
client.connect().await.map_err(|e| {
let msg = e.to_string();
let kind = map_error_to_kind(&msg);
RkikError::Nts(format!("NTS-KE failed: {} [{}]", msg, kind))
})?;
let time_snapshot = client.get_time().await.map_err(|e| {
let msg = e.to_string();
let kind = map_error_to_kind(&msg);
RkikError::Nts(format!("NTS time query failed: {} [{}]", msg, kind))
})?;
if !time_snapshot.authenticated {
return Err(RkikError::Nts(format!(
"NTS validation failed: server returned unauthenticated response after NTS-KE [{}]",
NtsErrorKind::UnauthenticatedResponse
)));
}
let nts_ke_data = client.nts_ke_info().map(|ke_result| {
let certificate = ke_result.certificate.as_ref().map(|cert| CertificateInfo {
subject: cert.subject.clone(),
issuer: cert.issuer.clone(),
valid_from: cert.valid_from.clone(),
valid_until: cert.valid_until.clone(),
serial_number: cert.serial_number.clone(),
san_dns_names: cert.san_dns_names.clone(),
signature_algorithm: cert.signature_algorithm.clone(),
public_key_algorithm: cert.public_key_algorithm.clone(),
fingerprint_sha256: cert.fingerprint_sha256.clone(),
is_self_signed: cert.is_self_signed,
});
NtsKeData {
ke_duration_ms: ke_result.ke_duration.as_secs_f64() * 1000.0,
cookie_count: ke_result.initial_cookie_count,
cookie_sizes: vec![], aead_algorithm: ke_result.aead_algorithm.clone(),
ntp_server: ke_result.ntp_server.to_string(),
certificate,
}
});
let network_time: DateTime<Utc> = time_snapshot.network_time.into();
let offset_ms = time_snapshot.offset.as_secs_f64() * 1000.0;
let rtt_ms = time_snapshot.round_trip_delay.as_secs_f64() * 1000.0;
Ok(NtsTimeResult {
network_time,
offset_ms,
rtt_ms,
authenticated: true,
server: time_snapshot.server.clone(),
nts_ke_data,
nts_validation: NtsValidationOutcome::success(),
})
}
#[cfg(not(feature = "nts"))]
pub async fn query_nts(
_server: &str,
_nts_ke_port: Option<u16>,
_timeout: Duration,
) -> Result<NtsTimeResult, RkikError> {
Err(RkikError::Other(
"NTS support not enabled. Compile with --features nts".to_string(),
))
}