1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
//! SORF inverse transform scalar function.
//!
//! SORF (Structured Orthogonal Random Features, [Yu et al. 2016][sorf-paper]) is a fast structured
//! approximation to a random orthogonal matrix. It composes random sign diagonals with the
//! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost
//! of a dense orthogonal matrix.
//!
//! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension
//! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform
//! at execution time, producing a [`Vector`] extension array with the original (pre-padding)
//! dimensionality.
//!
//! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the
//! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's
//! frozen local SplitMix64 stream contract rather than by an external RNG crate.
//!
//! # Input element type: `f32` only (TODO(connor): for now...)
//!
//! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is
//! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need
//! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform.
//!
//! The reason for this constraint is that TurboQuant (the only production caller today) stores its
//! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`.
//!
//! Supporting other float storage types would require an implicit up-/down-cast that we do not yet
//! want to bake into SorfTransform. This restriction is intentional and may be relaxed in the
//! future, but today it is load-bearing.
//!
//! # Output element type
//!
//! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It
//! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast
//! while materializing the output. This lets SorfTransform hand its result directly to a
//! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose
//! element-type expectation may differ from the `f32` the transform operated on internally.
//!
//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf
//! [`Vector`]: crate::vector::Vector
use fmt;
use Formatter;
use ArrayRef;
use ScalarFnArray;
use PType;
use TypedScalarFnInstance;
use VortexResult;
use vortex_ensure;
pub use SorfMatrix;
/// Inverse SORF orthogonal transform scalar function.
///
/// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32`
/// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the
/// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and
/// wraps the result in a new [`Vector`](crate::vector::Vector) extension array.
///
/// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the
/// `f32`-only input constraint.
;
/// Options for the SORF inverse transform scalar function.
///
/// Stored in the [`ScalarFnArray`] and used to deterministically reconstruct the
/// [`SorfMatrix`] at decode time.
/// Checks that the SORF configuration is valid.
pub