1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, rotation::ROTATION_ALGORITHM_VERSION, FibQuantError, Result};
4
5pub const PROFILE_SCHEMA: &str = "fib_quant_profile_v1";
6pub const MAX_AMBIENT_DIM: usize = 16_384;
8pub const MAX_BLOCK_DIM: usize = 256;
10pub const MAX_CODEBOOK_SIZE: usize = 1 << 20;
12pub const MAX_TRAINING_SAMPLES: u32 = 10_000_000;
14pub const MAX_ROTATION_MATRIX_VALUES: usize = 16_777_216;
16pub const MAX_CODEBOOK_VALUES: usize = 67_108_864;
18pub const MAX_PACKED_INDEX_BITS: usize = 1 << 34;
20
21const RATE_TOLERANCE: f64 = 1.0e-12;
22const MAX_LLOYD_RESTARTS: u32 = 1_024;
23const MAX_LLOYD_ITERATIONS: u32 = 100_000;
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27#[non_exhaustive]
28#[serde(rename_all = "snake_case")]
29pub enum NormFormat {
30 Fp16Paper,
32 #[doc(hidden)]
34 F32Reference,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
39#[non_exhaustive]
40#[serde(rename_all = "snake_case")]
41pub enum SourceMode {
42 CanonicalSphericalBeta,
44 #[doc(hidden)]
46 ReferenceGaussianProjection,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51#[non_exhaustive]
52#[serde(rename_all = "snake_case")]
53pub enum RadiusMethod {
54 BetaQuantile,
56 K2ClosedForm,
58 #[doc(hidden)]
60 LargeDSingleShellExplicit,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65#[non_exhaustive]
66#[serde(rename_all = "snake_case")]
67pub enum DirectionMethod {
68 FibonacciSpiral,
70 FibonacciSphere,
72 RobertsKronecker,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78#[non_exhaustive]
79#[serde(rename_all = "snake_case")]
80pub enum EmptyCellPolicy {
81 SplitHighestDistortion,
83 FailClosed,
85}
86
87#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89pub struct FibQuantProfileV1 {
90 pub schema_version: String,
92 pub ambient_dim: u32,
94 pub block_dim: u32,
96 pub codebook_size: u32,
98 pub paper_rate_bits_per_coord: f64,
100 pub wire_index_bits: u8,
102 pub wire_bits_per_coord: f64,
104 pub norm_format: NormFormat,
106 pub rotation_seed: u64,
108 pub rotation_algorithm_version: String,
110 pub codebook_seed: u64,
112 pub codebook_version: String,
114 pub source_mode: SourceMode,
116 pub radius_method: RadiusMethod,
118 pub direction_method: DirectionMethod,
120 pub lloyd_restarts: u32,
122 pub lloyd_iterations: u32,
124 pub training_samples: u32,
126 pub empty_cell_policy: EmptyCellPolicy,
128}
129
130impl FibQuantProfileV1 {
131 pub fn paper_default(
133 ambient_dim: usize,
134 block_dim: usize,
135 codebook_size: usize,
136 seed: u64,
137 ) -> Result<Self> {
138 validate_profile_parts(ambient_dim, block_dim, codebook_size)?;
139 let direction_method = match block_dim {
140 2 => DirectionMethod::FibonacciSpiral,
141 3 => DirectionMethod::FibonacciSphere,
142 _ => DirectionMethod::RobertsKronecker,
143 };
144 let radius_method = if block_dim == 2 {
145 RadiusMethod::K2ClosedForm
146 } else {
147 RadiusMethod::BetaQuantile
148 };
149 let wire_index_bits = wire_index_bits(codebook_size)?;
150 let profile = Self {
151 schema_version: PROFILE_SCHEMA.into(),
152 ambient_dim: ambient_dim as u32,
153 block_dim: block_dim as u32,
154 codebook_size: codebook_size as u32,
155 paper_rate_bits_per_coord: (codebook_size as f64).log2() / block_dim as f64,
156 wire_index_bits,
157 wire_bits_per_coord: f64::from(wire_index_bits) / block_dim as f64,
158 norm_format: NormFormat::Fp16Paper,
159 rotation_seed: seed,
160 rotation_algorithm_version: ROTATION_ALGORITHM_VERSION.into(),
161 codebook_seed: seed.wrapping_add(0x9e37_79b9_7f4a_7c15),
162 codebook_version: "fib-quant:paper-core-v1".into(),
163 source_mode: SourceMode::CanonicalSphericalBeta,
164 radius_method,
165 direction_method,
166 lloyd_restarts: 4,
167 lloyd_iterations: 25,
168 training_samples: default_training_samples(codebook_size)?,
169 empty_cell_policy: EmptyCellPolicy::SplitHighestDistortion,
170 };
171 profile.validate()?;
172 Ok(profile)
173 }
174
175 pub fn validate(&self) -> Result<()> {
177 if self.schema_version != PROFILE_SCHEMA {
178 return Err(FibQuantError::CorruptPayload(format!(
179 "profile schema_version {}, expected {PROFILE_SCHEMA}",
180 self.schema_version
181 )));
182 }
183 validate_profile_parts(
184 self.ambient_dim as usize,
185 self.block_dim as usize,
186 self.codebook_size as usize,
187 )?;
188 validate_resource_bounds(
189 self.ambient_dim as usize,
190 self.block_dim as usize,
191 self.codebook_size as usize,
192 self.training_samples,
193 self.wire_index_bits,
194 )?;
195 if self.norm_format != NormFormat::Fp16Paper {
196 return Err(FibQuantError::CorruptPayload(
197 "paper profile requires fp16 norm side header".into(),
198 ));
199 }
200 if self.source_mode != SourceMode::CanonicalSphericalBeta {
201 return Err(FibQuantError::CorruptPayload(
202 "paper profile requires canonical spherical-Beta source mode".into(),
203 ));
204 }
205 if self.rotation_algorithm_version != ROTATION_ALGORITHM_VERSION {
206 return Err(FibQuantError::CorruptPayload(format!(
207 "rotation_algorithm_version {}, expected {ROTATION_ALGORITHM_VERSION}",
208 self.rotation_algorithm_version
209 )));
210 }
211 let expected_bits = wire_index_bits(self.codebook_size as usize)?;
212 if self.wire_index_bits != expected_bits {
213 return Err(FibQuantError::CorruptPayload(format!(
214 "wire_index_bits {} does not match ceil(log2(N)) {expected_bits}",
215 self.wire_index_bits
216 )));
217 }
218 let k = self.block_dim as usize;
219 let expected_paper_rate = (self.codebook_size as f64).log2() / k as f64;
220 validate_rate(
221 "paper_rate_bits_per_coord",
222 self.paper_rate_bits_per_coord,
223 expected_paper_rate,
224 )?;
225 let expected_wire_rate = f64::from(self.wire_index_bits) / k as f64;
226 validate_rate(
227 "wire_bits_per_coord",
228 self.wire_bits_per_coord,
229 expected_wire_rate,
230 )?;
231 validate_method_pair(k, &self.radius_method, &self.direction_method)?;
232 if self.lloyd_restarts == 0 || self.lloyd_restarts > MAX_LLOYD_RESTARTS {
233 return Err(FibQuantError::CorruptPayload(format!(
234 "lloyd_restarts {} outside supported range 1..={MAX_LLOYD_RESTARTS}",
235 self.lloyd_restarts
236 )));
237 }
238 if self.lloyd_iterations == 0 || self.lloyd_iterations > MAX_LLOYD_ITERATIONS {
239 return Err(FibQuantError::CorruptPayload(format!(
240 "lloyd_iterations {} outside supported range 1..={MAX_LLOYD_ITERATIONS}",
241 self.lloyd_iterations
242 )));
243 }
244 if self.training_samples < self.codebook_size
245 || self.training_samples > MAX_TRAINING_SAMPLES
246 {
247 return Err(FibQuantError::CorruptPayload(format!(
248 "training_samples {} outside supported range {}..={MAX_TRAINING_SAMPLES}",
249 self.training_samples, self.codebook_size
250 )));
251 }
252 Ok(())
253 }
254
255 pub fn digest(&self) -> Result<String> {
257 self.validate()?;
258 json_digest(PROFILE_SCHEMA, self)
259 }
260
261 pub fn block_count(&self) -> u32 {
263 self.ambient_dim / self.block_dim
264 }
265}
266
267pub fn wire_index_bits(codebook_size: usize) -> Result<u8> {
269 if codebook_size < 2 {
270 return Err(FibQuantError::InvalidCodebookSize(codebook_size));
271 }
272 let bits = usize::BITS - (codebook_size - 1).leading_zeros();
273 u8::try_from(bits).map_err(|_| FibQuantError::InvalidCodebookSize(codebook_size))
274}
275
276fn validate_profile_parts(
277 ambient_dim: usize,
278 block_dim: usize,
279 codebook_size: usize,
280) -> Result<()> {
281 if ambient_dim == 0 {
282 return Err(FibQuantError::ZeroDimension);
283 }
284 if block_dim == 0 || block_dim > ambient_dim {
285 return Err(FibQuantError::InvalidBlockDim {
286 ambient_dim,
287 block_dim,
288 });
289 }
290 if ambient_dim == block_dim {
291 return Err(FibQuantError::InvalidBlockDim {
292 ambient_dim,
293 block_dim,
294 });
295 }
296 if ambient_dim % block_dim != 0 {
297 return Err(FibQuantError::DimensionNotDivisible {
298 ambient_dim,
299 block_dim,
300 });
301 }
302 if ambient_dim > MAX_AMBIENT_DIM {
303 return Err(FibQuantError::ResourceLimitExceeded(format!(
304 "ambient_dim {ambient_dim} exceeds MAX_AMBIENT_DIM {MAX_AMBIENT_DIM}"
305 )));
306 }
307 if block_dim > MAX_BLOCK_DIM {
308 return Err(FibQuantError::ResourceLimitExceeded(format!(
309 "block_dim {block_dim} exceeds MAX_BLOCK_DIM {MAX_BLOCK_DIM}"
310 )));
311 }
312 if !(2..=MAX_CODEBOOK_SIZE).contains(&codebook_size) {
313 return Err(FibQuantError::InvalidCodebookSize(codebook_size));
314 }
315 Ok(())
316}
317
318fn default_training_samples(codebook_size: usize) -> Result<u32> {
319 let samples = 30usize
320 .checked_mul(codebook_size)
321 .ok_or_else(|| FibQuantError::ResourceLimitExceeded("30 * codebook_size overflow".into()))?
322 .max(256)
323 .min(MAX_TRAINING_SAMPLES as usize);
324 u32::try_from(samples)
325 .map_err(|_| FibQuantError::ResourceLimitExceeded("training sample count overflow".into()))
326}
327
328fn checked_profile_mul(lhs: usize, rhs: usize, label: &str) -> Result<usize> {
329 lhs.checked_mul(rhs)
330 .ok_or_else(|| FibQuantError::ResourceLimitExceeded(format!("{label} overflow")))
331}
332
333fn validate_resource_bounds(
334 ambient_dim: usize,
335 block_dim: usize,
336 codebook_size: usize,
337 training_samples: u32,
338 wire_index_bits: u8,
339) -> Result<()> {
340 let rotation_values =
341 checked_profile_mul(ambient_dim, ambient_dim, "ambient_dim * ambient_dim")?;
342 if rotation_values > MAX_ROTATION_MATRIX_VALUES {
343 return Err(FibQuantError::ResourceLimitExceeded(format!(
344 "rotation matrix values {rotation_values} exceed MAX_ROTATION_MATRIX_VALUES {MAX_ROTATION_MATRIX_VALUES}"
345 )));
346 }
347
348 let codebook_values =
349 checked_profile_mul(codebook_size, block_dim, "codebook_size * block_dim")?;
350 if codebook_values > MAX_CODEBOOK_VALUES {
351 return Err(FibQuantError::ResourceLimitExceeded(format!(
352 "codebook values {codebook_values} exceed MAX_CODEBOOK_VALUES {MAX_CODEBOOK_VALUES}"
353 )));
354 }
355
356 checked_profile_mul(
357 training_samples as usize,
358 block_dim,
359 "training_samples * block_dim",
360 )?;
361
362 let block_count = ambient_dim / block_dim;
363 let packed_bits = checked_profile_mul(
364 block_count,
365 wire_index_bits as usize,
366 "block_count * wire_index_bits",
367 )?;
368 if packed_bits > MAX_PACKED_INDEX_BITS {
369 return Err(FibQuantError::ResourceLimitExceeded(format!(
370 "packed index bits {packed_bits} exceed MAX_PACKED_INDEX_BITS {MAX_PACKED_INDEX_BITS}"
371 )));
372 }
373 Ok(())
374}
375
376fn validate_rate(name: &str, actual: f64, expected: f64) -> Result<()> {
377 if !actual.is_finite() || !expected.is_finite() || (actual - expected).abs() > RATE_TOLERANCE {
378 return Err(FibQuantError::CorruptPayload(format!(
379 "{name} {actual} does not match expected {expected}"
380 )));
381 }
382 Ok(())
383}
384
385fn validate_method_pair(
386 block_dim: usize,
387 radius: &RadiusMethod,
388 direction: &DirectionMethod,
389) -> Result<()> {
390 let valid = match block_dim {
391 2 => {
392 radius == &RadiusMethod::K2ClosedForm && direction == &DirectionMethod::FibonacciSpiral
393 }
394 3 => {
395 radius == &RadiusMethod::BetaQuantile && direction == &DirectionMethod::FibonacciSphere
396 }
397 _ => {
398 radius == &RadiusMethod::BetaQuantile && direction == &DirectionMethod::RobertsKronecker
399 }
400 };
401 if valid {
402 Ok(())
403 } else {
404 Err(FibQuantError::CorruptPayload(format!(
405 "unsupported radius/direction pair for k={block_dim}: {radius:?}/{direction:?}"
406 )))
407 }
408}