1use crate::{AlignError, AlignResult, TimeOffset};
11use std::f64::consts::PI;
12
13#[derive(Debug, Clone)]
15pub struct SyncConfig {
16 pub sample_rate: u32,
18 pub window_size: usize,
20 pub max_offset: usize,
22}
23
24impl Default for SyncConfig {
25 fn default() -> Self {
26 Self {
27 sample_rate: 48000,
28 window_size: 480000, max_offset: 240000, }
31 }
32}
33
34pub struct AudioSync {
36 config: SyncConfig,
37}
38
39impl AudioSync {
40 #[must_use]
42 pub fn new(config: SyncConfig) -> Self {
43 Self { config }
44 }
45
46 pub fn find_offset(&self, signal1: &[f32], signal2: &[f32]) -> AlignResult<TimeOffset> {
51 if signal1.len() < self.config.window_size || signal2.len() < self.config.window_size {
52 return Err(AlignError::InsufficientData(
53 "Audio signals too short for correlation".to_string(),
54 ));
55 }
56
57 let window1 = &signal1[..self.config.window_size];
59 let window2 = &signal2[..self.config.window_size.min(signal2.len())];
60
61 let (offset, correlation) = self.cross_correlate(window1, window2)?;
63
64 let confidence = self.compute_confidence(window1, window2, offset);
66
67 Ok(TimeOffset::new(offset, confidence, correlation))
68 }
69
70 fn cross_correlate(&self, signal1: &[f32], signal2: &[f32]) -> AlignResult<(i64, f64)> {
72 let mut max_corr = f64::NEG_INFINITY;
73 let mut best_offset = 0i64;
74
75 let max_search = self.config.max_offset.min(signal1.len()).min(signal2.len());
76
77 let norm1 = self.normalize_signal(signal1);
79 let norm2 = self.normalize_signal(signal2);
80
81 for offset in 0..max_search {
83 let corr_pos = self.compute_correlation(&norm1[offset..], &norm2);
85 if corr_pos > max_corr {
86 max_corr = corr_pos;
87 best_offset = offset as i64;
88 }
89
90 if offset > 0 {
92 let corr_neg = self.compute_correlation(&norm1, &norm2[offset..]);
93 if corr_neg > max_corr {
94 max_corr = corr_neg;
95 best_offset = -(offset as i64);
96 }
97 }
98 }
99
100 if max_corr.is_finite() {
101 Ok((best_offset, max_corr))
102 } else {
103 Err(AlignError::SyncError(
104 "Correlation produced non-finite value".to_string(),
105 ))
106 }
107 }
108
109 fn normalize_signal(&self, signal: &[f32]) -> Vec<f32> {
111 let n = signal.len() as f32;
112 let mean = signal.iter().sum::<f32>() / n;
113
114 let variance = signal.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / n;
115
116 let std_dev = variance.sqrt();
117
118 if std_dev < 1e-10 {
119 return vec![0.0; signal.len()];
120 }
121
122 signal.iter().map(|&x| (x - mean) / std_dev).collect()
123 }
124
125 fn compute_correlation(&self, sig1: &[f32], sig2: &[f32]) -> f64 {
127 let len = sig1.len().min(sig2.len());
128 if len == 0 {
129 return 0.0;
130 }
131
132 let sum: f64 = sig1[..len]
133 .iter()
134 .zip(&sig2[..len])
135 .map(|(&a, &b)| f64::from(a) * f64::from(b))
136 .sum();
137
138 sum / len as f64
139 }
140
141 fn compute_confidence(&self, _signal1: &[f32], _signal2: &[f32], _offset: i64) -> f64 {
143 0.95
146 }
147
148 pub fn refine_offset(
153 &self,
154 signal1: &[f32],
155 signal2: &[f32],
156 coarse_offset: i64,
157 ) -> AlignResult<f64> {
158 let offset = coarse_offset.unsigned_abs() as usize;
159
160 if offset >= signal1.len() || offset >= signal2.len() {
161 return Err(AlignError::InvalidConfig("Offset out of range".to_string()));
162 }
163
164 let norm1 = self.normalize_signal(signal1);
166 let norm2 = self.normalize_signal(signal2);
167
168 let c0 = if offset > 0 {
169 self.compute_correlation(&norm1[offset - 1..], &norm2)
170 } else {
171 0.0
172 };
173
174 let c1 = self.compute_correlation(&norm1[offset..], &norm2);
175
176 let c2 = if offset + 1 < norm1.len() {
177 self.compute_correlation(&norm1[offset + 1..], &norm2)
178 } else {
179 0.0
180 };
181
182 let delta = (c0 - c2) / (2.0 * (c0 - 2.0 * c1 + c2));
184
185 if delta.is_finite() {
186 Ok(coarse_offset as f64 + delta)
187 } else {
188 Ok(coarse_offset as f64)
189 }
190 }
191}
192
193pub struct TimecodeSync {
195 pub frame_rate: f64,
197}
198
199impl TimecodeSync {
200 #[must_use]
202 pub fn new(frame_rate: f64) -> Self {
203 Self { frame_rate }
204 }
205
206 #[must_use]
208 pub fn compute_offset(&self, tc1: &Timecode, tc2: &Timecode) -> i64 {
209 let frames1 = tc1.to_frames(self.frame_rate);
210 let frames2 = tc2.to_frames(self.frame_rate);
211 frames2 - frames1
212 }
213
214 #[must_use]
216 pub fn verify_continuity(&self, timecodes: &[Timecode]) -> bool {
217 if timecodes.len() < 2 {
218 return true;
219 }
220
221 for i in 1..timecodes.len() {
222 let offset = self.compute_offset(&timecodes[i - 1], &timecodes[i]);
223 if offset != 1 {
224 return false;
225 }
226 }
227
228 true
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234pub struct Timecode {
235 pub hours: u8,
237 pub minutes: u8,
239 pub seconds: u8,
241 pub frames: u8,
243}
244
245impl Timecode {
246 #[must_use]
248 pub fn new(hours: u8, minutes: u8, seconds: u8, frames: u8) -> Self {
249 Self {
250 hours,
251 minutes,
252 seconds,
253 frames,
254 }
255 }
256
257 #[must_use]
259 pub fn to_frames(&self, frame_rate: f64) -> i64 {
260 let fps = frame_rate.round() as i64;
261 i64::from(self.hours) * 3600 * fps
262 + i64::from(self.minutes) * 60 * fps
263 + i64::from(self.seconds) * fps
264 + i64::from(self.frames)
265 }
266
267 #[must_use]
269 pub fn from_frames(frames: i64, frame_rate: f64) -> Self {
270 let fps = frame_rate.round() as i64;
271 let total_seconds = frames / fps;
272 let remaining_frames = frames % fps;
273
274 let hours = (total_seconds / 3600) % 24;
275 let minutes = (total_seconds / 60) % 60;
276 let seconds = total_seconds % 60;
277
278 Self {
279 hours: hours as u8,
280 minutes: minutes as u8,
281 seconds: seconds as u8,
282 frames: remaining_frames as u8,
283 }
284 }
285}
286
287pub struct MarkerDetector {
289 pub flash_threshold: f32,
291 pub min_duration: usize,
293}
294
295impl Default for MarkerDetector {
296 fn default() -> Self {
297 Self {
298 flash_threshold: 0.8,
299 min_duration: 1,
300 }
301 }
302}
303
304impl MarkerDetector {
305 #[must_use]
307 pub fn new(flash_threshold: f32, min_duration: usize) -> Self {
308 Self {
309 flash_threshold,
310 min_duration,
311 }
312 }
313
314 #[must_use]
316 pub fn detect_flashes(&self, brightness: &[f32]) -> Vec<usize> {
317 let mut flashes = Vec::new();
318 let mut in_flash = false;
319 let mut flash_start = 0;
320
321 for (i, &value) in brightness.iter().enumerate() {
322 if !in_flash && value > self.flash_threshold {
323 in_flash = true;
324 flash_start = i;
325 } else if in_flash && value <= self.flash_threshold {
326 in_flash = false;
327 if i - flash_start >= self.min_duration {
328 flashes.push(flash_start);
329 }
330 }
331 }
332
333 flashes
334 }
335
336 #[must_use]
338 pub fn compute_brightness(&self, rgb: &[u8], width: usize, height: usize) -> f32 {
339 if rgb.len() != width * height * 3 {
340 return 0.0;
341 }
342
343 let sum: u32 = rgb
344 .chunks(3)
345 .map(|pixel| {
346 let r = u32::from(pixel[0]);
348 let g = u32::from(pixel[1]);
349 let b = u32::from(pixel[2]);
350 (299 * r + 587 * g + 114 * b) / 1000
351 })
352 .sum();
353
354 (sum as f32 / (width * height) as f32) / 255.0
355 }
356}
357
358pub struct PhaseCorrelation {
360 pub fft_size: usize,
362}
363
364impl PhaseCorrelation {
365 #[must_use]
367 pub fn new(fft_size: usize) -> Self {
368 Self { fft_size }
369 }
370
371 pub fn find_offset(&self, signal1: &[f32], signal2: &[f32]) -> AlignResult<f64> {
376 if signal1.len() != signal2.len() || signal1.is_empty() {
377 return Err(AlignError::InvalidConfig(
378 "Signals must have same non-zero length".to_string(),
379 ));
380 }
381
382 let len = signal1.len().min(self.fft_size);
384 let mut max_val = f32::NEG_INFINITY;
385 let mut max_idx = 0;
386
387 for offset in 0..len {
388 let mut sum = 0.0f32;
389 for i in 0..(len - offset) {
390 sum += signal1[i] * signal2[i + offset];
391 }
392 if sum > max_val {
393 max_val = sum;
394 max_idx = offset;
395 }
396 }
397
398 Ok(max_idx as f64)
399 }
400}
401
402pub struct BeatDetector {
404 pub sample_rate: u32,
406 pub hop_size: usize,
408}
409
410impl BeatDetector {
411 #[must_use]
413 pub fn new(sample_rate: u32, hop_size: usize) -> Self {
414 Self {
415 sample_rate,
416 hop_size,
417 }
418 }
419
420 #[must_use]
422 pub fn detect_beats(&self, audio: &[f32]) -> Vec<usize> {
423 let mut beats = Vec::new();
424 let window_size = 2048;
425
426 let energy = self.compute_energy_envelope(audio, window_size);
428
429 for i in 1..energy.len().saturating_sub(1) {
431 if energy[i] > energy[i - 1] && energy[i] > energy[i + 1] {
432 let threshold = energy[i.saturating_sub(10)..i].iter().sum::<f32>() / 10.0 * 1.5;
433
434 if energy[i] > threshold {
435 beats.push(i * self.hop_size);
436 }
437 }
438 }
439
440 beats
441 }
442
443 fn compute_energy_envelope(&self, audio: &[f32], window_size: usize) -> Vec<f32> {
445 let mut envelope = Vec::new();
446
447 for chunk in audio.chunks(self.hop_size) {
448 let energy: f32 = chunk
449 .iter()
450 .take(window_size.min(chunk.len()))
451 .map(|&x| x * x)
452 .sum();
453 envelope.push(energy);
454 }
455
456 envelope
457 }
458
459 pub fn align_beats(&self, audio1: &[f32], audio2: &[f32]) -> AlignResult<TimeOffset> {
464 let beats1 = self.detect_beats(audio1);
465 let beats2 = self.detect_beats(audio2);
466
467 if beats1.is_empty() || beats2.is_empty() {
468 return Err(AlignError::SyncError("No beats detected".to_string()));
469 }
470
471 let offset = beats2[0] as i64 - beats1[0] as i64;
473
474 Ok(TimeOffset::new(offset, 0.8, 0.9))
475 }
476}
477
478pub struct WindowFunction;
480
481impl WindowFunction {
482 #[must_use]
484 pub fn hann(size: usize) -> Vec<f32> {
485 (0..size)
486 .map(|i| {
487 let x = i as f64 / (size - 1) as f64;
488 (0.5 * (1.0 - (2.0 * PI * x).cos())) as f32
489 })
490 .collect()
491 }
492
493 #[must_use]
495 pub fn hamming(size: usize) -> Vec<f32> {
496 (0..size)
497 .map(|i| {
498 let x = i as f64 / (size - 1) as f64;
499 (0.54 - 0.46 * (2.0 * PI * x).cos()) as f32
500 })
501 .collect()
502 }
503
504 #[must_use]
506 pub fn blackman(size: usize) -> Vec<f32> {
507 (0..size)
508 .map(|i| {
509 let x = i as f64 / (size - 1) as f64;
510 (0.42 - 0.5 * (2.0 * PI * x).cos() + 0.08 * (4.0 * PI * x).cos()) as f32
511 })
512 .collect()
513 }
514}
515
516pub struct MultiStreamSync {
518 audio_config: SyncConfig,
520 reference_index: usize,
522}
523
524impl MultiStreamSync {
525 #[must_use]
527 pub fn new(audio_config: SyncConfig, reference_index: usize) -> Self {
528 Self {
529 audio_config,
530 reference_index,
531 }
532 }
533
534 pub fn sync_streams(&self, streams: &[&[f32]]) -> AlignResult<Vec<TimeOffset>> {
539 if streams.len() <= self.reference_index {
540 return Err(AlignError::InvalidConfig(
541 "Reference index out of bounds".to_string(),
542 ));
543 }
544
545 let reference = streams[self.reference_index];
546 let sync = AudioSync::new(self.audio_config.clone());
547
548 let mut offsets = Vec::new();
549
550 for (i, stream) in streams.iter().enumerate() {
551 if i == self.reference_index {
552 offsets.push(TimeOffset::new(0, 1.0, 1.0));
553 } else {
554 let offset = sync.find_offset(reference, stream)?;
555 offsets.push(offset);
556 }
557 }
558
559 Ok(offsets)
560 }
561
562 #[must_use]
564 pub fn compute_sync_quality(&self, offsets: &[TimeOffset]) -> f32 {
565 if offsets.is_empty() {
566 return 0.0;
567 }
568
569 let avg_confidence: f64 =
570 offsets.iter().map(|o| o.confidence).sum::<f64>() / offsets.len() as f64;
571 let avg_correlation: f64 =
572 offsets.iter().map(|o| o.correlation).sum::<f64>() / offsets.len() as f64;
573
574 ((avg_confidence + avg_correlation) / 2.0) as f32
575 }
576}
577
578pub struct DriftDetector {
580 pub sample_rate: u32,
582 pub window_size: usize,
584 pub num_windows: usize,
586}
587
588impl DriftDetector {
589 #[must_use]
591 pub fn new(sample_rate: u32, window_size: usize, num_windows: usize) -> Self {
592 Self {
593 sample_rate,
594 window_size,
595 num_windows,
596 }
597 }
598
599 pub fn detect_drift(&self, signal1: &[f32], signal2: &[f32]) -> AlignResult<Vec<TimeOffset>> {
604 let total_samples = self.window_size * self.num_windows;
605 if signal1.len() < total_samples || signal2.len() < total_samples {
606 return Err(AlignError::InsufficientData(
607 "Signals too short for drift analysis".to_string(),
608 ));
609 }
610
611 let config = SyncConfig {
612 sample_rate: self.sample_rate,
613 window_size: self.window_size,
614 max_offset: self.window_size / 2,
615 };
616
617 let sync = AudioSync::new(config);
618 let mut offsets = Vec::new();
619
620 for i in 0..self.num_windows {
621 let start = i * self.window_size;
622 let end = start + self.window_size;
623
624 let window1 = &signal1[start..end];
625 let window2 = &signal2[start..end];
626
627 let offset = sync.find_offset(window1, window2)?;
628 offsets.push(offset);
629 }
630
631 Ok(offsets)
632 }
633
634 #[must_use]
636 pub fn compute_drift_rate(&self, offsets: &[TimeOffset]) -> f32 {
637 if offsets.len() < 2 {
638 return 0.0;
639 }
640
641 let n = offsets.len() as f32;
643 let mut sum_x = 0.0f32;
644 let mut sum_y = 0.0f32;
645 let mut sum_xy = 0.0f32;
646 let mut sum_xx = 0.0f32;
647
648 for (i, offset) in offsets.iter().enumerate() {
649 let x = i as f32;
650 let y = offset.samples as f32;
651 sum_x += x;
652 sum_y += y;
653 sum_xy += x * y;
654 sum_xx += x * x;
655 }
656
657 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x);
658
659 let window_duration = self.window_size as f32 / self.sample_rate as f32;
661 slope / window_duration
662 }
663}
664
665pub struct SpectralCorrelation {
667 pub fft_size: usize,
669 pub hop_size: usize,
671}
672
673impl SpectralCorrelation {
674 #[must_use]
676 pub fn new(fft_size: usize, hop_size: usize) -> Self {
677 Self { fft_size, hop_size }
678 }
679
680 pub fn correlate(&self, signal1: &[f32], signal2: &[f32]) -> AlignResult<TimeOffset> {
685 if signal1.len() < self.fft_size || signal2.len() < self.fft_size {
686 return Err(AlignError::InsufficientData(
687 "Signals too short for spectral correlation".to_string(),
688 ));
689 }
690
691 let mut max_corr = f32::NEG_INFINITY;
693 let mut best_offset = 0i64;
694
695 let max_offset = signal1.len().min(signal2.len()) / 2;
696
697 for offset in 0..max_offset.min(10000) {
698 let mut corr = 0.0f32;
699 let len = (signal1.len() - offset)
700 .min(signal2.len())
701 .min(self.fft_size);
702
703 for i in 0..len {
704 corr += signal1[i + offset] * signal2[i];
705 }
706
707 if corr > max_corr {
708 max_corr = corr;
709 best_offset = offset as i64;
710 }
711 }
712
713 Ok(TimeOffset::new(best_offset, 0.9, f64::from(max_corr)))
714 }
715}
716
717pub struct JitterAnalyzer {
719 pub expected_interval: usize,
721 pub tolerance: usize,
723}
724
725impl JitterAnalyzer {
726 #[must_use]
728 pub fn new(expected_interval: usize, tolerance: usize) -> Self {
729 Self {
730 expected_interval,
731 tolerance,
732 }
733 }
734
735 #[must_use]
737 pub fn analyze_jitter(&self, timestamps: &[usize]) -> JitterMetrics {
738 if timestamps.len() < 2 {
739 return JitterMetrics::default();
740 }
741
742 let mut intervals = Vec::new();
743 for i in 1..timestamps.len() {
744 intervals.push(timestamps[i] - timestamps[i - 1]);
745 }
746
747 let mean_interval = intervals.iter().sum::<usize>() as f32 / intervals.len() as f32;
748
749 let mut variance = 0.0f32;
750 for &interval in &intervals {
751 let diff = interval as f32 - mean_interval;
752 variance += diff * diff;
753 }
754 variance /= intervals.len() as f32;
755
756 let std_dev = variance.sqrt();
757 let max_jitter = intervals
758 .iter()
759 .map(|&i| (i as i32 - self.expected_interval as i32).abs())
760 .max()
761 .unwrap_or(0) as f32;
762
763 JitterMetrics {
764 mean_interval,
765 std_dev,
766 max_jitter,
767 jitter_ratio: std_dev / mean_interval,
768 }
769 }
770}
771
772#[derive(Debug, Clone, Copy, Default)]
774pub struct JitterMetrics {
775 pub mean_interval: f32,
777 pub std_dev: f32,
779 pub max_jitter: f32,
781 pub jitter_ratio: f32,
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 #[test]
790 fn test_audio_sync_config() {
791 let config = SyncConfig::default();
792 assert_eq!(config.sample_rate, 48000);
793 assert_eq!(config.window_size, 480000);
794 }
795
796 #[test]
797 fn test_timecode_conversion() {
798 let tc = Timecode::new(1, 30, 45, 10);
799 let frames = tc.to_frames(25.0);
800 let tc2 = Timecode::from_frames(frames, 25.0);
801 assert_eq!(tc, tc2);
802 }
803
804 #[test]
805 fn test_timecode_offset() {
806 let sync = TimecodeSync::new(25.0);
807 let tc1 = Timecode::new(1, 0, 0, 0);
808 let tc2 = Timecode::new(1, 0, 0, 25);
809 assert_eq!(sync.compute_offset(&tc1, &tc2), 25);
810 }
811
812 #[test]
813 fn test_flash_detection() {
814 let detector = MarkerDetector::default();
815 let brightness = vec![0.1, 0.2, 0.9, 0.9, 0.1, 0.2];
816 let flashes = detector.detect_flashes(&brightness);
817 assert_eq!(flashes.len(), 1);
818 assert_eq!(flashes[0], 2);
819 }
820
821 #[test]
822 fn test_brightness_computation() {
823 let detector = MarkerDetector::default();
824 let rgb = vec![255u8; 300]; let brightness = detector.compute_brightness(&rgb, 10, 10);
826 assert!((brightness - 1.0).abs() < 0.01);
827 }
828
829 #[test]
830 fn test_normalize_signal() {
831 let sync = AudioSync::new(SyncConfig::default());
832 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
833 let normalized = sync.normalize_signal(&signal);
834
835 let mean: f32 = normalized.iter().sum::<f32>() / normalized.len() as f32;
837 assert!(mean.abs() < 1e-6);
838
839 let variance: f32 =
841 normalized.iter().map(|&x| x * x).sum::<f32>() / normalized.len() as f32;
842 assert!((variance - 1.0).abs() < 1e-6);
843 }
844
845 #[test]
846 fn test_window_functions() {
847 let hann = WindowFunction::hann(100);
848 assert_eq!(hann.len(), 100);
849 assert!(hann[0] < 0.01); assert!(hann[50] > 0.99); let hamming = WindowFunction::hamming(100);
853 assert_eq!(hamming.len(), 100);
854
855 let blackman = WindowFunction::blackman(100);
856 assert_eq!(blackman.len(), 100);
857 }
858
859 #[test]
860 fn test_beat_detector() {
861 let detector = BeatDetector::new(48000, 512);
862
863 let mut audio = vec![0.0; 48000];
865 for i in (0..48000).step_by(4800) {
866 for j in 0..100 {
867 if i + j < audio.len() {
868 audio[i + j] = 1.0;
869 }
870 }
871 }
872
873 let beats = detector.detect_beats(&audio);
874 assert!(!beats.is_empty());
875 }
876
877 #[test]
878 fn test_multi_stream_sync() {
879 let config = SyncConfig {
881 sample_rate: 48000,
882 window_size: 1000,
883 max_offset: 500,
884 };
885 let sync = MultiStreamSync::new(config, 0);
886
887 let stream1 = vec![0.1f32; 2000];
888 let stream2 = vec![0.2f32; 2000];
889 let streams = vec![&stream1[..], &stream2[..]];
890
891 let result = sync.sync_streams(&streams);
892 assert!(result.is_ok());
893 }
894
895 #[test]
896 fn test_drift_detector() {
897 let detector = DriftDetector::new(48000, 48000, 5);
898 assert_eq!(detector.sample_rate, 48000);
899 assert_eq!(detector.num_windows, 5);
900 }
901
902 #[test]
903 fn test_jitter_analyzer() {
904 let analyzer = JitterAnalyzer::new(1000, 10);
905 let timestamps = vec![0, 1000, 2000, 3005, 4000];
906 let metrics = analyzer.analyze_jitter(×tamps);
907
908 assert!(metrics.mean_interval > 0.0);
909 assert!(metrics.std_dev >= 0.0);
910 }
911
912 #[test]
913 fn test_spectral_correlation() {
914 let corr = SpectralCorrelation::new(1024, 512);
915 assert_eq!(corr.fft_size, 1024);
916 assert_eq!(corr.hop_size, 512);
917 }
918}
919
920#[derive(Debug, Clone)]
931pub struct NtpConfig {
932 pub server: String,
934 pub port: u16,
936 pub timeout_ms: u64,
938}
939
940impl Default for NtpConfig {
941 fn default() -> Self {
942 Self {
943 server: "pool.ntp.org".to_string(),
944 port: 123,
945 timeout_ms: 2_000,
946 }
947 }
948}
949
950#[derive(Debug, Clone, Copy)]
967pub struct TimeDelta {
968 pub offset_ms: i64,
973
974 pub round_trip_ms: u64,
976}
977
978pub struct NtpClient;
1000
1001impl NtpClient {
1002 pub fn query_offset(config: &NtpConfig) -> AlignResult<TimeDelta> {
1015 use std::net::{ToSocketAddrs, UdpSocket};
1016 use std::time::Duration;
1017
1018 let server_addr = format!("{}:{}", config.server, config.port)
1020 .to_socket_addrs()
1021 .map_err(|e| AlignError::SyncError(format!("DNS resolution failed: {e}")))?
1022 .next()
1023 .ok_or_else(|| AlignError::SyncError("No addresses returned by DNS".to_string()))?;
1024
1025 let socket = UdpSocket::bind("0.0.0.0:0")
1027 .map_err(|e| AlignError::SyncError(format!("Failed to bind UDP socket: {e}")))?;
1028
1029 socket
1030 .set_read_timeout(Some(Duration::from_millis(config.timeout_ms)))
1031 .map_err(|e| AlignError::SyncError(format!("Failed to set socket timeout: {e}")))?;
1032
1033 let mut request = [0u8; 48];
1035 request[0] = 0x23;
1037
1038 let t1_ntp = Self::unix_to_ntp(Self::now_unix_ms());
1041
1042 request[40..48].copy_from_slice(&t1_ntp);
1044
1045 socket
1046 .send_to(&request, server_addr)
1047 .map_err(|e| AlignError::SyncError(format!("UDP send failed: {e}")))?;
1048
1049 let mut response = [0u8; 96]; let (n, _) = socket
1052 .recv_from(&mut response)
1053 .map_err(|e| AlignError::SyncError(format!("UDP receive failed: {e}")))?;
1054
1055 let t4_unix_ms = Self::now_unix_ms();
1056
1057 if n < 48 {
1058 return Err(AlignError::SyncError(format!(
1059 "Response too short: {n} bytes (expected ≥ 48)"
1060 )));
1061 }
1062
1063 let t2_ms = Self::ntp_to_unix_ms(&response[32..40]);
1065 let t3_ms = Self::ntp_to_unix_ms(&response[40..48]);
1066 let t1_unix_ms = Self::ntp_to_unix_ms_from_u64(Self::ntp_bytes_to_u64(&t1_ntp));
1068
1069 let offset_2x_ms = (t2_ms - t1_unix_ms) + (t3_ms - t4_unix_ms as i64);
1073 let offset_ms = offset_2x_ms / 2;
1074
1075 let rtt_ms = (t4_unix_ms as i64 - t1_unix_ms) - (t3_ms - t2_ms);
1076 let round_trip_ms = rtt_ms.max(0) as u64;
1077
1078 Ok(TimeDelta {
1079 offset_ms,
1080 round_trip_ms,
1081 })
1082 }
1083
1084 fn now_unix_ms() -> u64 {
1088 use std::time::{SystemTime, UNIX_EPOCH};
1089 SystemTime::now()
1090 .duration_since(UNIX_EPOCH)
1091 .unwrap_or_default()
1092 .as_millis() as u64
1093 }
1094
1095 fn unix_to_ntp(unix_ms: u64) -> [u8; 8] {
1098 const NTP_EPOCH_OFFSET: u64 = 2_208_988_800; let secs = unix_ms / 1_000 + NTP_EPOCH_OFFSET;
1100 let frac_ms = unix_ms % 1_000;
1101 let frac = ((frac_ms as u128 * (1u128 << 32)) / 1_000) as u32;
1105 let mut out = [0u8; 8];
1106 out[0..4].copy_from_slice(&(secs as u32).to_be_bytes());
1107 out[4..8].copy_from_slice(&frac.to_be_bytes());
1108 out
1109 }
1110
1111 fn ntp_to_unix_ms(bytes: &[u8]) -> i64 {
1113 assert!(bytes.len() >= 8, "NTP timestamp slice must be ≥ 8 bytes");
1114 let raw = Self::ntp_bytes_to_u64(&bytes[..8].try_into().unwrap_or([0u8; 8]));
1115 Self::ntp_to_unix_ms_from_u64(raw)
1116 }
1117
1118 fn ntp_bytes_to_u64(bytes: &[u8; 8]) -> u64 {
1119 let secs = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as u64;
1120 let frac = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as u64;
1121 (secs << 32) | frac
1122 }
1123
1124 fn ntp_to_unix_ms_from_u64(raw: u64) -> i64 {
1125 const NTP_EPOCH_OFFSET: u64 = 2_208_988_800;
1126 let secs = (raw >> 32) as u64;
1127 let frac = (raw & 0xFFFF_FFFF) as u64;
1128 let frac_ms = ((frac as u128 * 1_000) >> 32) as u64;
1131 let unix_ms = (secs.saturating_sub(NTP_EPOCH_OFFSET)) * 1_000 + frac_ms;
1132 unix_ms as i64
1133 }
1134}
1135
1136#[cfg(test)]
1139mod ntp_tests {
1140 use super::*;
1141
1142 #[test]
1144 fn test_ntp_packet_parse_known_bytes() {
1145 let ntp_secs: u32 = 3_913_056_000;
1150 let ntp_frac: u32 = 0;
1151
1152 let mut bytes = [0u8; 8];
1153 bytes[0..4].copy_from_slice(&ntp_secs.to_be_bytes());
1154 bytes[4..8].copy_from_slice(&ntp_frac.to_be_bytes());
1155
1156 let unix_ms = NtpClient::ntp_to_unix_ms(&bytes);
1157
1158 let expected_ms: i64 = 1_704_067_200_000;
1160 assert_eq!(
1161 unix_ms, expected_ms,
1162 "Unix ms mismatch: got {unix_ms}, expected {expected_ms}"
1163 );
1164 }
1165
1166 #[test]
1172 fn test_ntp_offset_computation_formula() {
1173 let t1_local_ms: i64 = 1_050;
1186 let t2_server_ms: i64 = 1_000;
1187 let t3_server_ms: i64 = 1_000;
1188 let t4_local_ms: i64 = 1_060;
1189
1190 let offset_2x = (t2_server_ms - t1_local_ms) + (t3_server_ms - t4_local_ms);
1191 let offset = offset_2x / 2;
1192 let rtt = (t4_local_ms - t1_local_ms) - (t3_server_ms - t2_server_ms);
1193
1194 assert_eq!(offset, -55, "offset formula mismatch");
1195 assert_eq!(rtt, 10, "RTT formula mismatch");
1196 }
1197
1198 #[test]
1200 fn test_ntp_unix_roundtrip() {
1201 let unix_ms: u64 = 1_735_689_600_500;
1203 let ntp_bytes = NtpClient::unix_to_ntp(unix_ms);
1204 let recovered = NtpClient::ntp_to_unix_ms(&ntp_bytes);
1205 assert!(
1207 (recovered - unix_ms as i64).abs() <= 1,
1208 "round-trip error: in={unix_ms}, out={recovered}"
1209 );
1210 }
1211}