1use crate::codec::range_encoder::RangeEncoder;
2use crate::codec::rc_qs_model::RCQsModel;
3use crate::core::pc_map;
4
5const PC_BIT_MAX: u32 = 8;
7
8pub fn symbol_count(bits: u32) -> usize {
10 if bits > PC_BIT_MAX {
11 (2 * bits + 1) as usize
13 } else {
14 (2 * (1u32 << bits) - 1) as usize
16 }
17}
18
19pub struct PCEncoderFloat<'a> {
21 encoder: &'a mut RangeEncoder,
22 model: &'a mut RCQsModel,
23 bits: u32,
24}
25
26impl<'a> PCEncoderFloat<'a> {
27 pub fn new(encoder: &'a mut RangeEncoder, model: &'a mut RCQsModel, bits: u32) -> Self {
28 Self {
29 encoder,
30 model,
31 bits,
32 }
33 }
34
35 #[inline]
37 pub fn encode(&mut self, actual: u32, predicted: u32) -> u32 {
38 let actual = pc_map::mask_u32(actual, self.bits);
39 let predicted = pc_map::mask_u32(predicted, self.bits);
40
41 if self.bits > PC_BIT_MAX {
42 self.encode_wide(actual, predicted);
43 } else {
44 self.encode_narrow(actual, predicted);
45 }
46 actual
47 }
48
49 #[inline]
50 fn encode_wide(&mut self, actual: u32, predicted: u32) {
51 let bias = self.bits;
52 if predicted < actual {
53 let d = actual - predicted;
54 let k = 31 - d.leading_zeros();
55 self.encoder.encode_with_model(bias + 1 + k, self.model);
56 self.encoder.encode_uint(d - (1u32 << k), k as i32);
57 } else if predicted > actual {
58 let d = predicted - actual;
59 let k = 31 - d.leading_zeros();
60 self.encoder.encode_with_model(bias - 1 - k, self.model);
61 self.encoder.encode_uint(d - (1u32 << k), k as i32);
62 } else {
63 self.encoder.encode_with_model(bias, self.model);
64 }
65 }
66
67 #[inline]
68 fn encode_narrow(&mut self, actual: u32, predicted: u32) {
69 let bias = (1u32 << self.bits) - 1;
70 let symbol = bias.wrapping_add(actual).wrapping_sub(predicted);
71 self.encoder.encode_with_model(symbol, self.model);
72 }
73}
74
75pub struct PCEncoderDouble<'a> {
77 encoder: &'a mut RangeEncoder,
78 model: &'a mut RCQsModel,
79 bits: u32,
80}
81
82impl<'a> PCEncoderDouble<'a> {
83 pub fn new(encoder: &'a mut RangeEncoder, model: &'a mut RCQsModel, bits: u32) -> Self {
84 Self {
85 encoder,
86 model,
87 bits,
88 }
89 }
90
91 #[inline]
93 pub fn encode(&mut self, actual: u64, predicted: u64) -> u64 {
94 let actual = pc_map::mask_u64(actual, self.bits);
95 let predicted = pc_map::mask_u64(predicted, self.bits);
96
97 if self.bits > PC_BIT_MAX {
98 self.encode_wide(actual, predicted);
99 } else {
100 self.encode_narrow(actual, predicted);
101 }
102 actual
103 }
104
105 #[inline]
106 fn encode_wide(&mut self, actual: u64, predicted: u64) {
107 let bias = self.bits;
108 if predicted < actual {
109 let d = actual - predicted;
110 let k = 63 - d.leading_zeros();
111 self.encoder.encode_with_model(bias + 1 + k, self.model);
112 self.encoder.encode_ulong(d - (1u64 << k), k as i32);
113 } else if predicted > actual {
114 let d = predicted - actual;
115 let k = 63 - d.leading_zeros();
116 self.encoder.encode_with_model(bias - 1 - k, self.model);
117 self.encoder.encode_ulong(d - (1u64 << k), k as i32);
118 } else {
119 self.encoder.encode_with_model(bias, self.model);
120 }
121 }
122
123 #[inline]
124 fn encode_narrow(&mut self, actual: u64, predicted: u64) {
125 let bias = (1u64 << self.bits) - 1;
126 let symbol = bias.wrapping_add(actual).wrapping_sub(predicted);
127 self.encoder.encode_with_model(symbol as u32, self.model);
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::codec::range_decoder::RangeDecoder;
135 use crate::core::pc_decoder::{PCDecoderDouble, PCDecoderFloat};
136
137 fn round_trip_float(actual: u32, predicted: u32, bits: u32) -> u32 {
138 let symbols = symbol_count(bits);
139 let mut enc = RangeEncoder::new();
140 let mut model = RCQsModel::with_defaults(true, symbols);
141 {
142 let mut pc = PCEncoderFloat::new(&mut enc, &mut model, bits);
143 pc.encode(actual, predicted);
144 }
145 let data = enc.finish();
146
147 let mut dec = RangeDecoder::new(&data);
148 dec.init();
149 let mut dmodel = RCQsModel::with_defaults(false, symbols);
150 let mut pcd = PCDecoderFloat::new(&mut dec, &mut dmodel, bits);
151 pcd.decode(predicted)
152 }
153
154 fn round_trip_double(actual: u64, predicted: u64, bits: u32) -> u64 {
155 let symbols = symbol_count(bits);
156 let mut enc = RangeEncoder::new();
157 let mut model = RCQsModel::with_defaults(true, symbols);
158 {
159 let mut pc = PCEncoderDouble::new(&mut enc, &mut model, bits);
160 pc.encode(actual, predicted);
161 }
162 let data = enc.finish();
163
164 let mut dec = RangeDecoder::new(&data);
165 dec.init();
166 let mut dmodel = RCQsModel::with_defaults(false, symbols);
167 let mut pcd = PCDecoderDouble::new(&mut dec, &mut dmodel, bits);
168 pcd.decode(predicted)
169 }
170
171 #[test]
172 fn float_perfect_prediction() {
173 assert_eq!(round_trip_float(100, 100, 32), 100);
174 }
175
176 #[test]
177 fn float_underprediction() {
178 assert_eq!(round_trip_float(200, 100, 32), 200);
179 }
180
181 #[test]
182 fn float_overprediction() {
183 assert_eq!(round_trip_float(50, 200, 32), 50);
184 }
185
186 #[test]
187 fn float_all_delta_sizes() {
188 let predicted = 0u32;
189 for k in 0..31 {
190 let delta = 1u32 << k;
191 let actual = predicted.wrapping_add(delta);
192 assert_eq!(
193 round_trip_float(actual, predicted, 32),
194 actual,
195 "k={k} under"
196 );
197 let actual2 = predicted.wrapping_sub(delta);
198 assert_eq!(
199 round_trip_float(actual2, predicted, 32),
200 actual2,
201 "k={k} over"
202 );
203 }
204 }
205
206 #[test]
207 fn double_perfect_prediction() {
208 assert_eq!(round_trip_double(100, 100, 64), 100);
209 }
210
211 #[test]
212 fn double_underprediction() {
213 assert_eq!(round_trip_double(200, 100, 64), 200);
214 }
215
216 #[test]
217 fn double_overprediction() {
218 assert_eq!(round_trip_double(50, 200, 64), 50);
219 }
220
221 #[test]
222 fn float_sequence() {
223 let symbols = symbol_count(32);
224 let mut enc = RangeEncoder::new();
225 let mut model = RCQsModel::with_defaults(true, symbols);
226 let pairs: Vec<(u32, u32)> =
227 vec![(100, 100), (200, 100), (50, 200), (0, 0), (0xFFFFFFFF, 0)];
228 {
229 let mut pc = PCEncoderFloat::new(&mut enc, &mut model, 32);
230 for &(a, p) in &pairs {
231 pc.encode(a, p);
232 }
233 }
234 let data = enc.finish();
235
236 let mut dec = RangeDecoder::new(&data);
237 dec.init();
238 let mut dmodel = RCQsModel::with_defaults(false, symbols);
239 let mut pcd = PCDecoderFloat::new(&mut dec, &mut dmodel, 32);
240 for &(a, p) in &pairs {
241 assert_eq!(pcd.decode(p), a);
242 }
243 }
244
245 #[test]
246 fn double_all_delta_sizes() {
247 let predicted = 0u64;
248 for k in 0..63 {
249 let delta = 1u64 << k;
250 let actual = predicted.wrapping_add(delta);
251 assert_eq!(
252 round_trip_double(actual, predicted, 64),
253 actual,
254 "k={k} under"
255 );
256 }
257 }
258
259 #[test]
261 fn float_narrow_round_trip() {
262 for bits in [2, 4, 8] {
263 let mask = (1u32 << bits) - 1;
264 for a in 0..=mask.min(15) {
265 for p in 0..=mask.min(15) {
266 let result = round_trip_float(a, p, bits);
267 assert_eq!(result, a, "bits={bits} a={a} p={p}");
268 }
269 }
270 }
271 }
272
273 #[test]
274 fn float_reduced_precision_wide() {
275 let mask = 0xFFFFu32;
277 assert_eq!(round_trip_float(0, 0, 16), 0);
278 assert_eq!(round_trip_float(mask, 0, 16), mask);
279 assert_eq!(round_trip_float(100, 200, 16), 100);
280 assert_eq!(round_trip_float(200, 100, 16), 200);
281 }
282}