Skip to main content

vyre_driver/
routing.rs

1//! Runtime distribution-aware algorithm routing.
2//!
3//! `routing` records light-weight input distribution summaries per call site
4//! and chooses a byte-identical algorithm variant for the next dispatch. The
5//! first users are sort-like ops where tiny inputs prefer insertion sort and
6//! skewed large inputs prefer radix-style passes.
7
8/// Profile-guided backend routing table and cert-gate latency measurement.
9pub mod pgo;
10
11use dashmap::DashMap;
12use std::borrow::Cow;
13use std::sync::Arc;
14
15/// Sort algorithm variants with identical output contracts.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[non_exhaustive]
18pub enum SortBackend {
19    /// Insertion sort: minimal fixed overhead for small inputs.
20    InsertionSort,
21    /// Radix sort: stable throughput for skewed integer distributions.
22    RadixSort,
23    /// Bitonic sort: GPU-friendly general-purpose fallback.
24    BitonicSort,
25}
26
27/// Observed input distribution for one call.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct Distribution {
30    len: usize,
31    unique: usize,
32    max_run: usize,
33}
34
35impl Distribution {
36    /// Build a distribution summary from u32 inputs.
37    #[must_use]
38    pub fn observe(values: &[u32]) -> Self {
39        if values.is_empty() {
40            return Self {
41                len: 0,
42                unique: 0,
43                max_run: 0,
44            };
45        }
46        let mut unique = FixedUniqueU32::default();
47        let mut max_run = 1usize;
48        let mut current_run = 1usize;
49        unique.observe(values[0]);
50        for window in values.windows(2) {
51            unique.observe(window[1]);
52            if window[0] == window[1] {
53                current_run += 1;
54                max_run = max_run.max(current_run);
55            } else {
56                current_run = 1;
57            }
58        }
59        Self {
60            len: values.len(),
61            unique: unique.unique_len(values.len()),
62            max_run,
63        }
64    }
65
66    #[must_use]
67    fn skew_ratio(self) -> f32 {
68        if self.len == 0 {
69            return 0.0;
70        }
71        1.0 - (self.unique as f32 / self.len as f32)
72    }
73}
74
75const INLINE_UNIQUE_CAP: usize = 512;
76
77struct FixedUniqueU32 {
78    values: [u32; INLINE_UNIQUE_CAP],
79    len: usize,
80    overflowed: bool,
81}
82
83impl Default for FixedUniqueU32 {
84    fn default() -> Self {
85        Self {
86            values: [0; INLINE_UNIQUE_CAP],
87            len: 0,
88            overflowed: false,
89        }
90    }
91}
92
93impl FixedUniqueU32 {
94    fn observe(&mut self, value: u32) {
95        if self.values[..self.len].contains(&value) {
96            return;
97        }
98        if self.len == INLINE_UNIQUE_CAP {
99            self.overflowed = true;
100            return;
101        }
102        self.values[self.len] = value;
103        self.len += 1;
104    }
105
106    fn unique_len(&self, input_len: usize) -> usize {
107        if self.overflowed {
108            input_len
109        } else {
110            self.len
111        }
112    }
113}
114
115/// Per-call-site profile used by routing decisions.
116#[derive(Debug, Default)]
117pub struct RoutingTable {
118    profiles: DashMap<Arc<str>, Distribution>,
119}
120
121impl RoutingTable {
122    /// Record one call-site observation and return the selected backend.
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if the routing table mutex is poisoned.
127    pub fn observe_sort_u32(
128        &self,
129        call_site: Cow<'_, str>,
130        values: &[u32],
131    ) -> Result<SortBackend, String> {
132        let distribution = Distribution::observe(values);
133        let key = match call_site {
134            Cow::Borrowed(value) => Arc::<str>::from(value),
135            Cow::Owned(value) => Arc::<str>::from(value.into_boxed_str()),
136        };
137        self.profiles.insert(key, distribution);
138        Ok(select_sort_backend(distribution))
139    }
140
141    /// Return the last observed distribution for a call site.
142    #[must_use]
143    pub fn distribution(&self, call_site: &str) -> Option<Distribution> {
144        self.profiles.get(call_site).map(|profile| *profile)
145    }
146}
147
148/// Pick a sort backend from a distribution summary.
149#[must_use]
150pub fn select_sort_backend(distribution: Distribution) -> SortBackend {
151    if distribution.len <= 32 {
152        SortBackend::InsertionSort
153    } else if distribution.skew_ratio() >= 0.75 || distribution.max_run >= 16 {
154        SortBackend::RadixSort
155    } else {
156        SortBackend::BitonicSort
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn skewed_input_picks_radix_sort() {
166        let table = RoutingTable::default();
167        let mut values = vec![7u32; 240];
168        values.extend(0..16);
169        let selected = table
170            .observe_sort_u32(Cow::Borrowed("sort.callsite.skewed"), &values)
171            .expect("Fix: routing profile should record");
172        assert_eq!(selected, SortBackend::RadixSort);
173    }
174
175    #[test]
176    fn small_input_picks_insertion_sort() {
177        let table = RoutingTable::default();
178        let selected = table
179            .observe_sort_u32(Cow::Borrowed("sort.callsite.small"), &[4, 1, 3, 2])
180            .expect("Fix: routing profile should record");
181        assert_eq!(selected, SortBackend::InsertionSort);
182    }
183
184    #[test]
185    fn pgo_picks_fastest_backend_per_op() {
186        let table = RoutingTable::default();
187        assert_eq!(
188            table
189                .observe_sort_u32(Cow::Borrowed("op.sort"), &[8, 3, 1])
190                .unwrap(),
191            SortBackend::InsertionSort
192        );
193        assert_eq!(
194            table
195                .observe_sort_u32(Cow::Borrowed("op.sort"), &vec![42; 128])
196                .unwrap(),
197            SortBackend::RadixSort
198        );
199        assert_eq!(
200            table
201                .distribution("op.sort")
202                .expect("Fix: profile retained")
203                .len,
204            128
205        );
206    }
207}