chrony_candm/
async_net.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: GPL-2.0-only
3
4use std::collections::{HashMap, VecDeque};
5use std::ffi::OsStr;
6use std::fs::Permissions;
7use std::future::Future;
8use std::hash::{Hash, Hasher};
9use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13
14use futures::future::TryFutureExt;
15use rand::Rng;
16use tokio::sync::mpsc::error::SendError;
17use tokio::sync::OnceCell;
18use tokio::time::Instant;
19
20use crate::common::QueryError;
21use crate::net::ClientOptions;
22use crate::reply::Reply;
23use crate::request::{Request, RequestBody};
24
25#[cfg(unix)]
26use std::os::unix::ffi::OsStrExt;
27#[cfg(unix)]
28use std::os::unix::fs::PermissionsExt;
29#[cfg(unix)]
30use tokio::net::UnixDatagram;
31
32#[cfg(unix)]
33#[derive(Debug)]
34pub struct UnixDatagramClient(UnixDatagram);
35
36#[cfg(unix)]
37impl AsRef<UnixDatagram> for UnixDatagramClient {
38    fn as_ref(&self) -> &UnixDatagram {
39        &self.0
40    }
41}
42
43#[cfg(unix)]
44impl AsMut<UnixDatagram> for UnixDatagramClient {
45    fn as_mut(&mut self) -> &mut UnixDatagram {
46        &mut self.0
47    }
48}
49
50#[cfg(unix)]
51impl Deref for UnixDatagramClient {
52    type Target = UnixDatagram;
53    fn deref(&self) -> &UnixDatagram {
54        &self.0
55    }
56}
57
58#[cfg(unix)]
59impl DerefMut for UnixDatagramClient {
60    fn deref_mut(&mut self) -> &mut UnixDatagram {
61        &mut self.0
62    }
63}
64
65#[cfg(unix)]
66impl Drop for UnixDatagramClient {
67    fn drop(&mut self) {
68        if let Ok(addr) = self.0.local_addr() {
69            if let Some(path) = addr.as_pathname() {
70                let _ = self.0.shutdown(std::net::Shutdown::Both);
71                let _ = std::fs::remove_file(path);
72            }
73        }
74    }
75}
76
77#[cfg(unix)]
78impl UnixDatagramClient {
79    pub async fn new() -> std::io::Result<UnixDatagramClient> {
80        let id: [u8; 16] = rand::random();
81        let mut path = b"/var/run/chrony/client-000102030405060708090a0b0c0d0e0f.sock".clone();
82        hex::encode_to_slice(id, &mut path[23..55]).unwrap();
83        let path_str = OsStr::from_bytes(&path);
84        let sock = UnixDatagram::bind(path_str)?;
85        let client = UnixDatagramClient(sock);
86        std::fs::set_permissions(path_str, Permissions::from_mode(0o777))?;
87        client.connect("/var/run/chrony/chronyd.sock")?;
88        Ok(client)
89    }
90
91    /// Query chronyd using this UnixDomainSocket
92    ///
93    /// Sends a request to a server and waits for the response
94    ///
95    /// # Errors
96    /// See [`QueryError`] for more info
97    ///
98    /// # NOTE
99    /// This function takes `&mut self` unnecessarily to prevent the footgun of making concurrent requests to `chronyd`
100    pub async fn query(
101        &mut self,
102        request: RequestBody,
103        options: ClientOptions,
104    ) -> Result<Reply, QueryError> {
105        use bytes::BytesMut;
106        let request = Request {
107            sequence: rand::random(),
108            attempt: 0,
109            body: request,
110        };
111
112        let mut send_buf = BytesMut::with_capacity(request.length());
113        request.serialize(&mut send_buf);
114        let mut recv_buf = [0; 1500];
115        let mut attempt = 0;
116
117        while attempt < options.n_tries {
118            self.0.send(&send_buf).await.map_err(QueryError::Send)?;
119            let Ok(io_result) =
120                tokio::time::timeout(options.timeout, self.0.recv(&mut recv_buf)).await
121            else {
122                attempt += 1;
123                continue;
124            };
125            let size = io_result.map_err(QueryError::Recv)?;
126            let mut msg = &recv_buf[..size];
127            let reply = Reply::deserialize(&mut msg)?;
128            if reply.sequence == request.sequence {
129                return Ok(reply);
130            } else {
131                return Err(QueryError::SequenceMismatch {
132                    expected: request.sequence,
133                    received: reply.sequence,
134                });
135            }
136        }
137        Err(QueryError::Timeout)
138    }
139}
140
141#[derive(Debug, Hash)]
142enum ServerAddr {
143    Udp(SocketAddrV6),
144    #[cfg(unix)]
145    Unix,
146}
147
148type ReplySender = tokio::sync::oneshot::Sender<std::io::Result<Reply>>;
149type ReplyReceiver = tokio::sync::oneshot::Receiver<std::io::Result<Reply>>;
150#[derive(Debug)]
151struct RequestMeta {
152    body: RequestBody,
153    reply_sender: ReplySender,
154    server: ServerAddr,
155}
156
157type RequestSender = tokio::sync::mpsc::UnboundedSender<RequestMeta>;
158type RequestReceiver = tokio::sync::mpsc::UnboundedReceiver<RequestMeta>;
159
160#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
161enum ServerKey {
162    Udp(SocketAddr),
163    #[cfg(unix)]
164    Unix,
165}
166
167#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
168struct InflightKey {
169    server_key: ServerKey,
170    sequence: u32,
171}
172
173#[derive(Debug)]
174struct InflightValue {
175    request: Vec<u8>,
176    attempt: u16,
177    reply_sender: ReplySender,
178    server: ServerAddr,
179}
180
181/// Asynchronously sends requests and receives replies
182#[deprecated = "Persistent client overly complicates retry logic"]
183#[derive(Debug)]
184pub struct Client {
185    task_handle: tokio::task::JoinHandle<()>,
186    sender: RequestSender,
187}
188
189/// A future which can be `await`ed to obtain the reply to a sent query
190#[derive(Debug)]
191pub struct ReplyFuture(ReplyReceiver);
192
193impl Future for ReplyFuture {
194    type Output = std::io::Result<Reply>;
195
196    fn poll(
197        self: std::pin::Pin<&mut Self>,
198        cx: &mut std::task::Context<'_>,
199    ) -> std::task::Poll<Self::Output> {
200        let receiver = &mut self.get_mut().0;
201        let mut result = receiver.unwrap_or_else(|e| {
202            Err(std::io::Error::new(
203                std::io::ErrorKind::ConnectionAborted,
204                e,
205            ))
206        });
207        Pin::new(&mut result).poll(cx)
208    }
209}
210
211impl Client {
212    /// Spawns a task to handle sending of requests and receiving of replies, and returns a `Client`
213    /// that communicates with this task.
214    pub fn spawn(handle: &tokio::runtime::Handle, options: crate::net::ClientOptions) -> Client {
215        let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
216        let task_handle = handle.spawn(client_task(options, receiver));
217        Client {
218            task_handle,
219            sender,
220        }
221    }
222
223    /// Sends a request to the given server, and returns a future which can be `await`ed to obtain the reply.
224    ///
225    /// Note that this is *not* an async function; it is a synchronous, non-blocking function which returns a future. Even
226    /// if you don't immediately await the returned future, the request will still immediately be dispatched
227    /// to the task which was spawned to service this client, which will immediately form it into a packet
228    /// and send it.
229    ///
230    /// The type of the `server` parameter is more restrictive than that of this method's blocking counterpart
231    /// [crate::net::blocking_query] because converting some implementations of `ToSocketAddrs` into a `SocketAddr` can
232    /// involve blocking on a DNS lookup, which here would be unacceptable. If you're communicating with localhost,
233    /// can call this method with `LOCAL_SERVER_ADDR.into()`. If you're communicating with a remote server,
234    /// you'll need to handle the DNS lookup yourself.
235    pub fn query(&self, request: RequestBody, server: SocketAddr) -> ReplyFuture {
236        let mapped_server = match server {
237            SocketAddr::V4(v4) => SocketAddrV6::new(v4.ip().to_ipv6_mapped(), v4.port(), 0, 0),
238            SocketAddr::V6(v6) => v6,
239        };
240
241        let (sender, receiver) = tokio::sync::oneshot::channel();
242        if let Err(SendError(request_meta)) = self.sender.send(RequestMeta {
243            body: request,
244            reply_sender: sender,
245            server: ServerAddr::Udp(mapped_server),
246        }) {
247            request_meta
248                .reply_sender
249                .send(Err(std::io::Error::new(
250                    std::io::ErrorKind::ConnectionAborted,
251                    "Client task unexpectedly shut down",
252                )))
253                .expect("Send failed but the receiver is still in scope?!")
254        }
255
256        ReplyFuture(receiver)
257    }
258
259    /// Sends a request to the local Chrony server via its UNIX domain socket, and returns a future which can be `await`ed to obtain the reply.
260    #[cfg(unix)]
261    pub fn query_uds(&self, request: RequestBody) -> ReplyFuture {
262        let (sender, receiver) = tokio::sync::oneshot::channel();
263
264        if let Err(SendError(request_meta)) = self.sender.send(RequestMeta {
265            body: request,
266            reply_sender: sender,
267            server: ServerAddr::Unix,
268        }) {
269            request_meta
270                .reply_sender
271                .send(Err(std::io::Error::new(
272                    std::io::ErrorKind::ConnectionAborted,
273                    "Client task unexpectedly shut down",
274                )))
275                .expect("Send failed but the receiver is still in scope?!")
276        }
277
278        ReplyFuture(receiver)
279    }
280}
281
282#[derive(Debug)]
283struct ReplyMeta<'a> {
284    reply: &'a [u8],
285    server_key: ServerKey,
286}
287
288#[derive(Debug)]
289enum SelectResult<'a> {
290    Request(RequestMeta),
291    Reply(ReplyMeta<'a>),
292    Timeout,
293    Error(std::io::Error),
294    Shutdown,
295}
296
297async fn client_task(options: ClientOptions, mut receiver: RequestReceiver) {
298    let mut deadlines: VecDeque<(Instant, InflightKey)> = std::collections::VecDeque::new();
299    let mut inflight: HashMap<InflightKey, InflightValue> = std::collections::HashMap::new();
300
301    let udp_init = || tokio::net::UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0));
302    let udp_cell: OnceCell<tokio::net::UdpSocket> = OnceCell::new();
303    let mut udp_buf = [0u8; 1500];
304
305    #[cfg(unix)]
306    let uds_init = || UnixDatagramClient::new();
307    #[cfg(unix)]
308    let uds_cell: OnceCell<UnixDatagramClient> = OnceCell::new();
309    #[cfg(unix)]
310    let mut uds_buf = [0u8; 1500];
311
312    let (mut sequence, key0, key1): (u32, u64, u64) = {
313        let mut rng = rand::thread_rng();
314        (rng.gen(), rng.gen(), rng.gen())
315    };
316
317    loop {
318        let now = tokio::time::Instant::now();
319
320        // Find expired deadlines and retransmit those messages, or give up
321        // if the attempt limit has been reached
322        while let Some((deadline, _)) = deadlines.front() {
323            if *deadline > now {
324                break;
325            }
326            //`deadline_key` and `inflight_key` have the same value, but
327            // we keep both copies in order to minimize further cloning.
328            let (_, deadline_key) = deadlines.pop_front().unwrap();
329            if let Some((inflight_key, mut inflight_val)) = inflight.remove_entry(&deadline_key) {
330                inflight_val.attempt += 1;
331                if inflight_val.attempt > options.n_tries {
332                    let _ = inflight_val.reply_sender.send(Err(std::io::Error::new(
333                        std::io::ErrorKind::TimedOut,
334                        "request timed out and max retries reached",
335                    )));
336                } else {
337                    crate::request::increment_attempt(inflight_val.request.as_mut());
338                    let send_result = match inflight_val.server {
339                        // These are retries, so we can safely unwrap() since we must have
340                        // gotten a socket the first time around.
341                        ServerAddr::Udp(addr) => {
342                            udp_cell
343                                .get()
344                                .unwrap()
345                                .send_to(inflight_val.request.as_ref(), addr)
346                                .await
347                        }
348                        #[cfg(unix)]
349                        ServerAddr::Unix => {
350                            uds_cell
351                                .get()
352                                .unwrap()
353                                .send(inflight_val.request.as_ref())
354                                .await
355                        }
356                    };
357                    match send_result {
358                        Ok(_) => {
359                            inflight.insert(inflight_key, inflight_val);
360                            let new_deadline = now + options.timeout;
361                            deadlines.push_back((new_deadline, deadline_key));
362                        }
363                        Err(e) => {
364                            let _ = inflight_val.reply_sender.send(Err(e));
365                        }
366                    }
367                }
368            }
369        }
370
371        // Cull unexpired deadlines that have already been met. On every iteration of
372        // the main loop, pop off the front of the queue until we reach something
373        // still in-flight. This takes amortized O(1), and it ensures that the next
374        // coming deadline is "real" and prevents unnecessary wakeups.
375        while let Some((_, inflight_key)) = deadlines.front() {
376            if inflight.contains_key(inflight_key) {
377                break;
378            } else {
379                deadlines.pop_front();
380            }
381        }
382
383        // In rare cases, like if the network comes back online in the middle of a
384        // big burst of queries, there might be a lot of met deadlines still remaining
385        // in the queue. Finding these takes O(N), so we don't want to do it every time.
386        // Instead, we do a complete cull only when the length of `deadlines` reaches
387        // twice the capacity of `inflight`. This keeps us at amortized O(1) and also
388        // keeps the deadline queue from asymptotically outgrowing the inflight table.
389        if deadlines.len() >= 2 * inflight.capacity() {
390            deadlines.retain(|(_, inflight_key)| inflight.contains_key(inflight_key))
391        }
392
393        let timeout = async {
394            match deadlines.front() {
395                Some((deadline, _)) => tokio::time::sleep_until(*deadline).await,
396                None => futures::future::pending().await,
397            }
398        };
399
400        let udp_recv = async {
401            match udp_cell.get() {
402                Some(udp) => {
403                    let (size, peer) = udp.recv_from(&mut udp_buf).await?;
404                    std::io::Result::Ok(ReplyMeta {
405                        reply: &udp_buf[0..size],
406                        server_key: ServerKey::Udp(peer),
407                    })
408                }
409                _ => futures::future::pending().await,
410            }
411        };
412
413        #[cfg(unix)]
414        let uds_recv = async {
415            match uds_cell.get() {
416                Some(uds) => {
417                    let size = uds.recv(&mut uds_buf).await?;
418                    std::io::Result::Ok(ReplyMeta {
419                        reply: &uds_buf[0..size],
420                        server_key: ServerKey::Unix,
421                    })
422                }
423                _ => futures::future::pending().await,
424            }
425        };
426        #[cfg(not(unix))]
427        let uds_recv = futures::future::pending();
428
429        let select_result = tokio::select! {
430            result = udp_recv => match result {
431                Ok(reply_meta) => {
432                    SelectResult::Reply(reply_meta)
433                },
434                Err(e) => SelectResult::Error(e),
435            },
436            result = uds_recv => match result {
437                Ok(reply_meta) => {
438                    SelectResult::Reply(reply_meta)
439                },
440                Err(e) => SelectResult::Error(e),
441            },
442            result = receiver.recv() => {
443                match result {
444                    Some(request) => SelectResult::Request(request),
445                    None => SelectResult::Shutdown,
446                }
447            },
448            _ = timeout => SelectResult::Timeout
449        };
450
451        match select_result {
452            SelectResult::Request(request_meta) => {
453                // Sequnce numbers should be unpredictable in order to make
454                // off-path blind spoofing harder. A single global sequence
455                // number accumulator with a random initial state is mixed
456                // with a server-specific key derived using SipHash. This
457                // way we don't have to keep state for each individual server,
458                // sequence numbers sent to one server can't be used to guess
459                // another's, and we avoid the birthday-bound collisions that
460                // we'd get from picking a random number for every request.
461                let mut hasher = siphasher::sip::SipHasher::new_with_keys(key0, key1);
462                request_meta.server.hash(&mut hasher);
463                let obfuscated_sequence = sequence.wrapping_add(hasher.finish() as u32);
464                sequence = sequence.wrapping_add(1);
465
466                let request = Request {
467                    sequence: obfuscated_sequence,
468                    attempt: 0,
469                    body: request_meta.body,
470                };
471                let mut send_buf = Vec::with_capacity(request.length());
472                request.serialize(&mut send_buf);
473
474                let inflight_key = InflightKey {
475                    server_key: match request_meta.server {
476                        ServerAddr::Udp(addr) => ServerKey::Udp(addr.into()),
477                        #[cfg(unix)]
478                        ServerAddr::Unix => ServerKey::Unix,
479                    },
480                    sequence: obfuscated_sequence,
481                };
482
483                let inflight_val = InflightValue {
484                    request: send_buf,
485                    attempt: 0,
486                    reply_sender: request_meta.reply_sender,
487                    server: request_meta.server,
488                };
489
490                let deadline = now + options.timeout;
491
492                match inflight_val.server {
493                    ServerAddr::Udp(addr) => match udp_cell.get_or_try_init(udp_init).await {
494                        Ok(udp) => {
495                            if let Err(e) = udp.send_to(inflight_val.request.as_ref(), addr).await {
496                                let _ = inflight_val.reply_sender.send(Err(e));
497                                continue;
498                            }
499                        }
500                        Err(e) => {
501                            let _ = inflight_val.reply_sender.send(Err(e));
502                            continue;
503                        }
504                    },
505                    #[cfg(unix)]
506                    ServerAddr::Unix => match uds_cell.get_or_try_init(uds_init).await {
507                        Ok(uds) => {
508                            if let Err(e) = uds.send(inflight_val.request.as_ref()).await {
509                                let _ = inflight_val.reply_sender.send(Err(e));
510                                continue;
511                            }
512                        }
513                        Err(e) => {
514                            let _ = inflight_val.reply_sender.send(Err(e));
515                            continue;
516                        }
517                    },
518                }
519
520                deadlines.push_back((deadline, inflight_key.clone()));
521                inflight.insert(inflight_key.clone(), inflight_val);
522            }
523            SelectResult::Reply(reply_meta) => {
524                let mut reply_buf = reply_meta.reply;
525                if let Ok(reply) = Reply::deserialize(&mut reply_buf) {
526                    let inflight_key = InflightKey {
527                        server_key: reply_meta.server_key,
528                        sequence: reply.sequence,
529                    };
530                    if let Some(inflight_val) = inflight.remove(&inflight_key) {
531                        let _ = inflight_val.reply_sender.send(Ok(reply));
532                    }
533                }
534            }
535            SelectResult::Timeout => {}
536            SelectResult::Error(e) => {
537                if e.kind() == std::io::ErrorKind::Interrupted {
538                    continue;
539                }
540
541                // Any other kind of error should never happen here, but if it
542                // does we have to bail.
543                let erc = Arc::new(e);
544                receiver.close();
545                // Drain any requests from the channel and answer them all with the error
546                // we got.
547                while let Some(request) = receiver.recv().await {
548                    let _ = request
549                        .reply_sender
550                        .send(Err(std::io::Error::new(erc.kind(), erc.clone())));
551                }
552                // Also answer in-flight requests with an error.
553                for v in inflight.into_values() {
554                    let _ = v
555                        .reply_sender
556                        .send(Err(std::io::Error::new(erc.kind(), erc.clone())));
557                }
558                return;
559            }
560            SelectResult::Shutdown => {
561                for v in inflight.into_values() {
562                    let _ = v.reply_sender.send(Err(std::io::Error::new(
563                        std::io::ErrorKind::ConnectionAborted,
564                        "Client dropped before arrival of reply",
565                    )));
566                }
567                return;
568            }
569        }
570    }
571}
572
573/// Query chronyd using a Unix Domain Socket
574///
575/// Creates a unix domain socket client, sends a request to a server and waits for the response
576/// This has retry logic based off [`ClientOptions`]
577#[cfg(unix)]
578pub async fn query_uds(request: RequestBody, options: ClientOptions) -> std::io::Result<Reply> {
579    let mut client = UnixDatagramClient::new().await?;
580    client.query(request, options).await.map_err(QueryError::into_io)
581}