1use drmem_api::{
2 device,
3 driver::{self, DriverConfig},
4 Error, Result,
5};
6use std::future::Future;
7use std::sync::Arc;
8use std::{convert::Infallible, pin::Pin};
9use std::{
10 net::{SocketAddr, SocketAddrV4},
11 str,
12};
13use tokio::net::UdpSocket;
14use tokio::sync::Mutex;
15use tokio::time::{self, Duration};
16use tracing::{debug, error, trace, warn, Span};
17
18mod server {
22 use super::*;
23
24 #[derive(Debug, PartialEq)]
27 pub struct Info(String, f64, f64);
28
29 impl Info {
30 pub fn new(host: String, offset: f64, delay: f64) -> Info {
33 Info(host, offset, delay)
34 }
35
36 pub fn bad_value() -> Info {
40 Info(String::from(""), 0.0, 0.0)
41 }
42
43 pub fn get_host(&self) -> &String {
46 &self.0
47 }
48
49 pub fn get_offset(&self) -> f64 {
53 self.1
54 }
55
56 pub fn get_delay(&self) -> f64 {
60 self.2
61 }
62 }
63
64 fn update_host_info(
70 mut state: (Option<String>, Option<f64>, Option<f64>),
71 item: &str,
72 ) -> (Option<String>, Option<f64>, Option<f64>) {
73 match item.split('=').collect::<Vec<&str>>()[..] {
74 ["srcadr", adr] => state.0 = Some(String::from(adr)),
75 ["offset", offset] => {
76 if let Ok(o) = offset.parse::<f64>() {
77 state.1 = Some(o)
78 }
79 }
80 ["delay", delay] => {
81 if let Ok(d) = delay.parse::<f64>() {
82 state.2 = Some(d)
83 }
84 }
85 _ => (),
86 }
87 state
88 }
89
90 pub fn decode_info(input: &str) -> Option<Info> {
94 let result = input
95 .split(',')
96 .filter(|v| !v.is_empty())
97 .map(|v| v.trim_start())
98 .fold((None, None, None), update_host_info);
99
100 if let (Some(a), Some(o), Some(d)) = result {
101 Some(Info::new(a, o, d))
102 } else {
103 None
104 }
105 }
106}
107
108pub struct Instance {
109 sock: UdpSocket,
110 seq: u16,
111}
112
113pub struct Devices {
114 d_state: driver::ReadOnlyDevice<bool>,
115 d_source: driver::ReadOnlyDevice<String>,
116 d_offset: driver::ReadOnlyDevice<f64>,
117 d_delay: driver::ReadOnlyDevice<f64>,
118}
119
120impl Instance {
121 pub const NAME: &'static str = "ntp";
122
123 pub const SUMMARY: &'static str =
124 "monitors an NTP server and reports its state";
125
126 pub const DESCRIPTION: &'static str = include_str!("../README.md");
127
128 fn get_cfg_address(cfg: &DriverConfig) -> Result<SocketAddrV4> {
131 match cfg.get("addr") {
132 Some(toml::value::Value::String(addr)) => {
133 if let Ok(addr) = addr.parse::<SocketAddrV4>() {
134 Ok(addr)
135 } else {
136 Err(Error::ConfigError(String::from(
137 "'addr' not in hostname:port format",
138 )))
139 }
140 }
141 Some(_) => Err(Error::ConfigError(String::from(
142 "'addr' config parameter should be a string",
143 ))),
144 None => Err(Error::ConfigError(String::from(
145 "missing 'addr' parameter in config",
146 ))),
147 }
148 }
149
150 fn read_u16(buf: &[u8]) -> u16 {
154 (buf[0] as u16) * 256 + (buf[1] as u16)
155 }
156
157 async fn get_synced_host(&mut self) -> Option<u16> {
158 let req: [u8; 12] = [
159 0x26,
160 0x01,
161 (self.seq / 256) as u8,
162 (self.seq % 256) as u8,
163 0x00,
164 0x00,
165 0x00,
166 0x00,
167 0x00,
168 0x00,
169 0x00,
170 0x00,
171 ];
172
173 self.seq += 1;
174
175 if let Err(e) = self.sock.send(&req).await {
179 error!("couldn't send \"synced hosts\" request -> {}", e);
180 return None;
181 }
182
183 let mut buf = [0u8; 500];
184
185 #[rustfmt::skip]
186 tokio::select! {
187 result = self.sock.recv(&mut buf) => {
188 match result {
189 Ok(len) if len < 12 => {
194 warn!(
195 "response from ntpd < 12 bytes -> only {} bytes",
196 len
197 )
198 }
199
200 Ok(len) => {
201 let total = Instance::read_u16(&buf[10..=11]) as usize;
202 let expected_len = total + 12 + (4 - total % 4) % 4;
203
204 if expected_len == len {
209 for ii in buf[12..len].chunks_exact(4) {
210 if (ii[2] & 0x7) == 6 {
211 return Some(Instance::read_u16(
212 &ii[0..=1],
213 ));
214 }
215 }
216 } else {
217 warn!(
218 "bad packet length -> expected {}, got {}",
219 expected_len, len
220 );
221 }
222 }
223 Err(e) => error!("couldn't receive data -> {}", e),
224 }
225 },
226 _ = tokio::time::sleep(std::time::Duration::from_millis(1_000)) => {
227 warn!("timed-out waiting for reply to \"get synced host\" request")
228 }
229 }
230
231 None
232 }
233
234 pub async fn get_host_info(&mut self, id: u16) -> Option<server::Info> {
238 let req = &[
239 0x26,
240 0x02,
241 (self.seq / 256) as u8,
242 (self.seq % 256) as u8,
243 0x00,
244 0x00,
245 (id / 256) as u8,
246 (id % 256) as u8,
247 0x00,
248 0x00,
249 0x00,
250 0x00,
251 ];
252
253 self.seq += 1;
254
255 if let Err(e) = self.sock.send(req).await {
256 error!("couldn't send \"host info\" request -> {}", e);
257 return None;
258 }
259
260 let mut buf = [0u8; 500];
261 let mut payload = [0u8; 2048];
262 let mut next_offset = 0;
263
264 loop {
265 #[rustfmt::skip]
266 tokio::select! {
267 result = self.sock.recv(&mut buf) => {
268 match result {
269 Ok(len) if len < 12 => {
274 warn!("response from ntpd < 12 bytes -> {}", len);
275 break;
276 }
277
278 Ok(len) => {
279 let offset = Instance::read_u16(&buf[8..=9]) as usize;
280
281 if offset != next_offset {
288 warn!("dropped packet (incorrect offset)");
289 break;
290 }
291
292 let total = Instance::read_u16(&buf[10..=11]) as usize;
293 let expected_len = total + 12 + (4 - total % 4) % 4;
294
295 if expected_len != len {
301 warn!(
302 "bad packet length -> expected {}, got {}",
303 expected_len, len
304 );
305 break;
306 }
307
308 if offset + total > payload.len() {
313 warn!(
314 "payload too big (offset {}, total {}, target buf: {})",
315 offset,
316 total,
317 payload.len()
318 );
319 break;
320 }
321
322 next_offset += total;
325
326 let dst_range = offset..offset + total;
329 let src_range = 12..12 + total;
330
331 trace!(
332 "copying {} bytes into {} through {}",
333 dst_range.len(),
334 dst_range.start,
335 dst_range.end - 1
336 );
337
338 payload[dst_range].clone_from_slice(&buf[src_range]);
339
340 if (buf[1] & 0x20) == 0 {
345 let payload = &payload[..next_offset];
346
347 return str::from_utf8(payload)
348 .ok()
349 .and_then(server::decode_info)
350 }
351 }
352 Err(e) => {
353 error!("couldn't receive data -> {}", e);
354 break;
355 }
356 }
357 },
358 _ = tokio::time::sleep(std::time::Duration::from_millis(1_000)) => {
359 warn!("timed-out waiting for reply to \"get host info\" request")
360 }
361 }
362 }
363 None
364 }
365}
366
367impl driver::API for Instance {
368 type DeviceSet = Devices;
369
370 fn register_devices(
371 core: driver::RequestChan,
372 _: &DriverConfig,
373 max_history: Option<usize>,
374 ) -> Pin<Box<dyn Future<Output = Result<Self::DeviceSet>> + Send>> {
375 let state_name = "state".parse::<device::Base>().unwrap();
380 let source_name = "source".parse::<device::Base>().unwrap();
381 let offset_name = "offset".parse::<device::Base>().unwrap();
382 let delay_name = "delay".parse::<device::Base>().unwrap();
383
384 Box::pin(async move {
385 let d_state =
388 core.add_ro_device(state_name, None, max_history).await?;
389 let d_source =
390 core.add_ro_device(source_name, None, max_history).await?;
391 let d_offset = core
392 .add_ro_device(offset_name, Some("ms"), max_history)
393 .await?;
394 let d_delay = core
395 .add_ro_device(delay_name, Some("ms"), max_history)
396 .await?;
397
398 Ok(Devices {
399 d_state,
400 d_source,
401 d_offset,
402 d_delay,
403 })
404 })
405 }
406
407 fn create_instance(
408 cfg: &DriverConfig,
409 ) -> Pin<Box<dyn Future<Output = Result<Box<Self>>> + Send>> {
410 let addr = Instance::get_cfg_address(cfg);
411
412 let fut = async move {
413 let addr = addr?;
416 let loc_if = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
417
418 Span::current().record("cfg", addr.to_string());
419
420 if let Ok(sock) = UdpSocket::bind(loc_if).await {
421 if sock.connect(addr).await.is_ok() {
422 return Ok(Box::new(Instance { sock, seq: 1 }));
423 }
424 }
425 Err(Error::OperationError("couldn't create socket".to_owned()))
426 };
427
428 Box::pin(fut)
429 }
430
431 fn run<'a>(
432 &'a mut self,
433 devices: Arc<Mutex<Devices>>,
434 ) -> Pin<Box<dyn Future<Output = Infallible> + Send + 'a>> {
435 let fut = async move {
436 {
440 let addr = self
441 .sock
442 .peer_addr()
443 .map(|v| format!("{}", v))
444 .unwrap_or_else(|_| String::from("**unknown**"));
445
446 Span::current().record("cfg", addr.as_str());
447 }
448
449 let mut info = Some(server::Info::bad_value());
455 let mut interval = time::interval(Duration::from_millis(20_000));
456
457 let mut devices = devices.lock().await;
458
459 loop {
460 interval.tick().await;
461
462 if let Some(id) = self.get_synced_host().await {
463 debug!("synced to host ID: {:#04x}", id);
464
465 let host_info = self.get_host_info(id).await;
466
467 match host_info {
468 Some(ref tmp) => {
469 if info != host_info {
470 debug!(
471 "host: {}, offset: {} ms, delay: {} ms",
472 tmp.get_host(),
473 tmp.get_offset(),
474 tmp.get_delay()
475 );
476 devices
477 .d_source
478 .report_update(tmp.get_host().clone())
479 .await;
480 devices
481 .d_offset
482 .report_update(tmp.get_offset())
483 .await;
484 devices
485 .d_delay
486 .report_update(tmp.get_delay())
487 .await;
488 devices.d_state.report_update(true).await;
489 info = host_info;
490 }
491 continue;
492 }
493 None => {
494 if info.is_some() {
495 warn!("no synced host information found");
496 info = None;
497 devices.d_state.report_update(false).await;
498 }
499 }
500 }
501 } else if info.is_some() {
502 warn!("we're not synced to any host");
503 info = None;
504 devices.d_state.report_update(false).await;
505 }
506 }
507 };
508
509 Box::pin(fut)
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_decoding() {
519 assert_eq!(
520 server::decode_info("srcadr=192.168.1.1,offset=0.0,delay=0.0"),
521 Some(server::Info::new(String::from("192.168.1.1"), 0.0, 0.0))
522 );
523 assert_eq!(
524 server::decode_info(" srcadr=192.168.1.1, offset=0.0, delay=0.0"),
525 Some(server::Info::new(String::from("192.168.1.1"), 0.0, 0.0))
526 );
527
528 assert_eq!(server::decode_info(" offset=0.0, delay=0.0"), None);
531 assert_eq!(server::decode_info(" srcadr=192.168.1.1, delay=0.0"), None);
532 assert_eq!(
533 server::decode_info(" srcadr=192.168.1.1, offset=0.0"),
534 None
535 );
536
537 assert!(server::decode_info("srcadr=192.168.1.1,offset=b,delay=0.0")
540 .is_none());
541 assert!(server::decode_info("srcadr=192.168.1.1,offset=0.0,delay=b")
542 .is_none());
543 }
544}