#![allow(dead_code)]
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ClockConfig {
pub nominal_rate: u64,
pub ema_alpha: f64,
}
impl Default for ClockConfig {
fn default() -> Self {
Self {
nominal_rate: 90_000,
ema_alpha: 0.1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ClockObservation {
pub wall_ns: u64,
pub pts: i64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DriftEstimate {
pub observed_rate: f64,
pub drift_ppm: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PtsDtsRelation {
pub pts_dts_offset: i64,
}
impl PtsDtsRelation {
#[must_use]
pub fn dts_from_pts(&self, pts: i64) -> i64 {
pts - self.pts_dts_offset
}
#[must_use]
pub fn pts_from_dts(&self, dts: i64) -> i64 {
dts + self.pts_dts_offset
}
}
struct ClockState {
anchor: ClockObservation,
latest: ClockObservation,
ema_rate: f64,
obs_count: u64,
}
impl ClockState {
fn new(first: ClockObservation, nominal_rate: f64) -> Self {
Self {
anchor: first,
latest: first,
ema_rate: nominal_rate,
obs_count: 1,
}
}
fn update(&mut self, obs: ClockObservation, alpha: f64) -> f64 {
let delta_ns = obs.wall_ns.saturating_sub(self.latest.wall_ns);
let delta_pts = obs.pts.wrapping_sub(self.latest.pts);
let instant_rate = if delta_ns > 0 && delta_pts > 0 {
delta_pts as f64 / (delta_ns as f64 / 1_000_000_000.0)
} else {
self.ema_rate
};
self.ema_rate = alpha * instant_rate + (1.0 - alpha) * self.ema_rate;
self.latest = obs;
self.obs_count += 1;
instant_rate
}
fn drift_ppm(&self, nominal_rate: f64) -> f64 {
(self.ema_rate - nominal_rate) / nominal_rate * 1_000_000.0
}
}
pub struct MediaClock {
config: ClockConfig,
state: Option<ClockState>,
pts_dts_relation: Option<PtsDtsRelation>,
}
impl MediaClock {
#[must_use]
pub fn new(config: ClockConfig) -> Self {
Self {
config,
state: None,
pts_dts_relation: None,
}
}
#[must_use]
pub fn default_90k() -> Self {
Self::new(ClockConfig::default())
}
#[must_use]
pub fn nominal_rate(&self) -> u64 {
self.config.nominal_rate
}
#[must_use]
pub fn observation_count(&self) -> u64 {
self.state.as_ref().map_or(0, |s| s.obs_count)
}
pub fn observe(&mut self, wall_ns: u64, pts: i64) {
let obs = ClockObservation { wall_ns, pts };
match &mut self.state {
None => {
self.state = Some(ClockState::new(obs, self.config.nominal_rate as f64));
}
Some(state) => {
if obs.wall_ns <= state.latest.wall_ns {
return;
}
state.update(obs, self.config.ema_alpha);
}
}
}
pub fn observe_pts_dts(&mut self, pts: i64, dts: i64) {
if self.pts_dts_relation.is_none() {
self.pts_dts_relation = Some(PtsDtsRelation {
pts_dts_offset: pts - dts,
});
}
}
#[must_use]
pub fn estimate(&self) -> Option<DriftEstimate> {
let state = self.state.as_ref()?;
if state.obs_count < 2 {
return None;
}
let nominal = self.config.nominal_rate as f64;
Some(DriftEstimate {
observed_rate: state.ema_rate,
drift_ppm: state.drift_ppm(nominal),
})
}
#[must_use]
pub fn drift_ppm(&self) -> f64 {
self.estimate().map_or(0.0, |e| e.drift_ppm)
}
#[must_use]
pub fn predict_pts(&self, wall_ns: u64) -> Option<i64> {
let state = self.state.as_ref()?;
if state.obs_count < 2 {
return None;
}
let delta_ns = (wall_ns as i64).wrapping_sub(state.anchor.wall_ns as i64);
let delta_secs = delta_ns as f64 / 1_000_000_000.0;
let delta_pts = (delta_secs * state.ema_rate).round() as i64;
Some(state.anchor.pts.wrapping_add(delta_pts))
}
#[must_use]
pub fn pts_dts_relation(&self) -> Option<PtsDtsRelation> {
self.pts_dts_relation
}
#[must_use]
pub fn pts_to_dts(&self, pts: i64) -> Option<i64> {
self.pts_dts_relation.map(|r| r.dts_from_pts(pts))
}
#[must_use]
pub fn dts_to_pts(&self, dts: i64) -> Option<i64> {
self.pts_dts_relation.map(|r| r.pts_from_dts(dts))
}
pub fn reset(&mut self) {
self.state = None;
self.pts_dts_relation = None;
}
}
impl fmt::Debug for MediaClock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let obs = self.observation_count();
let drift = self.drift_ppm();
write!(
f,
"MediaClock {{ nominal_rate: {}, obs: {obs}, drift_ppm: {drift:.2} }}",
self.config.nominal_rate
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_clock() -> MediaClock {
MediaClock::new(ClockConfig {
nominal_rate: 90_000,
ema_alpha: 0.5,
})
}
#[test]
fn test_initial_state() {
let clock = make_clock();
assert_eq!(clock.observation_count(), 0);
assert!(clock.estimate().is_none());
assert_eq!(clock.drift_ppm(), 0.0);
}
#[test]
fn test_single_observation_no_estimate() {
let mut clock = make_clock();
clock.observe(0, 0);
assert_eq!(clock.observation_count(), 1);
assert!(clock.estimate().is_none());
}
#[test]
fn test_perfect_rate_drift_near_zero() {
let mut clock = make_clock();
clock.observe(0, 0);
clock.observe(1_000_000_000, 90_000); let drift = clock.drift_ppm();
assert!(drift.abs() < 1.0, "drift = {drift} ppm, expected ~0");
}
#[test]
fn test_fast_source_positive_drift() {
let mut clock = make_clock();
let pts_per_sec: i64 = 90_009;
clock.observe(0, 0);
clock.observe(1_000_000_000, pts_per_sec);
let drift = clock.drift_ppm();
assert!(drift > 0.0, "fast source must yield positive drift");
}
#[test]
fn test_slow_source_negative_drift() {
let mut clock = make_clock();
let pts_per_sec: i64 = 89_991;
clock.observe(0, 0);
clock.observe(1_000_000_000, pts_per_sec);
let drift = clock.drift_ppm();
assert!(drift < 0.0, "slow source must yield negative drift");
}
#[test]
fn test_non_monotonic_observation_ignored() {
let mut clock = make_clock();
clock.observe(1_000_000_000, 90_000);
clock.observe(500_000_000, 45_000); assert_eq!(clock.observation_count(), 1);
}
#[test]
fn test_predict_pts_uncalibrated() {
let mut clock = make_clock();
clock.observe(0, 0);
assert!(clock.predict_pts(1_000_000_000).is_none());
}
#[test]
fn test_predict_pts_nominal_rate() {
let mut clock = MediaClock::new(ClockConfig {
nominal_rate: 90_000,
ema_alpha: 1.0, });
clock.observe(0, 0);
clock.observe(1_000_000_000, 90_000);
let predicted = clock.predict_pts(2_000_000_000).expect("should predict");
let error = (predicted - 180_000i64).abs();
assert!(error <= 2, "predicted PTS = {predicted}, error = {error}");
}
#[test]
fn test_pts_dts_relation_stored() {
let mut clock = make_clock();
clock.observe_pts_dts(100, 90); let rel = clock.pts_dts_relation().expect("relation stored");
assert_eq!(rel.pts_dts_offset, 10);
}
#[test]
fn test_pts_dts_relation_immutable_after_first() {
let mut clock = make_clock();
clock.observe_pts_dts(100, 90);
clock.observe_pts_dts(200, 150); let rel = clock.pts_dts_relation().expect("relation stored");
assert_eq!(rel.pts_dts_offset, 10);
}
#[test]
fn test_pts_dts_roundtrip() {
let mut clock = make_clock();
clock.observe_pts_dts(180_000, 162_000); let dts = clock.pts_to_dts(270_000).expect("relation set");
assert_eq!(dts, 252_000);
let pts_back = clock.dts_to_pts(dts).expect("relation set");
assert_eq!(pts_back, 270_000);
}
#[test]
fn test_reset_clears_state() {
let mut clock = make_clock();
clock.observe(0, 0);
clock.observe(1_000_000_000, 90_000);
clock.observe_pts_dts(90_000, 81_000);
clock.reset();
assert_eq!(clock.observation_count(), 0);
assert!(clock.estimate().is_none());
assert!(clock.pts_dts_relation().is_none());
}
#[test]
fn test_nominal_rate() {
let clock = MediaClock::new(ClockConfig {
nominal_rate: 48_000,
ema_alpha: 0.1,
});
assert_eq!(clock.nominal_rate(), 48_000);
}
#[test]
fn test_drift_estimate_observed_rate() {
let mut clock = MediaClock::new(ClockConfig {
nominal_rate: 90_000,
ema_alpha: 1.0,
});
clock.observe(0, 0);
clock.observe(1_000_000_000, 90_000);
let est = clock.estimate().expect("estimate available");
let diff = (est.observed_rate - 90_000.0).abs();
assert!(diff < 1.0, "observed_rate = {}", est.observed_rate);
}
#[test]
fn test_pts_dts_relation_helpers() {
let rel = PtsDtsRelation {
pts_dts_offset: 9000,
};
assert_eq!(rel.dts_from_pts(180_000), 171_000);
assert_eq!(rel.pts_from_dts(171_000), 180_000);
}
#[test]
fn test_debug_impl() {
let mut clock = make_clock();
clock.observe(0, 0);
clock.observe(1_000_000_000, 90_000);
let s = format!("{clock:?}");
assert!(s.contains("MediaClock"));
}
#[test]
fn test_multiple_observations_ema_converges() {
let mut clock = MediaClock::new(ClockConfig {
nominal_rate: 90_000,
ema_alpha: 0.3,
});
let ns_per_sec: u64 = 1_000_000_000;
for i in 0..10u64 {
clock.observe(i * ns_per_sec, (i * 90_000) as i64);
}
let drift = clock.drift_ppm();
assert!(drift.abs() < 5.0, "EMA drift = {drift} ppm after 10 obs");
}
}