1use ps_buffer::Buffer;
2
3use crate::{codeword::Codeword, long, DecodeError, EncodeError, ReedSolomon};
4
5pub fn encode(message: &[u8], parity: u8) -> Result<Buffer, EncodeError> {
10 if message.len() + (usize::from(parity) << 1) > 0xff {
11 let segment_length = 0xFF - (parity << 1);
12 let codeword = long::encode(message, parity, segment_length, segment_length)?;
13
14 return Ok(codeword);
15 }
16
17 let rs = ReedSolomon::new(parity)?;
18
19 Ok(rs.encode(message)?)
20}
21
22pub fn decode(received: &[u8], parity: u8) -> Result<Codeword<'_>, DecodeError> {
28 if let Ok(length) = u8::try_from(received.len()) {
29 if parity > length >> 1 {
30 return Err(DecodeError::InsufficientParityBytes(parity, length));
31 }
32
33 let rs = ReedSolomon::new(parity)?;
34
35 Ok(rs.decode(received)?)
36 } else {
37 Ok(long::decode(received)?)
38 }
39}
40
41#[must_use]
42pub fn validate(received: &[u8], parity: u8) -> bool {
44 if let Ok(length) = u8::try_from(received.len()) {
45 if parity > length >> 1 {
46 return false;
47 }
48
49 let Ok(rs) = ReedSolomon::new(parity) else {
50 return false;
51 };
52
53 match rs.validate(received) {
54 Ok(None) => true,
55 Ok(Some(_)) | Err(_) => false,
56 }
57 } else {
58 long::fast_validate(received).unwrap_or_default()
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use crate::EccError;
65
66 use super::{decode, encode};
67
68 #[test]
69 fn ecc_works() -> Result<(), EccError> {
70 let test_str = "Strč prst skrz krk! ¯\\_(ツ)_/¯".as_bytes();
71 let mut encoded = encode(test_str, 13)?;
72
73 for i in 0..13 {
74 let index = (i * 37) % encoded.len();
75 encoded[index] ^= (i * index + 13).to_le_bytes()[0];
76 let decoded = decode(&encoded, 13)?;
77
78 assert_eq!(test_str, &decoded[..]);
79 }
80
81 Ok(())
82 }
83}
84
85#[cfg(test)]
86mod validate_tests {
87 use crate::{
88 long, validate, LongEccConstructorError, LongEccDecodeError, LongEccEncodeError,
89 LongEccToBytesError, RSEncodeError, ReedSolomon,
90 };
91
92 use ps_buffer::ToBuffer;
93
94 #[derive(thiserror::Error, Debug)]
95 enum TestError {
96 #[error(transparent)]
97 LongEccConstructor(#[from] LongEccConstructorError),
98 #[error(transparent)]
99 LongEccEncode(#[from] LongEccEncodeError),
100 #[error(transparent)]
101 LongEccDecode(#[from] LongEccDecodeError),
102 #[error(transparent)]
103 LongEccToBytes(#[from] LongEccToBytesError),
104 #[error(transparent)]
105 Buffer(#[from] ps_buffer::BufferError),
106 #[error(transparent)]
107 RSConstructorError(#[from] crate::RSConstructorError),
108 #[error(transparent)]
109 RSEncodeError(#[from] RSEncodeError),
110 }
111
112 #[test]
113 fn test_validate_short_data_valid_no_errors() -> Result<(), TestError> {
114 let data = b"test";
115 let parity = 2;
116 let rs = ReedSolomon::new(parity)?;
117 let codeword = rs.encode(data)?;
118
119 assert!(validate(&codeword, parity));
120 Ok(())
121 }
122
123 #[test]
124 fn test_validate_short_data_invalid_with_errors() -> Result<(), TestError> {
125 let data = b"test";
126 let parity = 2;
127 let rs = ReedSolomon::new(parity)?;
128 let mut codeword = rs.encode(data)?;
129
130 codeword[0] ^= 1;
132 codeword[1] ^= 1;
133 codeword[2] ^= 1;
134
135 assert!(!validate(&codeword, parity));
136 Ok(())
137 }
138
139 #[test]
140 fn test_validate_short_data_parity_too_large() {
141 let data = b"test"; let parity = 3; assert!(!validate(data, parity));
146 }
147
148 #[test]
149 fn test_validate_short_data_rs_constructor_error() {
150 let data = b"test";
151 let parity = 255; assert!(!validate(data, parity));
155 }
156
157 #[test]
158 fn test_validate_short_data_correctable_errors() -> Result<(), TestError> {
159 let data = b"test";
160 let parity = 2;
161 let rs = ReedSolomon::new(parity)?;
162 let mut codeword = rs.encode(data)?;
163
164 codeword[0] ^= 1;
166
167 assert!(!validate(&codeword, parity));
169
170 Ok(())
171 }
172
173 #[test]
174 fn test_validate_long_data_valid_no_errors() -> Result<(), TestError> {
175 let message = b"This is a longer message that will use long ECC".repeat(7);
176 let parity = 2;
177 let segment_length = 20;
178 let segment_distance = 16;
179
180 let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
181
182 assert!(validate(&encoded, parity));
184
185 Ok(())
186 }
187
188 #[test]
189 fn test_validate_long_data_invalid_with_errors() -> Result<(), TestError> {
190 let message = b"This is a longer message that will use long ECC".to_buffer()?;
191 let parity = 2;
192 let segment_length = 20;
193 let segment_distance = 16;
194
195 let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
196
197 encoded[32] ^= 1;
199 encoded[37] ^= 1;
200
201 assert!(!validate(&encoded, parity));
203 Ok(())
204 }
205
206 #[test]
207 fn test_validate_long_data_fast_path_valid() -> Result<(), TestError> {
208 let message = b"Fast path validation test".repeat(12);
209 let parity = 1;
210 let segment_length = 15;
211 let segment_distance = 12;
212
213 let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
214
215 assert!(validate(&encoded, parity));
217
218 Ok(())
219 }
220
221 #[test]
222 fn test_validate_empty_data() {
223 let data = b"";
224 let parity = 0;
225
226 assert!(validate(data, parity));
228 }
229
230 #[test]
231 fn test_validate_single_byte() -> Result<(), TestError> {
232 let data = b"A";
233 let parity = 1;
234 let rs = ReedSolomon::new(parity)?;
235 let codeword = rs.encode(data)?;
236
237 assert!(validate(&codeword, parity));
238 Ok(())
239 }
240
241 #[test]
242 fn test_validate_large_short_data() -> Result<(), TestError> {
243 let data = b"This is exactly 32 bytes of test data!!";
244 let parity = 4;
245 let rs = ReedSolomon::new(parity)?;
246 let codeword = rs.encode(data)?;
247
248 assert!(validate(&codeword, parity));
249 Ok(())
250 }
251
252 #[test]
253 fn test_validate_edge_case_parity_equals_length_div_2() -> Result<(), TestError> {
254 let data = b"test"; let parity = 2; let rs = ReedSolomon::new(parity)?;
258 let codeword = rs.encode(data)?;
259
260 assert!(validate(&codeword, parity));
261 Ok(())
262 }
263
264 #[test]
265 fn test_validate_edge_case_parity_just_over_length_div_2() {
266 let data = b"test"; let parity = 3; assert!(!validate(data, parity));
271 }
272
273 #[test]
274 fn test_validate_long_data_with_zero_parity() -> Result<(), TestError> {
275 let message = b"Zero parity test".to_buffer()?;
276 let parity = 0;
277 let segment_length = 15;
278 let segment_distance = 12;
279
280 let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
281
282 assert!(validate(&encoded, parity));
284 Ok(())
285 }
286
287 #[test]
288 fn test_validate_long_data_header_corrupted() -> Result<(), TestError> {
289 let message = b"Header corruption test".to_buffer()?;
290 let parity = 2;
291 let segment_length = 15;
292 let segment_distance = 12;
293
294 let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
295
296 encoded[0] ^= 1;
298 encoded[5] ^= 1;
299
300 assert!(!validate(&encoded, parity));
302 Ok(())
303 }
304
305 #[test]
306 fn test_validate_short_data_length_conversion_error() {
307 let data: Vec<u8> = vec![0x42; 300]; let parity = 2;
309
310 assert!(!validate(&data, parity));
312 }
313
314 #[test]
315 fn test_validate_short_data_unrecoverable_errors() -> Result<(), TestError> {
316 let data = b"test data";
317 let parity = 1;
318 let rs = ReedSolomon::new(parity)?;
319 let mut codeword = rs.encode(data)?;
320
321 codeword[0] ^= 1;
323 codeword[1] ^= 1; assert!(!validate(&codeword, parity));
326 Ok(())
327 }
328
329 #[test]
330 fn test_validate_long_data_corrupted_parity() -> Result<(), TestError> {
331 let message = b"Corrupted parity test".to_buffer()?;
332 let parity = 2;
333 let segment_length = 15;
334 let segment_distance = 12;
335
336 let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
337
338 let parity_start = 32 + message.len();
340 if parity_start < encoded.len() {
341 encoded[parity_start] ^= 1;
342 }
343
344 assert!(!validate(&encoded, parity));
346 Ok(())
347 }
348}