Skip to main content

fib_quant/
profile.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, rotation::ROTATION_ALGORITHM_VERSION, FibQuantError, Result};
4
5pub const PROFILE_SCHEMA: &str = "fib_quant_profile_v1";
6/// Maximum ambient dimension accepted by the alpha profile validator.
7pub const MAX_AMBIENT_DIM: usize = 16_384;
8/// Maximum block dimension accepted by the alpha profile validator.
9pub const MAX_BLOCK_DIM: usize = 256;
10/// Maximum codebook size accepted by the alpha profile validator.
11pub const MAX_CODEBOOK_SIZE: usize = 1 << 20;
12/// Maximum Lloyd training samples accepted by the alpha profile validator.
13pub const MAX_TRAINING_SAMPLES: u32 = 10_000_000;
14/// Maximum number of scalar values in a dense rotation matrix.
15pub const MAX_ROTATION_MATRIX_VALUES: usize = 16_777_216;
16/// Maximum number of scalar values in an `N x k` codebook.
17pub const MAX_CODEBOOK_VALUES: usize = 67_108_864;
18/// Maximum bits in a packed fixed-rate payload.
19pub const MAX_PACKED_INDEX_BITS: usize = 1 << 34;
20
21const RATE_TOLERANCE: f64 = 1.0e-12;
22const MAX_LLOYD_RESTARTS: u32 = 1_024;
23const MAX_LLOYD_ITERATIONS: u32 = 100_000;
24
25/// Norm payload representation.
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27#[non_exhaustive]
28#[serde(rename_all = "snake_case")]
29pub enum NormFormat {
30    /// Paper path: fp16 scalar norm side header.
31    Fp16Paper,
32    /// Reference/test path: f32 scalar norm side header.
33    #[doc(hidden)]
34    F32Reference,
35}
36
37/// Source used for training samples.
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
39#[non_exhaustive]
40#[serde(rename_all = "snake_case")]
41pub enum SourceMode {
42    /// Direct spherical-Beta sampler.
43    CanonicalSphericalBeta,
44    /// Normalized Gaussian projection reference sampler.
45    #[doc(hidden)]
46    ReferenceGaussianProjection,
47}
48
49/// Radius initialization method.
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51#[non_exhaustive]
52#[serde(rename_all = "snake_case")]
53pub enum RadiusMethod {
54    /// Bennett-Gersho Beta-quantile radii.
55    BetaQuantile,
56    /// Paper closed form for k=2.
57    K2ClosedForm,
58    /// Explicit large-d single-shell initialization.
59    #[doc(hidden)]
60    LargeDSingleShellExplicit,
61}
62
63/// Direction initialization method.
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65#[non_exhaustive]
66#[serde(rename_all = "snake_case")]
67pub enum DirectionMethod {
68    /// Planar Fibonacci spiral.
69    FibonacciSpiral,
70    /// Fibonacci sphere.
71    FibonacciSphere,
72    /// Roberts-Kronecker rank-one sequence.
73    RobertsKronecker,
74}
75
76/// Empty-cell handling during Lloyd-Max refinement.
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78#[non_exhaustive]
79#[serde(rename_all = "snake_case")]
80pub enum EmptyCellPolicy {
81    /// Split the occupied cell with highest distortion.
82    SplitHighestDistortion,
83    /// Fail if any cell is empty.
84    FailClosed,
85}
86
87/// Stable profile for paper-faithful FibQuant codebooks and payloads.
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89pub struct FibQuantProfileV1 {
90    /// Stable schema marker.
91    pub schema_version: String,
92    /// Ambient vector dimension `d`.
93    pub ambient_dim: u32,
94    /// Block dimension `k`.
95    pub block_dim: u32,
96    /// Codebook size `N`.
97    pub codebook_size: u32,
98    /// Paper dense rate `log2(N) / k`.
99    pub paper_rate_bits_per_coord: f64,
100    /// Practical fixed-rate index width `ceil(log2(N))`.
101    pub wire_index_bits: u8,
102    /// Practical wire rate `wire_index_bits / k`.
103    pub wire_bits_per_coord: f64,
104    /// Norm header format.
105    pub norm_format: NormFormat,
106    /// Seed for ambient rotation.
107    pub rotation_seed: u64,
108    /// Rotation generation algorithm identity.
109    pub rotation_algorithm_version: String,
110    /// Seed for codebook construction and Lloyd training.
111    pub codebook_seed: u64,
112    /// Codebook algorithm/version string.
113    pub codebook_version: String,
114    /// Training source mode.
115    pub source_mode: SourceMode,
116    /// Radius method.
117    pub radius_method: RadiusMethod,
118    /// Direction method.
119    pub direction_method: DirectionMethod,
120    /// Number of Lloyd restarts.
121    pub lloyd_restarts: u32,
122    /// Number of Lloyd iterations per restart.
123    pub lloyd_iterations: u32,
124    /// Number of training samples.
125    pub training_samples: u32,
126    /// Empty-cell repair policy.
127    pub empty_cell_policy: EmptyCellPolicy,
128}
129
130impl FibQuantProfileV1 {
131    /// Build a validated paper profile with method choices derived from `k`.
132    pub fn paper_default(
133        ambient_dim: usize,
134        block_dim: usize,
135        codebook_size: usize,
136        seed: u64,
137    ) -> Result<Self> {
138        validate_profile_parts(ambient_dim, block_dim, codebook_size)?;
139        let direction_method = match block_dim {
140            2 => DirectionMethod::FibonacciSpiral,
141            3 => DirectionMethod::FibonacciSphere,
142            _ => DirectionMethod::RobertsKronecker,
143        };
144        let radius_method = if block_dim == 2 {
145            RadiusMethod::K2ClosedForm
146        } else {
147            RadiusMethod::BetaQuantile
148        };
149        let wire_index_bits = wire_index_bits(codebook_size)?;
150        let profile = Self {
151            schema_version: PROFILE_SCHEMA.into(),
152            ambient_dim: ambient_dim as u32,
153            block_dim: block_dim as u32,
154            codebook_size: codebook_size as u32,
155            paper_rate_bits_per_coord: (codebook_size as f64).log2() / block_dim as f64,
156            wire_index_bits,
157            wire_bits_per_coord: f64::from(wire_index_bits) / block_dim as f64,
158            norm_format: NormFormat::Fp16Paper,
159            rotation_seed: seed,
160            rotation_algorithm_version: ROTATION_ALGORITHM_VERSION.into(),
161            codebook_seed: seed.wrapping_add(0x9e37_79b9_7f4a_7c15),
162            codebook_version: "fib-quant:paper-core-v1".into(),
163            source_mode: SourceMode::CanonicalSphericalBeta,
164            radius_method,
165            direction_method,
166            lloyd_restarts: 4,
167            lloyd_iterations: 25,
168            training_samples: default_training_samples(codebook_size)?,
169            empty_cell_policy: EmptyCellPolicy::SplitHighestDistortion,
170        };
171        profile.validate()?;
172        Ok(profile)
173    }
174
175    /// Validate the complete profile.
176    pub fn validate(&self) -> Result<()> {
177        if self.schema_version != PROFILE_SCHEMA {
178            return Err(FibQuantError::CorruptPayload(format!(
179                "profile schema_version {}, expected {PROFILE_SCHEMA}",
180                self.schema_version
181            )));
182        }
183        validate_profile_parts(
184            self.ambient_dim as usize,
185            self.block_dim as usize,
186            self.codebook_size as usize,
187        )?;
188        validate_resource_bounds(
189            self.ambient_dim as usize,
190            self.block_dim as usize,
191            self.codebook_size as usize,
192            self.training_samples,
193            self.wire_index_bits,
194        )?;
195        if self.norm_format != NormFormat::Fp16Paper {
196            return Err(FibQuantError::CorruptPayload(
197                "paper profile requires fp16 norm side header".into(),
198            ));
199        }
200        if self.source_mode != SourceMode::CanonicalSphericalBeta {
201            return Err(FibQuantError::CorruptPayload(
202                "paper profile requires canonical spherical-Beta source mode".into(),
203            ));
204        }
205        if self.rotation_algorithm_version != ROTATION_ALGORITHM_VERSION {
206            return Err(FibQuantError::CorruptPayload(format!(
207                "rotation_algorithm_version {}, expected {ROTATION_ALGORITHM_VERSION}",
208                self.rotation_algorithm_version
209            )));
210        }
211        let expected_bits = wire_index_bits(self.codebook_size as usize)?;
212        if self.wire_index_bits != expected_bits {
213            return Err(FibQuantError::CorruptPayload(format!(
214                "wire_index_bits {} does not match ceil(log2(N)) {expected_bits}",
215                self.wire_index_bits
216            )));
217        }
218        let k = self.block_dim as usize;
219        let expected_paper_rate = (self.codebook_size as f64).log2() / k as f64;
220        validate_rate(
221            "paper_rate_bits_per_coord",
222            self.paper_rate_bits_per_coord,
223            expected_paper_rate,
224        )?;
225        let expected_wire_rate = f64::from(self.wire_index_bits) / k as f64;
226        validate_rate(
227            "wire_bits_per_coord",
228            self.wire_bits_per_coord,
229            expected_wire_rate,
230        )?;
231        validate_method_pair(k, &self.radius_method, &self.direction_method)?;
232        if self.lloyd_restarts == 0 || self.lloyd_restarts > MAX_LLOYD_RESTARTS {
233            return Err(FibQuantError::CorruptPayload(format!(
234                "lloyd_restarts {} outside supported range 1..={MAX_LLOYD_RESTARTS}",
235                self.lloyd_restarts
236            )));
237        }
238        if self.lloyd_iterations == 0 || self.lloyd_iterations > MAX_LLOYD_ITERATIONS {
239            return Err(FibQuantError::CorruptPayload(format!(
240                "lloyd_iterations {} outside supported range 1..={MAX_LLOYD_ITERATIONS}",
241                self.lloyd_iterations
242            )));
243        }
244        if self.training_samples < self.codebook_size
245            || self.training_samples > MAX_TRAINING_SAMPLES
246        {
247            return Err(FibQuantError::CorruptPayload(format!(
248                "training_samples {} outside supported range {}..={MAX_TRAINING_SAMPLES}",
249                self.training_samples, self.codebook_size
250            )));
251        }
252        Ok(())
253    }
254
255    /// Stable digest over all explicit profile fields.
256    pub fn digest(&self) -> Result<String> {
257        self.validate()?;
258        json_digest(PROFILE_SCHEMA, self)
259    }
260
261    /// Number of `k`-blocks per vector.
262    pub fn block_count(&self) -> u32 {
263        self.ambient_dim / self.block_dim
264    }
265}
266
267/// Return the fixed wire width for one index in `[0, N)`.
268pub fn wire_index_bits(codebook_size: usize) -> Result<u8> {
269    if codebook_size < 2 {
270        return Err(FibQuantError::InvalidCodebookSize(codebook_size));
271    }
272    let bits = usize::BITS - (codebook_size - 1).leading_zeros();
273    u8::try_from(bits).map_err(|_| FibQuantError::InvalidCodebookSize(codebook_size))
274}
275
276fn validate_profile_parts(
277    ambient_dim: usize,
278    block_dim: usize,
279    codebook_size: usize,
280) -> Result<()> {
281    if ambient_dim == 0 {
282        return Err(FibQuantError::ZeroDimension);
283    }
284    if block_dim == 0 || block_dim > ambient_dim {
285        return Err(FibQuantError::InvalidBlockDim {
286            ambient_dim,
287            block_dim,
288        });
289    }
290    if ambient_dim == block_dim {
291        return Err(FibQuantError::InvalidBlockDim {
292            ambient_dim,
293            block_dim,
294        });
295    }
296    if ambient_dim % block_dim != 0 {
297        return Err(FibQuantError::DimensionNotDivisible {
298            ambient_dim,
299            block_dim,
300        });
301    }
302    if ambient_dim > MAX_AMBIENT_DIM {
303        return Err(FibQuantError::ResourceLimitExceeded(format!(
304            "ambient_dim {ambient_dim} exceeds MAX_AMBIENT_DIM {MAX_AMBIENT_DIM}"
305        )));
306    }
307    if block_dim > MAX_BLOCK_DIM {
308        return Err(FibQuantError::ResourceLimitExceeded(format!(
309            "block_dim {block_dim} exceeds MAX_BLOCK_DIM {MAX_BLOCK_DIM}"
310        )));
311    }
312    if !(2..=MAX_CODEBOOK_SIZE).contains(&codebook_size) {
313        return Err(FibQuantError::InvalidCodebookSize(codebook_size));
314    }
315    Ok(())
316}
317
318fn default_training_samples(codebook_size: usize) -> Result<u32> {
319    let samples = 30usize
320        .checked_mul(codebook_size)
321        .ok_or_else(|| FibQuantError::ResourceLimitExceeded("30 * codebook_size overflow".into()))?
322        .max(256)
323        .min(MAX_TRAINING_SAMPLES as usize);
324    u32::try_from(samples)
325        .map_err(|_| FibQuantError::ResourceLimitExceeded("training sample count overflow".into()))
326}
327
328fn checked_profile_mul(lhs: usize, rhs: usize, label: &str) -> Result<usize> {
329    lhs.checked_mul(rhs)
330        .ok_or_else(|| FibQuantError::ResourceLimitExceeded(format!("{label} overflow")))
331}
332
333fn validate_resource_bounds(
334    ambient_dim: usize,
335    block_dim: usize,
336    codebook_size: usize,
337    training_samples: u32,
338    wire_index_bits: u8,
339) -> Result<()> {
340    let rotation_values =
341        checked_profile_mul(ambient_dim, ambient_dim, "ambient_dim * ambient_dim")?;
342    if rotation_values > MAX_ROTATION_MATRIX_VALUES {
343        return Err(FibQuantError::ResourceLimitExceeded(format!(
344            "rotation matrix values {rotation_values} exceed MAX_ROTATION_MATRIX_VALUES {MAX_ROTATION_MATRIX_VALUES}"
345        )));
346    }
347
348    let codebook_values =
349        checked_profile_mul(codebook_size, block_dim, "codebook_size * block_dim")?;
350    if codebook_values > MAX_CODEBOOK_VALUES {
351        return Err(FibQuantError::ResourceLimitExceeded(format!(
352            "codebook values {codebook_values} exceed MAX_CODEBOOK_VALUES {MAX_CODEBOOK_VALUES}"
353        )));
354    }
355
356    checked_profile_mul(
357        training_samples as usize,
358        block_dim,
359        "training_samples * block_dim",
360    )?;
361
362    let block_count = ambient_dim / block_dim;
363    let packed_bits = checked_profile_mul(
364        block_count,
365        wire_index_bits as usize,
366        "block_count * wire_index_bits",
367    )?;
368    if packed_bits > MAX_PACKED_INDEX_BITS {
369        return Err(FibQuantError::ResourceLimitExceeded(format!(
370            "packed index bits {packed_bits} exceed MAX_PACKED_INDEX_BITS {MAX_PACKED_INDEX_BITS}"
371        )));
372    }
373    Ok(())
374}
375
376fn validate_rate(name: &str, actual: f64, expected: f64) -> Result<()> {
377    if !actual.is_finite() || !expected.is_finite() || (actual - expected).abs() > RATE_TOLERANCE {
378        return Err(FibQuantError::CorruptPayload(format!(
379            "{name} {actual} does not match expected {expected}"
380        )));
381    }
382    Ok(())
383}
384
385fn validate_method_pair(
386    block_dim: usize,
387    radius: &RadiusMethod,
388    direction: &DirectionMethod,
389) -> Result<()> {
390    let valid = match block_dim {
391        2 => {
392            radius == &RadiusMethod::K2ClosedForm && direction == &DirectionMethod::FibonacciSpiral
393        }
394        3 => {
395            radius == &RadiusMethod::BetaQuantile && direction == &DirectionMethod::FibonacciSphere
396        }
397        _ => {
398            radius == &RadiusMethod::BetaQuantile && direction == &DirectionMethod::RobertsKronecker
399        }
400    };
401    if valid {
402        Ok(())
403    } else {
404        Err(FibQuantError::CorruptPayload(format!(
405            "unsupported radius/direction pair for k={block_dim}: {radius:?}/{direction:?}"
406        )))
407    }
408}