Skip to main content

onpair/
config.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::types::BitWidth;
5
6// ─────────────────────────────────────────────────────────────────────────────
7// Public config.
8// ─────────────────────────────────────────────────────────────────────────────
9
10/// Code width: the maximum dictionary size is `2^bits`. Validated to `9..=16`
11/// at construction, so a [`Bits`] always holds an in-range value.
12#[derive(Copy, Clone, Debug, PartialEq, Eq)]
13pub struct Bits(u8);
14
15impl Bits {
16    /// Construct a [`Bits`], returning [`Error::InvalidArg`] unless
17    /// `value` is in `9..=16`.
18    pub const fn new(value: u8) -> Result<Self, Error> {
19        if 9 <= value && value <= 16 {
20            Ok(Self(value))
21        } else {
22            Err(Error::InvalidArg)
23        }
24    }
25
26    /// The validated code width, in `9..=16`.
27    pub const fn value(self) -> u8 {
28        self.0
29    }
30}
31
32impl TryFrom<u8> for Bits {
33    type Error = Error;
34    fn try_from(value: u8) -> Result<Self, Error> {
35        Self::new(value)
36    }
37}
38
39/// Dynamic-threshold sample fraction. Validated to `(0.0, 1.0]` at
40/// construction, so a [`Threshold`] always holds an in-range value.
41#[derive(Copy, Clone, Debug, PartialEq)]
42pub struct Threshold(f64);
43
44impl Threshold {
45    /// Construct a [`Threshold`], returning [`Error::InvalidArg`] unless
46    /// `value` is in `(0.0, 1.0]`.
47    pub const fn new(value: f64) -> Result<Self, Error> {
48        if value > 0.0 && value <= 1.0 {
49            Ok(Self(value))
50        } else {
51            Err(Error::InvalidArg)
52        }
53    }
54
55    /// The validated sample fraction, in `(0.0, 1.0]`.
56    pub const fn value(self) -> f64 {
57        self.0
58    }
59}
60
61impl TryFrom<f64> for Threshold {
62    type Error = Error;
63    fn try_from(value: f64) -> Result<Self, Error> {
64        Self::new(value)
65    }
66}
67
68/// Training configuration. See [`DEFAULT_CONFIG`] for a reasonable starting
69/// point.
70#[derive(Copy, Clone, Debug)]
71pub struct Config {
72    /// Code width; see [`Bits`].
73    pub bits: Bits,
74    /// Dynamic-threshold sample fraction; see [`Threshold`].
75    pub threshold: Threshold,
76    /// RNG seed for sampling; `None` means non-deterministic.
77    pub seed: Option<u64>,
78}
79
80/// Reasonable starting point: 12-bit codes, dynamic threshold sampling 20 %.
81pub const DEFAULT_CONFIG: Config = Config {
82    bits: match Bits::new(12) {
83        Ok(b) => b,
84        Err(_) => unreachable!(),
85    },
86    threshold: match Threshold::new(0.2) {
87        Ok(t) => t,
88        Err(_) => unreachable!(),
89    },
90    seed: None,
91};
92
93impl Default for Config {
94    fn default() -> Self {
95        DEFAULT_CONFIG
96    }
97}
98
99// ─────────────────────────────────────────────────────────────────────────────
100// Error — single-variant.
101// ─────────────────────────────────────────────────────────────────────────────
102
103/// Error returned by the public training and encoding API.
104#[derive(Debug, Copy, Clone, Eq, PartialEq)]
105pub enum Error {
106    /// A configuration value or input buffer was out of range or malformed.
107    InvalidArg,
108}
109
110impl std::fmt::Display for Error {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        match self {
113            Error::InvalidArg => f.write_str("onpair: invalid argument"),
114        }
115    }
116}
117
118impl std::error::Error for Error {}
119
120// ─────────────────────────────────────────────────────────────────────────────
121// Internal training config — crate-private. Kept richer than the public Config
122// so unit tests can still drive fixed-threshold training.
123// ─────────────────────────────────────────────────────────────────────────────
124
125#[derive(Copy, Clone, Debug, PartialEq, Eq)]
126pub(crate) struct FixedThreshold {
127    pub(crate) value: u8,
128}
129
130#[derive(Copy, Clone, Debug, PartialEq)]
131pub(crate) struct DynamicThreshold {
132    pub(crate) sample_fraction: f64,
133}
134
135impl Default for DynamicThreshold {
136    fn default() -> Self {
137        Self {
138            sample_fraction: 0.2,
139        }
140    }
141}
142
143#[derive(Copy, Clone, Debug)]
144#[allow(dead_code)] // `Fixed` is used only in tests
145pub(crate) enum ThresholdSpec {
146    Fixed(FixedThreshold),
147    Dynamic(DynamicThreshold),
148}
149
150impl Default for ThresholdSpec {
151    fn default() -> Self {
152        Self::Dynamic(DynamicThreshold::default())
153    }
154}
155
156#[derive(Clone, Debug)]
157pub(crate) struct TrainingConfig {
158    pub(crate) bits: BitWidth,
159    pub(crate) threshold: ThresholdSpec,
160    pub(crate) seed: Option<u64>,
161}
162
163impl Default for TrainingConfig {
164    fn default() -> Self {
165        Self {
166            bits: 16,
167            threshold: ThresholdSpec::default(),
168            seed: None,
169        }
170    }
171}
172
173impl From<Config> for TrainingConfig {
174    fn from(c: Config) -> Self {
175        Self {
176            bits: c.bits.value(),
177            threshold: ThresholdSpec::Dynamic(DynamicThreshold {
178                sample_fraction: c.threshold.value(),
179            }),
180            seed: c.seed,
181        }
182    }
183}