Skip to main content

mdns_proto/
cache.rs

1//! Passive record cache observed from incoming traffic.
2
3use core::time::Duration;
4
5use crate::{
6  Instant, Name, Pool,
7  backend::RdataBuf,
8  trace::*,
9  wire::{ResourceClass, ResourceType},
10};
11
12/// One cached resource record.
13
14#[derive(Debug, Clone)]
15pub struct CacheEntry<I: Instant> {
16  name: Name,
17  rtype: ResourceType,
18  /// cache key includes ResourceClass so a non-IN-class record
19  /// cannot dedupe with, evict, or count as an IN-class record.  Without
20  /// this, a malformed or hostile response with the same `(name, rtype)`
21  /// but class != IN could corrupt the cache across protocol identity
22  /// boundaries.
23  rclass: ResourceClass,
24  rdata: RdataBuf,
25  expires_at: I,
26  /// When this record was last received / refreshed.  used to
27  /// implement the RFC 6762 §10.2 "1-second grace" on cache-flush —
28  /// an incoming cache-flush only affects siblings whose
29  /// `now - received_at >= 1 second`, so a multi-address RRSet
30  /// announced across two back-to-back packets is not collapsed.
31  received_at: I,
32}
33
34impl<I: Instant> CacheEntry<I> {
35  /// Build a new cache entry.  `received_at` is the wall instant at
36  /// which this record arrived; `expires_at` is the TTL-derived
37  /// future deadline.
38  pub(crate) fn new(
39    name: Name,
40    rtype: ResourceType,
41    rclass: ResourceClass,
42    rdata: RdataBuf,
43    expires_at: I,
44    received_at: I,
45  ) -> Self {
46    Self {
47      name,
48      rtype,
49      rclass,
50      rdata,
51      expires_at,
52      received_at,
53    }
54  }
55
56  /// The record's name.
57  #[inline(always)]
58  pub fn name(&self) -> &Name {
59    &self.name
60  }
61
62  /// The record's type.
63  #[inline(always)]
64  pub const fn rtype(&self) -> ResourceType {
65    self.rtype
66  }
67
68  /// The record's class.
69  #[inline(always)]
70  pub const fn rclass(&self) -> ResourceClass {
71    self.rclass
72  }
73
74  /// The record's raw rdata bytes.
75  #[inline(always)]
76  pub fn rdata_slice(&self) -> &[u8] {
77    self.rdata.as_ref()
78  }
79
80  /// Absolute expiration deadline.
81  #[inline(always)]
82  pub fn expires_at(&self) -> I {
83    self.expires_at
84  }
85
86  /// Wall instant at which this record was last received / refreshed.
87  #[inline(always)]
88  pub fn received_at(&self) -> I {
89    self.received_at
90  }
91}
92
93/// Default maximum number of cache entries before eviction kicks in.
94const DEFAULT_MAX_ENTRIES: usize = 1024;
95
96/// Passive record cache.
97pub struct Cache<I, P> {
98  entries: P,
99  max_entries: usize,
100  _phantom: core::marker::PhantomData<I>,
101  #[cfg(feature = "stats")]
102  stats: Option<std::sync::Arc<hick_trace::stats::Stats>>,
103}
104
105impl<I, P> Cache<I, P>
106where
107  I: Instant,
108  P: Pool<CacheEntry<I>>,
109{
110  /// Empty cache with the default maximum entry cap (1024).
111  pub fn new() -> Self {
112    Self {
113      entries: P::new(),
114      max_entries: DEFAULT_MAX_ENTRIES,
115      _phantom: core::marker::PhantomData,
116      #[cfg(feature = "stats")]
117      stats: None,
118    }
119  }
120
121  /// Empty cache with a custom maximum entry cap.
122  ///
123  /// When `try_insert` is called and the number of stored entries has reached
124  /// `max`, the soonest-expiring entry is evicted proactively before the new
125  /// entry is inserted. This bounds memory usage even when the backing
126  /// [`Pool`] grows without error (e.g. `slab::Slab`).
127  pub fn with_max_entries(max: usize) -> Self {
128    Self {
129      entries: P::new(),
130      max_entries: max,
131      _phantom: core::marker::PhantomData,
132      #[cfg(feature = "stats")]
133      stats: None,
134    }
135  }
136
137  /// Attach the shared [`hick_trace::stats::Stats`] handle from the owning
138  /// [`crate::endpoint::Endpoint`]. No allocation — the Arc is cloned from the
139  /// endpoint's existing single Arc. Called immediately after construction by
140  /// the `Endpoint` so that all per-cache counters accumulate into the
141  /// endpoint-level stats. Before this is called, stats bumps are no-ops
142  /// (the field is `None`).
143  #[cfg(feature = "stats")]
144  pub(crate) fn set_stats(&mut self, stats: std::sync::Arc<hick_trace::stats::Stats>) {
145    self.stats = Some(stats);
146  }
147
148  /// Borrow the stats handle if one has been attached.
149  #[cfg(feature = "stats")]
150  #[inline]
151  fn stat(&self) -> Option<&hick_trace::stats::Stats> {
152    self.stats.as_deref()
153  }
154
155  /// The configured maximum number of entries.
156  #[inline(always)]
157  pub const fn max_entries(&self) -> usize {
158    self.max_entries
159  }
160
161  /// The current number of cached entries.
162  #[inline(always)]
163  pub fn len(&self) -> usize {
164    self.entries.len()
165  }
166
167  /// Whether the cache is empty.
168  #[inline(always)]
169  pub fn is_empty(&self) -> bool {
170    self.entries.len() == 0
171  }
172
173  /// Insert (or update / remove) a record observation.
174  ///
175  /// Semantics:
176  /// - If `ttl == 0`, treat as "record going away" (RFC 6762 §10.1): clamp the
177  ///   matching `(name, rtype, rclass, rdata)` entry's `expires_at` to one
178  ///   second out (rescue window, never extending a sooner expiry) and return
179  ///   `Ok(None)` — NOT an immediate delete.
180  /// - If `cache_flush == true` (RFC 6762 §10.2), DEFER eviction by 1 second:
181  ///   clamp the `expires_at` of every existing sibling matching
182  ///   `(name, rtype, rclass)` (and not the new rdata) to `min(current,
183  ///   now + 1s)`.  This gives a refresh burst time to re-announce
184  ///   missing siblings before they disappear from the cache.
185  /// - If a matching `(name, rtype, rclass, rdata)` entry already exists,
186  ///   refresh its expiration in place and return `Ok(Some(key))`
187  ///   (deduplication).
188  /// - Otherwise insert a new entry.  If the pool is full, evict the
189  ///   soonest-expiring entry first (best-effort) then retry.  If the
190  ///   retry still fails the error is propagated.
191  ///
192  /// cache identity is `(name, rtype, rclass, rdata)`.  A non-IN
193  /// class record cannot dedupe with, evict, or count as an IN record.
194  #[allow(clippy::too_many_arguments)]
195  pub fn try_insert(
196    &mut self,
197    name: Name,
198    rtype: ResourceType,
199    rclass: ResourceClass,
200    rdata: impl Into<RdataBuf>,
201    ttl: Duration,
202    now: I,
203    cache_flush: bool,
204  ) -> Result<Option<usize>, P::Error> {
205    // max_entries == 0 means caching is disabled.  Honour it on every
206    // insert path (including cache_flush, which would otherwise insert a fresh
207    // entry after evicting matching ones).  Returning Ok(None) keeps the
208    // existing "no entry was inserted" semantic (same as the TTL=0 branch).
209    if self.max_entries == 0 {
210      // Still honour TTL=0 removals so a zero-cap cache stays consistent if a
211      // caller is shrinking max_entries dynamically — but no entry to remove
212      // here (the cache is empty by construction), so just bail.
213      return Ok(None);
214    }
215    let rdata: RdataBuf = rdata.into();
216    // TTL=0 → goodbye (RFC 6762 §10.1). Do NOT delete immediately: shorten the
217    // matching entry to expire in ONE SECOND. This gives any responder still
218    // using the record a window to rescue it (a positive-TTL re-announce before
219    // then refreshes it via the dedup path below), and bounds the disruption of
220    // an accidental or malicious on-link goodbye to a brief disappearance window
221    // rather than instant deletion. Only ever SHORTENS — never extends a sooner
222    // natural expiry. (mirrors the cache-flush deferred-expiry below.)
223    if ttl == Duration::ZERO {
224      if let Some(deadline) = now.checked_add_duration(Duration::from_secs(1)) {
225        let mut victim: Option<usize> = None;
226        for (key, entry) in self.entries.iter() {
227          if entry.rtype() == rtype
228            && entry.rclass() == rclass
229            && entry.name().as_str() == name.as_str()
230            && entry.rdata_slice() == rdata.as_ref()
231          {
232            victim = Some(key);
233            break;
234          }
235        }
236        if let Some(key) = victim
237          && let Some(entry) = self.entries.get_mut(key)
238          && entry.expires_at() > deadline
239        {
240          entry.expires_at = deadline;
241        }
242      }
243      return Ok(None);
244    }
245
246    // cache_flush=true → RFC 6762 §10.2: the sender is authoritative for
247    // records of this (name, rtype, rclass).  This implements the
248    // RFC-specified DEFERRED expiry: instead of immediately removing
249    // matching siblings, clamp their `expires_at` to `min(current,
250    // now + 1s)`.  Behaviour:
251    //   * Refresh bursts that re-announce missing siblings within 1s
252    //     update those siblings' received_at/expires_at via the dedup
253    //     path below — the clamp is undone.
254    //   * Siblings that are NOT re-announced expire naturally one
255    //     second later via the normal TTL sweep — callers have a 1s
256    //     window to observe them before they vanish.
257    //
258    // the "skip recent siblings" semantics still apply: entries
259    // received within the last second are left untouched (no clamp).
260    if cache_flush {
261      let one_sec_from_now = now.checked_add_duration(Duration::from_secs(1));
262      if let Some(deadline) = one_sec_from_now {
263        // Collect first to avoid mutable-while-iterating problems.
264        let mut to_clamp: std::vec::Vec<usize> = std::vec::Vec::new();
265        for (key, entry) in self.entries.iter() {
266          if entry.rtype() != rtype
267            || entry.rclass() != rclass
268            || entry.name().as_str() != name.as_str()
269          {
270            continue;
271          }
272          // grace: do not touch entries received within the last second.
273          let age = now.checked_duration_since(entry.received_at());
274          let recent = match age {
275            Some(d) => d < Duration::from_secs(1),
276            None => true, // received_at in the future — treat as recent
277          };
278          if recent || entry.rdata_slice() == rdata.as_ref() {
279            continue;
280          }
281          // Only clamp if it would shorten the deadline.
282          if entry.expires_at() > deadline {
283            to_clamp.push(key);
284          }
285        }
286        for key in to_clamp {
287          if let Some(entry) = self.entries.get_mut(key) {
288            entry.expires_at = deadline;
289          }
290        }
291      }
292      // Fall through to the dedup/insert path: the new record either
293      // refreshes an existing copy of itself or inserts fresh.
294    }
295
296    let expires_at = now.checked_add_duration(ttl).unwrap_or(now);
297
298    // Deduplicate: refresh the expiration of an existing matching entry.
299    let mut update_key: Option<usize> = None;
300    for (key, entry) in self.entries.iter() {
301      if entry.rtype() == rtype
302        && entry.rclass() == rclass
303        && entry.name().as_str() == name.as_str()
304        && entry.rdata_slice() == rdata.as_ref()
305      {
306        update_key = Some(key);
307        break;
308      }
309    }
310    if let Some(key) = update_key {
311      if let Some(entry) = self.entries.get_mut(key) {
312        entry.expires_at = expires_at;
313        entry.received_at = now;
314      }
315      trace!(
316        target: "mdns_proto::cache",
317        rtype = ?rtype,
318        "cache: refreshed existing entry (dedup)"
319      );
320      #[cfg(feature = "stats")]
321      if let Some(s) = self.stat() {
322        s.cache_refreshes(1);
323        s.set_cache_size(self.entries.len() as u64);
324      }
325      return Ok(Some(key));
326    }
327
328    // Insert through the bounded helper (proactive cap eviction + reactive retry).
329    let result = self
330      .bounded_insert(CacheEntry::new(name, rtype, rclass, rdata, expires_at, now))
331      .map(Some);
332    if result.is_ok() {
333      trace!(
334        target: "mdns_proto::cache",
335        rtype = ?rtype,
336        "cache: inserted new entry"
337      );
338      #[cfg(feature = "stats")]
339      if let Some(s) = self.stat() {
340        s.cache_inserts(1);
341        s.set_cache_size(self.entries.len() as u64);
342      }
343    }
344    result
345  }
346
347  /// Insert `entry` into the backing pool while respecting `max_entries`.
348  ///
349  /// Algorithm:
350  /// 1. Proactive eviction: if the pool is at or above `max_entries`, evict the
351  ///    soonest-expiring entry BEFORE attempting the insert.  This bounds memory
352  ///    usage even when the backing pool is infallible (e.g. `slab::Slab`).
353  /// 2. Attempt the insert.
354  /// 3. Reactive eviction + retry: if the pool returns a capacity error (e.g.
355  ///    `heapless` fixed-size collections), evict the soonest-expiring entry
356  ///    and retry once.
357  fn bounded_insert(&mut self, entry: CacheEntry<I>) -> Result<usize, P::Error> {
358    // Step 1: proactive eviction when at or above the cap.
359    if self.entries.len() >= self.max_entries {
360      let mut victim: Option<(usize, I)> = None;
361      for (key, e) in self.entries.iter() {
362        let exp = e.expires_at();
363        if !matches!(victim, Some((_, prev_exp)) if prev_exp <= exp) {
364          victim = Some((key, exp));
365        }
366      }
367      if let Some((vk, _)) = victim {
368        self.entries.try_remove(vk);
369        trace!(
370          target: "mdns_proto::cache",
371          "cache: proactive eviction (cap reached)"
372        );
373        #[cfg(feature = "stats")]
374        if let Some(s) = self.stat() {
375          s.cache_evictions(1);
376        }
377      }
378    }
379
380    // Step 2: attempt insert.
381    match self.entries.insert(entry.clone()) {
382      Ok(k) => Ok(k),
383      Err(_) => {
384        // Step 3: reactive eviction + single retry.
385        let mut victim: Option<(usize, I)> = None;
386        for (key, e) in self.entries.iter() {
387          let exp = e.expires_at();
388          if !matches!(victim, Some((_, prev_exp)) if prev_exp <= exp) {
389            victim = Some((key, exp));
390          }
391        }
392        if let Some((vk, _)) = victim {
393          self.entries.try_remove(vk);
394          trace!(
395            target: "mdns_proto::cache",
396            "cache: reactive eviction (pool capacity error)"
397          );
398          #[cfg(feature = "stats")]
399          if let Some(s) = self.stat() {
400            s.cache_evictions(1);
401          }
402        }
403        self.entries.insert(entry)
404      }
405    }
406  }
407
408  /// Sweep expired entries, returning how many were removed.
409  pub fn sweep_expired(&mut self, now: I) -> usize {
410    let mut to_remove: std::vec::Vec<usize> = std::vec::Vec::new();
411    for (key, entry) in self.entries.iter() {
412      if entry.expires_at() <= now {
413        to_remove.push(key);
414      }
415    }
416    let mut removed = 0usize;
417    for key in to_remove {
418      if self.entries.try_remove(key).is_some() {
419        removed = removed.saturating_add(1);
420      }
421    }
422    if removed > 0 {
423      trace!(
424        target: "mdns_proto::cache",
425        removed,
426        "cache: swept expired entries"
427      );
428      #[cfg(feature = "stats")]
429      if let Some(s) = self.stat() {
430        #[allow(clippy::cast_possible_truncation)]
431        s.cache_expirations(removed as u64);
432        s.set_cache_size(self.entries.len() as u64);
433      }
434    }
435    removed
436  }
437
438  /// Next deadline (soonest expiration), if any.
439  pub fn next_expiration(&self) -> Option<I> {
440    let mut best: Option<I> = None;
441    for (_, entry) in self.entries.iter() {
442      let exp = entry.expires_at();
443      best = Some(match best {
444        Some(prev) if prev < exp => prev,
445        _ => exp,
446      });
447    }
448    best
449  }
450
451  /// Look up whether the cache contains a record matching
452  /// `(name, rtype, rclass)`.  class is part of the cache key.
453  pub fn contains(&self, name: &Name, rtype: ResourceType, rclass: ResourceClass) -> bool {
454    self.entries.iter().any(|(_, e)| {
455      e.rtype() == rtype && e.rclass() == rclass && e.name().as_str() == name.as_str()
456    })
457  }
458
459  /// Count the number of cached entries matching `(name, rtype, rclass)`.
460  ///
461  /// Multiple distinct records can share `(name, rtype, rclass)` (e.g. a
462  /// multi-homed host with several A records), so a single `contains`
463  /// check cannot tell you whether the full RRSet landed.  Use this for
464  /// RRSet-coherency checks.
465  pub fn count_matching(&self, name: &Name, rtype: ResourceType, rclass: ResourceClass) -> usize {
466    self
467      .entries
468      .iter()
469      .filter(|(_, e)| {
470        e.rtype() == rtype && e.rclass() == rclass && e.name().as_str() == name.as_str()
471      })
472      .count()
473  }
474}
475
476impl<I, P> Default for Cache<I, P>
477where
478  I: Instant,
479  P: Pool<CacheEntry<I>>,
480{
481  fn default() -> Self {
482    Self::new()
483  }
484}
485
486#[cfg(all(test, feature = "std", feature = "slab"))]
487#[allow(clippy::unwrap_used, clippy::expect_used)]
488mod tests;