drmem_drv_ntp/
lib.rs

1use drmem_api::{
2    device,
3    driver::{self, DriverConfig},
4    Error, Result,
5};
6use std::future::Future;
7use std::sync::Arc;
8use std::{convert::Infallible, pin::Pin};
9use std::{
10    net::{SocketAddr, SocketAddrV4},
11    str,
12};
13use tokio::net::UdpSocket;
14use tokio::sync::Mutex;
15use tokio::time::{self, Duration};
16use tracing::{debug, error, trace, warn, Span};
17
18// Encapsulates data types and algorithms related to NTP server
19// information.
20
21mod server {
22    use super::*;
23
24    // Holds interesting state information for an NTP server.
25
26    #[derive(Debug, PartialEq)]
27    pub struct Info(String, f64, f64);
28
29    impl Info {
30        // Creates a new, initialized `Info` type.
31
32        pub fn new(host: String, offset: f64, delay: f64) -> Info {
33            Info(host, offset, delay)
34        }
35
36        // Creates a value which will never match any value returned
37        // by an NTP server (because the host will never be blank.)
38
39        pub fn bad_value() -> Info {
40            Info(String::from(""), 0.0, 0.0)
41        }
42
43        // Returns the IP address of the NTP server.
44
45        pub fn get_host(&self) -> &String {
46            &self.0
47        }
48
49        // Returns the estimated offset (in milliseconds) of the
50        // system time compared to the NTP server.
51
52        pub fn get_offset(&self) -> f64 {
53            self.1
54        }
55
56        // Returns the estimated time-of-flight delay (in
57        // milliseconds) to the NTP server.
58
59        pub fn get_delay(&self) -> f64 {
60            self.2
61        }
62    }
63
64    // Updates the `Info` object using up to three "interesting"
65    // parameters from text consisting of comma-separated,
66    // key/value pairs. The original `Info` is consumed by this
67    // method.
68
69    fn update_host_info(
70        mut state: (Option<String>, Option<f64>, Option<f64>),
71        item: &str,
72    ) -> (Option<String>, Option<f64>, Option<f64>) {
73        match item.split('=').collect::<Vec<&str>>()[..] {
74            ["srcadr", adr] => state.0 = Some(String::from(adr)),
75            ["offset", offset] => {
76                if let Ok(o) = offset.parse::<f64>() {
77                    state.1 = Some(o)
78                }
79            }
80            ["delay", delay] => {
81                if let Ok(d) = delay.parse::<f64>() {
82                    state.2 = Some(d)
83                }
84            }
85            _ => (),
86        }
87        state
88    }
89
90    // Returns an `Info` type that has been initialized with the
91    // parameters defined in `input`.
92
93    pub fn decode_info(input: &str) -> Option<Info> {
94        let result = input
95            .split(',')
96            .filter(|v| !v.is_empty())
97            .map(|v| v.trim_start())
98            .fold((None, None, None), update_host_info);
99
100        if let (Some(a), Some(o), Some(d)) = result {
101            Some(Info::new(a, o, d))
102        } else {
103            None
104        }
105    }
106}
107
108pub struct Instance {
109    sock: UdpSocket,
110    seq: u16,
111}
112
113pub struct Devices {
114    d_state: driver::ReadOnlyDevice<bool>,
115    d_source: driver::ReadOnlyDevice<String>,
116    d_offset: driver::ReadOnlyDevice<f64>,
117    d_delay: driver::ReadOnlyDevice<f64>,
118}
119
120impl Instance {
121    pub const NAME: &'static str = "ntp";
122
123    pub const SUMMARY: &'static str =
124        "monitors an NTP server and reports its state";
125
126    pub const DESCRIPTION: &'static str = include_str!("../README.md");
127
128    // Attempts to pull the hostname/port for the remote process.
129
130    fn get_cfg_address(cfg: &DriverConfig) -> Result<SocketAddrV4> {
131        match cfg.get("addr") {
132            Some(toml::value::Value::String(addr)) => {
133                if let Ok(addr) = addr.parse::<SocketAddrV4>() {
134                    Ok(addr)
135                } else {
136                    Err(Error::ConfigError(String::from(
137                        "'addr' not in hostname:port format",
138                    )))
139                }
140            }
141            Some(_) => Err(Error::ConfigError(String::from(
142                "'addr' config parameter should be a string",
143            ))),
144            None => Err(Error::ConfigError(String::from(
145                "missing 'addr' parameter in config",
146            ))),
147        }
148    }
149
150    // Combines and returns the first two bytes from a buffer as a
151    // big-endian, 16-bit value.
152
153    fn read_u16(buf: &[u8]) -> u16 {
154        (buf[0] as u16) * 256 + (buf[1] as u16)
155    }
156
157    async fn get_synced_host(&mut self) -> Option<u16> {
158        let req: [u8; 12] = [
159            0x26,
160            0x01,
161            (self.seq / 256) as u8,
162            (self.seq % 256) as u8,
163            0x00,
164            0x00,
165            0x00,
166            0x00,
167            0x00,
168            0x00,
169            0x00,
170            0x00,
171        ];
172
173        self.seq += 1;
174
175        // Try to send the request. If there's a failure with the
176        // socket, report the error and return `None`.
177
178        if let Err(e) = self.sock.send(&req).await {
179            error!("couldn't send \"synced hosts\" request -> {}", e);
180            return None;
181        }
182
183        let mut buf = [0u8; 500];
184
185        #[rustfmt::skip]
186	tokio::select! {
187	    result = self.sock.recv(&mut buf) => {
188		match result {
189		    // The packet has to be at least 12 bytes so we
190		    // can use all parts of the header without
191		    // worrying about panicking.
192
193		    Ok(len) if len < 12 => {
194			warn!(
195			    "response from ntpd < 12 bytes -> only {} bytes",
196			    len
197			)
198		    }
199
200		    Ok(len) => {
201			let total = Instance::read_u16(&buf[10..=11]) as usize;
202			let expected_len = total + 12 + (4 - total % 4) % 4;
203
204			// Make sure the incoming buffer is as large
205			// as the length field says it is (so we can
206			// safely access the entire payload.)
207
208			if expected_len == len {
209			    for ii in buf[12..len].chunks_exact(4) {
210				if (ii[2] & 0x7) == 6 {
211				    return Some(Instance::read_u16(
212					&ii[0..=1],
213				    ));
214				}
215			    }
216			} else {
217			    warn!(
218				"bad packet length -> expected {}, got {}",
219				expected_len, len
220			    );
221			}
222		    }
223		    Err(e) => error!("couldn't receive data -> {}", e),
224		}
225	    },
226	    _ = tokio::time::sleep(std::time::Duration::from_millis(1_000)) => {
227		warn!("timed-out waiting for reply to \"get synced host\" request")
228	    }
229	}
230
231        None
232    }
233
234    // Requests information about a given association ID. An `Info`
235    // type is returned containing the parameters we find interesting.
236
237    pub async fn get_host_info(&mut self, id: u16) -> Option<server::Info> {
238        let req = &[
239            0x26,
240            0x02,
241            (self.seq / 256) as u8,
242            (self.seq % 256) as u8,
243            0x00,
244            0x00,
245            (id / 256) as u8,
246            (id % 256) as u8,
247            0x00,
248            0x00,
249            0x00,
250            0x00,
251        ];
252
253        self.seq += 1;
254
255        if let Err(e) = self.sock.send(req).await {
256            error!("couldn't send \"host info\" request -> {}", e);
257            return None;
258        }
259
260        let mut buf = [0u8; 500];
261        let mut payload = [0u8; 2048];
262        let mut next_offset = 0;
263
264        loop {
265            #[rustfmt::skip]
266	    tokio::select! {
267		result = self.sock.recv(&mut buf) => {
268		    match result {
269			// The packet has to be at least 12 bytes so
270			// we can use all parts of the header without
271			// worrying about panicking.
272
273			Ok(len) if len < 12 => {
274			    warn!("response from ntpd < 12 bytes -> {}", len);
275			    break;
276			}
277
278			Ok(len) => {
279			    let offset = Instance::read_u16(&buf[8..=9]) as usize;
280
281			    // We don't keep track of which of the
282			    // multiple packets we've already
283			    // received. Instead, we require the
284			    // packets are sent in order. This warning
285			    // has never been emitted.
286
287			    if offset != next_offset {
288				warn!("dropped packet (incorrect offset)");
289				break;
290			    }
291
292			    let total = Instance::read_u16(&buf[10..=11]) as usize;
293			    let expected_len = total + 12 + (4 - total % 4) % 4;
294
295			    // Make sure the incoming buffer is as
296			    // large as the length field says it is
297			    // (so we can safely access the entire
298			    // payload.)
299
300			    if expected_len != len {
301				warn!(
302				    "bad packet length -> expected {}, got {}",
303				    expected_len, len
304				);
305				break;
306			    }
307
308			    // Make sure the reply's offset and total
309			    // won't push us past the end of our
310			    // buffer.
311
312			    if offset + total > payload.len() {
313				warn!(
314				    "payload too big (offset {}, total {}, target buf: {})",
315				    offset,
316				    total,
317				    payload.len()
318				);
319				break;
320			    }
321
322			    // Update the next, expected offset.
323
324			    next_offset += total;
325
326			    // Copy the fragment into the final buffer.
327
328			    let dst_range = offset..offset + total;
329			    let src_range = 12..12 + total;
330
331			    trace!(
332				"copying {} bytes into {} through {}",
333				dst_range.len(),
334				dst_range.start,
335				dst_range.end - 1
336			    );
337
338			    payload[dst_range].clone_from_slice(&buf[src_range]);
339
340			    // If this is the last packet, we can
341			    // process it. Convert the byte buffer to
342			    // text and decode it.
343
344			    if (buf[1] & 0x20) == 0 {
345				let payload = &payload[..next_offset];
346
347				return str::from_utf8(payload)
348				    .ok()
349				    .and_then(server::decode_info)
350			    }
351			}
352			Err(e) => {
353			    error!("couldn't receive data -> {}", e);
354			    break;
355			}
356		    }
357		},
358		_ = tokio::time::sleep(std::time::Duration::from_millis(1_000)) => {
359		    warn!("timed-out waiting for reply to \"get host info\" request")
360		}
361	    }
362        }
363        None
364    }
365}
366
367impl driver::API for Instance {
368    type DeviceSet = Devices;
369
370    fn register_devices(
371        core: driver::RequestChan,
372        _: &DriverConfig,
373        max_history: Option<usize>,
374    ) -> Pin<Box<dyn Future<Output = Result<Self::DeviceSet>> + Send>> {
375        // It's safe to use `.unwrap()` for these names because, in a
376        // fully-tested, released version of this driver, we would
377        // have seen and fixed any panics.
378
379        let state_name = "state".parse::<device::Base>().unwrap();
380        let source_name = "source".parse::<device::Base>().unwrap();
381        let offset_name = "offset".parse::<device::Base>().unwrap();
382        let delay_name = "delay".parse::<device::Base>().unwrap();
383
384        Box::pin(async move {
385            // Define the devices managed by this driver.
386
387            let d_state =
388                core.add_ro_device(state_name, None, max_history).await?;
389            let d_source =
390                core.add_ro_device(source_name, None, max_history).await?;
391            let d_offset = core
392                .add_ro_device(offset_name, Some("ms"), max_history)
393                .await?;
394            let d_delay = core
395                .add_ro_device(delay_name, Some("ms"), max_history)
396                .await?;
397
398            Ok(Devices {
399                d_state,
400                d_source,
401                d_offset,
402                d_delay,
403            })
404        })
405    }
406
407    fn create_instance(
408        cfg: &DriverConfig,
409    ) -> Pin<Box<dyn Future<Output = Result<Box<Self>>> + Send>> {
410        let addr = Instance::get_cfg_address(cfg);
411
412        let fut = async move {
413            // Validate the configuration.
414
415            let addr = addr?;
416            let loc_if = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
417
418            Span::current().record("cfg", addr.to_string());
419
420            if let Ok(sock) = UdpSocket::bind(loc_if).await {
421                if sock.connect(addr).await.is_ok() {
422                    return Ok(Box::new(Instance { sock, seq: 1 }));
423                }
424            }
425            Err(Error::OperationError("couldn't create socket".to_owned()))
426        };
427
428        Box::pin(fut)
429    }
430
431    fn run<'a>(
432        &'a mut self,
433        devices: Arc<Mutex<Devices>>,
434    ) -> Pin<Box<dyn Future<Output = Infallible> + Send + 'a>> {
435        let fut = async move {
436            // Record the peer's address in the "cfg" field of the
437            // span.
438
439            {
440                let addr = self
441                    .sock
442                    .peer_addr()
443                    .map(|v| format!("{}", v))
444                    .unwrap_or_else(|_| String::from("**unknown**"));
445
446                Span::current().record("cfg", addr.as_str());
447            }
448
449            // Set `info` to an initial, unmatchable value. `None`
450            // would be preferrable here but, if DrMem had a problem
451            // at startup getting the NTP state, it wouldn't print the
452            // warning(s).
453
454            let mut info = Some(server::Info::bad_value());
455            let mut interval = time::interval(Duration::from_millis(20_000));
456
457            let mut devices = devices.lock().await;
458
459            loop {
460                interval.tick().await;
461
462                if let Some(id) = self.get_synced_host().await {
463                    debug!("synced to host ID: {:#04x}", id);
464
465                    let host_info = self.get_host_info(id).await;
466
467                    match host_info {
468                        Some(ref tmp) => {
469                            if info != host_info {
470                                debug!(
471                                    "host: {}, offset: {} ms, delay: {} ms",
472                                    tmp.get_host(),
473                                    tmp.get_offset(),
474                                    tmp.get_delay()
475                                );
476                                devices
477                                    .d_source
478                                    .report_update(tmp.get_host().clone())
479                                    .await;
480                                devices
481                                    .d_offset
482                                    .report_update(tmp.get_offset())
483                                    .await;
484                                devices
485                                    .d_delay
486                                    .report_update(tmp.get_delay())
487                                    .await;
488                                devices.d_state.report_update(true).await;
489                                info = host_info;
490                            }
491                            continue;
492                        }
493                        None => {
494                            if info.is_some() {
495                                warn!("no synced host information found");
496                                info = None;
497                                devices.d_state.report_update(false).await;
498                            }
499                        }
500                    }
501                } else if info.is_some() {
502                    warn!("we're not synced to any host");
503                    info = None;
504                    devices.d_state.report_update(false).await;
505                }
506            }
507        };
508
509        Box::pin(fut)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_decoding() {
519        assert_eq!(
520            server::decode_info("srcadr=192.168.1.1,offset=0.0,delay=0.0"),
521            Some(server::Info::new(String::from("192.168.1.1"), 0.0, 0.0))
522        );
523        assert_eq!(
524            server::decode_info(" srcadr=192.168.1.1, offset=0.0, delay=0.0"),
525            Some(server::Info::new(String::from("192.168.1.1"), 0.0, 0.0))
526        );
527
528        // Should return `None` if fields are missing.
529
530        assert_eq!(server::decode_info(" offset=0.0, delay=0.0"), None);
531        assert_eq!(server::decode_info(" srcadr=192.168.1.1, delay=0.0"), None);
532        assert_eq!(
533            server::decode_info(" srcadr=192.168.1.1, offset=0.0"),
534            None
535        );
536
537        // Test badly formed input.
538
539        assert!(server::decode_info("srcadr=192.168.1.1,offset=b,delay=0.0")
540            .is_none());
541        assert!(server::decode_info("srcadr=192.168.1.1,offset=0.0,delay=b")
542            .is_none());
543    }
544}