1use std::collections::{HashMap, VecDeque};
5use std::ffi::OsStr;
6use std::fs::Permissions;
7use std::future::Future;
8use std::hash::{Hash, Hasher};
9use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13
14use futures::future::TryFutureExt;
15use rand::Rng;
16use tokio::sync::mpsc::error::SendError;
17use tokio::sync::OnceCell;
18use tokio::time::Instant;
19
20use crate::common::QueryError;
21use crate::net::ClientOptions;
22use crate::reply::Reply;
23use crate::request::{Request, RequestBody};
24
25#[cfg(unix)]
26use std::os::unix::ffi::OsStrExt;
27#[cfg(unix)]
28use std::os::unix::fs::PermissionsExt;
29#[cfg(unix)]
30use tokio::net::UnixDatagram;
31
32#[cfg(unix)]
33#[derive(Debug)]
34pub struct UnixDatagramClient(UnixDatagram);
35
36#[cfg(unix)]
37impl AsRef<UnixDatagram> for UnixDatagramClient {
38 fn as_ref(&self) -> &UnixDatagram {
39 &self.0
40 }
41}
42
43#[cfg(unix)]
44impl AsMut<UnixDatagram> for UnixDatagramClient {
45 fn as_mut(&mut self) -> &mut UnixDatagram {
46 &mut self.0
47 }
48}
49
50#[cfg(unix)]
51impl Deref for UnixDatagramClient {
52 type Target = UnixDatagram;
53 fn deref(&self) -> &UnixDatagram {
54 &self.0
55 }
56}
57
58#[cfg(unix)]
59impl DerefMut for UnixDatagramClient {
60 fn deref_mut(&mut self) -> &mut UnixDatagram {
61 &mut self.0
62 }
63}
64
65#[cfg(unix)]
66impl Drop for UnixDatagramClient {
67 fn drop(&mut self) {
68 if let Ok(addr) = self.0.local_addr() {
69 if let Some(path) = addr.as_pathname() {
70 let _ = self.0.shutdown(std::net::Shutdown::Both);
71 let _ = std::fs::remove_file(path);
72 }
73 }
74 }
75}
76
77#[cfg(unix)]
78impl UnixDatagramClient {
79 pub async fn new() -> std::io::Result<UnixDatagramClient> {
80 let id: [u8; 16] = rand::random();
81 let mut path = b"/var/run/chrony/client-000102030405060708090a0b0c0d0e0f.sock".clone();
82 hex::encode_to_slice(id, &mut path[23..55]).unwrap();
83 let path_str = OsStr::from_bytes(&path);
84 let sock = UnixDatagram::bind(path_str)?;
85 let client = UnixDatagramClient(sock);
86 std::fs::set_permissions(path_str, Permissions::from_mode(0o777))?;
87 client.connect("/var/run/chrony/chronyd.sock")?;
88 Ok(client)
89 }
90
91 pub async fn query(
101 &mut self,
102 request: RequestBody,
103 options: ClientOptions,
104 ) -> Result<Reply, QueryError> {
105 use bytes::BytesMut;
106 let request = Request {
107 sequence: rand::random(),
108 attempt: 0,
109 body: request,
110 };
111
112 let mut send_buf = BytesMut::with_capacity(request.length());
113 request.serialize(&mut send_buf);
114 let mut recv_buf = [0; 1500];
115 let mut attempt = 0;
116
117 while attempt < options.n_tries {
118 self.0.send(&send_buf).await.map_err(QueryError::Send)?;
119 let Ok(io_result) =
120 tokio::time::timeout(options.timeout, self.0.recv(&mut recv_buf)).await
121 else {
122 attempt += 1;
123 continue;
124 };
125 let size = io_result.map_err(QueryError::Recv)?;
126 let mut msg = &recv_buf[..size];
127 let reply = Reply::deserialize(&mut msg)?;
128 if reply.sequence == request.sequence {
129 return Ok(reply);
130 } else {
131 return Err(QueryError::SequenceMismatch {
132 expected: request.sequence,
133 received: reply.sequence,
134 });
135 }
136 }
137 Err(QueryError::Timeout)
138 }
139}
140
141#[derive(Debug, Hash)]
142enum ServerAddr {
143 Udp(SocketAddrV6),
144 #[cfg(unix)]
145 Unix,
146}
147
148type ReplySender = tokio::sync::oneshot::Sender<std::io::Result<Reply>>;
149type ReplyReceiver = tokio::sync::oneshot::Receiver<std::io::Result<Reply>>;
150#[derive(Debug)]
151struct RequestMeta {
152 body: RequestBody,
153 reply_sender: ReplySender,
154 server: ServerAddr,
155}
156
157type RequestSender = tokio::sync::mpsc::UnboundedSender<RequestMeta>;
158type RequestReceiver = tokio::sync::mpsc::UnboundedReceiver<RequestMeta>;
159
160#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
161enum ServerKey {
162 Udp(SocketAddr),
163 #[cfg(unix)]
164 Unix,
165}
166
167#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
168struct InflightKey {
169 server_key: ServerKey,
170 sequence: u32,
171}
172
173#[derive(Debug)]
174struct InflightValue {
175 request: Vec<u8>,
176 attempt: u16,
177 reply_sender: ReplySender,
178 server: ServerAddr,
179}
180
181#[deprecated = "Persistent client overly complicates retry logic"]
183#[derive(Debug)]
184pub struct Client {
185 task_handle: tokio::task::JoinHandle<()>,
186 sender: RequestSender,
187}
188
189#[derive(Debug)]
191pub struct ReplyFuture(ReplyReceiver);
192
193impl Future for ReplyFuture {
194 type Output = std::io::Result<Reply>;
195
196 fn poll(
197 self: std::pin::Pin<&mut Self>,
198 cx: &mut std::task::Context<'_>,
199 ) -> std::task::Poll<Self::Output> {
200 let receiver = &mut self.get_mut().0;
201 let mut result = receiver.unwrap_or_else(|e| {
202 Err(std::io::Error::new(
203 std::io::ErrorKind::ConnectionAborted,
204 e,
205 ))
206 });
207 Pin::new(&mut result).poll(cx)
208 }
209}
210
211impl Client {
212 pub fn spawn(handle: &tokio::runtime::Handle, options: crate::net::ClientOptions) -> Client {
215 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
216 let task_handle = handle.spawn(client_task(options, receiver));
217 Client {
218 task_handle,
219 sender,
220 }
221 }
222
223 pub fn query(&self, request: RequestBody, server: SocketAddr) -> ReplyFuture {
236 let mapped_server = match server {
237 SocketAddr::V4(v4) => SocketAddrV6::new(v4.ip().to_ipv6_mapped(), v4.port(), 0, 0),
238 SocketAddr::V6(v6) => v6,
239 };
240
241 let (sender, receiver) = tokio::sync::oneshot::channel();
242 if let Err(SendError(request_meta)) = self.sender.send(RequestMeta {
243 body: request,
244 reply_sender: sender,
245 server: ServerAddr::Udp(mapped_server),
246 }) {
247 request_meta
248 .reply_sender
249 .send(Err(std::io::Error::new(
250 std::io::ErrorKind::ConnectionAborted,
251 "Client task unexpectedly shut down",
252 )))
253 .expect("Send failed but the receiver is still in scope?!")
254 }
255
256 ReplyFuture(receiver)
257 }
258
259 #[cfg(unix)]
261 pub fn query_uds(&self, request: RequestBody) -> ReplyFuture {
262 let (sender, receiver) = tokio::sync::oneshot::channel();
263
264 if let Err(SendError(request_meta)) = self.sender.send(RequestMeta {
265 body: request,
266 reply_sender: sender,
267 server: ServerAddr::Unix,
268 }) {
269 request_meta
270 .reply_sender
271 .send(Err(std::io::Error::new(
272 std::io::ErrorKind::ConnectionAborted,
273 "Client task unexpectedly shut down",
274 )))
275 .expect("Send failed but the receiver is still in scope?!")
276 }
277
278 ReplyFuture(receiver)
279 }
280}
281
282#[derive(Debug)]
283struct ReplyMeta<'a> {
284 reply: &'a [u8],
285 server_key: ServerKey,
286}
287
288#[derive(Debug)]
289enum SelectResult<'a> {
290 Request(RequestMeta),
291 Reply(ReplyMeta<'a>),
292 Timeout,
293 Error(std::io::Error),
294 Shutdown,
295}
296
297async fn client_task(options: ClientOptions, mut receiver: RequestReceiver) {
298 let mut deadlines: VecDeque<(Instant, InflightKey)> = std::collections::VecDeque::new();
299 let mut inflight: HashMap<InflightKey, InflightValue> = std::collections::HashMap::new();
300
301 let udp_init = || tokio::net::UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0));
302 let udp_cell: OnceCell<tokio::net::UdpSocket> = OnceCell::new();
303 let mut udp_buf = [0u8; 1500];
304
305 #[cfg(unix)]
306 let uds_init = || UnixDatagramClient::new();
307 #[cfg(unix)]
308 let uds_cell: OnceCell<UnixDatagramClient> = OnceCell::new();
309 #[cfg(unix)]
310 let mut uds_buf = [0u8; 1500];
311
312 let (mut sequence, key0, key1): (u32, u64, u64) = {
313 let mut rng = rand::thread_rng();
314 (rng.gen(), rng.gen(), rng.gen())
315 };
316
317 loop {
318 let now = tokio::time::Instant::now();
319
320 while let Some((deadline, _)) = deadlines.front() {
323 if *deadline > now {
324 break;
325 }
326 let (_, deadline_key) = deadlines.pop_front().unwrap();
329 if let Some((inflight_key, mut inflight_val)) = inflight.remove_entry(&deadline_key) {
330 inflight_val.attempt += 1;
331 if inflight_val.attempt > options.n_tries {
332 let _ = inflight_val.reply_sender.send(Err(std::io::Error::new(
333 std::io::ErrorKind::TimedOut,
334 "request timed out and max retries reached",
335 )));
336 } else {
337 crate::request::increment_attempt(inflight_val.request.as_mut());
338 let send_result = match inflight_val.server {
339 ServerAddr::Udp(addr) => {
342 udp_cell
343 .get()
344 .unwrap()
345 .send_to(inflight_val.request.as_ref(), addr)
346 .await
347 }
348 #[cfg(unix)]
349 ServerAddr::Unix => {
350 uds_cell
351 .get()
352 .unwrap()
353 .send(inflight_val.request.as_ref())
354 .await
355 }
356 };
357 match send_result {
358 Ok(_) => {
359 inflight.insert(inflight_key, inflight_val);
360 let new_deadline = now + options.timeout;
361 deadlines.push_back((new_deadline, deadline_key));
362 }
363 Err(e) => {
364 let _ = inflight_val.reply_sender.send(Err(e));
365 }
366 }
367 }
368 }
369 }
370
371 while let Some((_, inflight_key)) = deadlines.front() {
376 if inflight.contains_key(inflight_key) {
377 break;
378 } else {
379 deadlines.pop_front();
380 }
381 }
382
383 if deadlines.len() >= 2 * inflight.capacity() {
390 deadlines.retain(|(_, inflight_key)| inflight.contains_key(inflight_key))
391 }
392
393 let timeout = async {
394 match deadlines.front() {
395 Some((deadline, _)) => tokio::time::sleep_until(*deadline).await,
396 None => futures::future::pending().await,
397 }
398 };
399
400 let udp_recv = async {
401 match udp_cell.get() {
402 Some(udp) => {
403 let (size, peer) = udp.recv_from(&mut udp_buf).await?;
404 std::io::Result::Ok(ReplyMeta {
405 reply: &udp_buf[0..size],
406 server_key: ServerKey::Udp(peer),
407 })
408 }
409 _ => futures::future::pending().await,
410 }
411 };
412
413 #[cfg(unix)]
414 let uds_recv = async {
415 match uds_cell.get() {
416 Some(uds) => {
417 let size = uds.recv(&mut uds_buf).await?;
418 std::io::Result::Ok(ReplyMeta {
419 reply: &uds_buf[0..size],
420 server_key: ServerKey::Unix,
421 })
422 }
423 _ => futures::future::pending().await,
424 }
425 };
426 #[cfg(not(unix))]
427 let uds_recv = futures::future::pending();
428
429 let select_result = tokio::select! {
430 result = udp_recv => match result {
431 Ok(reply_meta) => {
432 SelectResult::Reply(reply_meta)
433 },
434 Err(e) => SelectResult::Error(e),
435 },
436 result = uds_recv => match result {
437 Ok(reply_meta) => {
438 SelectResult::Reply(reply_meta)
439 },
440 Err(e) => SelectResult::Error(e),
441 },
442 result = receiver.recv() => {
443 match result {
444 Some(request) => SelectResult::Request(request),
445 None => SelectResult::Shutdown,
446 }
447 },
448 _ = timeout => SelectResult::Timeout
449 };
450
451 match select_result {
452 SelectResult::Request(request_meta) => {
453 let mut hasher = siphasher::sip::SipHasher::new_with_keys(key0, key1);
462 request_meta.server.hash(&mut hasher);
463 let obfuscated_sequence = sequence.wrapping_add(hasher.finish() as u32);
464 sequence = sequence.wrapping_add(1);
465
466 let request = Request {
467 sequence: obfuscated_sequence,
468 attempt: 0,
469 body: request_meta.body,
470 };
471 let mut send_buf = Vec::with_capacity(request.length());
472 request.serialize(&mut send_buf);
473
474 let inflight_key = InflightKey {
475 server_key: match request_meta.server {
476 ServerAddr::Udp(addr) => ServerKey::Udp(addr.into()),
477 #[cfg(unix)]
478 ServerAddr::Unix => ServerKey::Unix,
479 },
480 sequence: obfuscated_sequence,
481 };
482
483 let inflight_val = InflightValue {
484 request: send_buf,
485 attempt: 0,
486 reply_sender: request_meta.reply_sender,
487 server: request_meta.server,
488 };
489
490 let deadline = now + options.timeout;
491
492 match inflight_val.server {
493 ServerAddr::Udp(addr) => match udp_cell.get_or_try_init(udp_init).await {
494 Ok(udp) => {
495 if let Err(e) = udp.send_to(inflight_val.request.as_ref(), addr).await {
496 let _ = inflight_val.reply_sender.send(Err(e));
497 continue;
498 }
499 }
500 Err(e) => {
501 let _ = inflight_val.reply_sender.send(Err(e));
502 continue;
503 }
504 },
505 #[cfg(unix)]
506 ServerAddr::Unix => match uds_cell.get_or_try_init(uds_init).await {
507 Ok(uds) => {
508 if let Err(e) = uds.send(inflight_val.request.as_ref()).await {
509 let _ = inflight_val.reply_sender.send(Err(e));
510 continue;
511 }
512 }
513 Err(e) => {
514 let _ = inflight_val.reply_sender.send(Err(e));
515 continue;
516 }
517 },
518 }
519
520 deadlines.push_back((deadline, inflight_key.clone()));
521 inflight.insert(inflight_key.clone(), inflight_val);
522 }
523 SelectResult::Reply(reply_meta) => {
524 let mut reply_buf = reply_meta.reply;
525 if let Ok(reply) = Reply::deserialize(&mut reply_buf) {
526 let inflight_key = InflightKey {
527 server_key: reply_meta.server_key,
528 sequence: reply.sequence,
529 };
530 if let Some(inflight_val) = inflight.remove(&inflight_key) {
531 let _ = inflight_val.reply_sender.send(Ok(reply));
532 }
533 }
534 }
535 SelectResult::Timeout => {}
536 SelectResult::Error(e) => {
537 if e.kind() == std::io::ErrorKind::Interrupted {
538 continue;
539 }
540
541 let erc = Arc::new(e);
544 receiver.close();
545 while let Some(request) = receiver.recv().await {
548 let _ = request
549 .reply_sender
550 .send(Err(std::io::Error::new(erc.kind(), erc.clone())));
551 }
552 for v in inflight.into_values() {
554 let _ = v
555 .reply_sender
556 .send(Err(std::io::Error::new(erc.kind(), erc.clone())));
557 }
558 return;
559 }
560 SelectResult::Shutdown => {
561 for v in inflight.into_values() {
562 let _ = v.reply_sender.send(Err(std::io::Error::new(
563 std::io::ErrorKind::ConnectionAborted,
564 "Client dropped before arrival of reply",
565 )));
566 }
567 return;
568 }
569 }
570 }
571}
572
573#[cfg(unix)]
578pub async fn query_uds(request: RequestBody, options: ClientOptions) -> std::io::Result<Reply> {
579 let mut client = UnixDatagramClient::new().await?;
580 client.query(request, options).await.map_err(QueryError::into_io)
581}