ps_ecc/
methods.rs

1use ps_buffer::Buffer;
2
3use crate::{codeword::Codeword, long, DecodeError, EncodeError, ReedSolomon};
4
5/// Encodes a message by adding an error-correcting code.
6/// # Errors
7/// - `RSConstructorError` is returned if `len(message) + 2 * parity` > `255`.
8/// - `RSEncodeError` is returned if encoding fails for any reason.
9pub fn encode(message: &[u8], parity: u8) -> Result<Buffer, EncodeError> {
10    if message.len() + (usize::from(parity) << 1) > 0xff {
11        let segment_length = 0xFF - (parity << 1);
12        let codeword = long::encode(message, parity, segment_length, segment_length)?;
13
14        return Ok(codeword);
15    }
16
17    let rs = ReedSolomon::new(parity)?;
18
19    Ok(rs.encode(message)?)
20}
21
22/// Verifies the error-correcting code and returns the message.
23/// # Errors
24/// - `InputTooLarge` is returned if `len(received)` > 255 bytes.
25/// - `InsufficientParityBytes` is returned if `parity > length / 2`.
26/// - `RSDecodeError` is returned if decoding fails for any reason.
27pub fn decode(received: &[u8], parity: u8) -> Result<Codeword<'_>, DecodeError> {
28    if let Ok(length) = u8::try_from(received.len()) {
29        if parity > length >> 1 {
30            return Err(DecodeError::InsufficientParityBytes(parity, length));
31        }
32
33        let rs = ReedSolomon::new(parity)?;
34
35        Ok(rs.decode(received)?)
36    } else {
37        Ok(long::decode(received)?)
38    }
39}
40
41#[must_use]
42/// Validates that a received codeword isn't corrupted.
43pub fn validate(received: &[u8], parity: u8) -> bool {
44    if let Ok(length) = u8::try_from(received.len()) {
45        if parity > length >> 1 {
46            return false;
47        }
48
49        let Ok(rs) = ReedSolomon::new(parity) else {
50            return false;
51        };
52
53        match rs.validate(received) {
54            Ok(None) => true,
55            Ok(Some(_)) | Err(_) => false,
56        }
57    } else {
58        long::fast_validate(received).unwrap_or_default()
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use crate::EccError;
65
66    use super::{decode, encode};
67
68    #[test]
69    fn ecc_works() -> Result<(), EccError> {
70        let test_str = "Strč prst skrz krk! ¯\\_(ツ)_/¯".as_bytes();
71        let mut encoded = encode(test_str, 13)?;
72
73        for i in 0..13 {
74            let index = (i * 37) % encoded.len();
75            encoded[index] ^= (i * index + 13).to_le_bytes()[0];
76            let decoded = decode(&encoded, 13)?;
77
78            assert_eq!(test_str, &decoded[..]);
79        }
80
81        Ok(())
82    }
83}
84
85#[cfg(test)]
86mod validate_tests {
87    use crate::{
88        long, validate, LongEccConstructorError, LongEccDecodeError, LongEccEncodeError,
89        LongEccToBytesError, RSEncodeError, ReedSolomon,
90    };
91
92    use ps_buffer::ToBuffer;
93
94    #[derive(thiserror::Error, Debug)]
95    enum TestError {
96        #[error(transparent)]
97        LongEccConstructor(#[from] LongEccConstructorError),
98        #[error(transparent)]
99        LongEccEncode(#[from] LongEccEncodeError),
100        #[error(transparent)]
101        LongEccDecode(#[from] LongEccDecodeError),
102        #[error(transparent)]
103        LongEccToBytes(#[from] LongEccToBytesError),
104        #[error(transparent)]
105        Buffer(#[from] ps_buffer::BufferError),
106        #[error(transparent)]
107        RSConstructorError(#[from] crate::RSConstructorError),
108        #[error(transparent)]
109        RSEncodeError(#[from] RSEncodeError),
110    }
111
112    #[test]
113    fn test_validate_short_data_valid_no_errors() -> Result<(), TestError> {
114        let data = b"test";
115        let parity = 2;
116        let rs = ReedSolomon::new(parity)?;
117        let codeword = rs.encode(data)?;
118
119        assert!(validate(&codeword, parity));
120        Ok(())
121    }
122
123    #[test]
124    fn test_validate_short_data_invalid_with_errors() -> Result<(), TestError> {
125        let data = b"test";
126        let parity = 2;
127        let rs = ReedSolomon::new(parity)?;
128        let mut codeword = rs.encode(data)?;
129
130        // Introduce errors
131        codeword[0] ^= 1;
132        codeword[1] ^= 1;
133        codeword[2] ^= 1;
134
135        assert!(!validate(&codeword, parity));
136        Ok(())
137    }
138
139    #[test]
140    fn test_validate_short_data_parity_too_large() {
141        let data = b"test"; // 4 bytes
142        let parity = 3; // 3 > 4/2 (2) - parity is too large
143
144        // Should return false when parity is too large
145        assert!(!validate(data, parity));
146    }
147
148    #[test]
149    fn test_validate_short_data_rs_constructor_error() {
150        let data = b"test";
151        let parity = 255; // Invalid parity value that should cause RS constructor to fail
152
153        // Should return false when RS constructor fails
154        assert!(!validate(data, parity));
155    }
156
157    #[test]
158    fn test_validate_short_data_correctable_errors() -> Result<(), TestError> {
159        let data = b"test";
160        let parity = 2;
161        let rs = ReedSolomon::new(parity)?;
162        let mut codeword = rs.encode(data)?;
163
164        // Introduce correctable errors (1 error with parity=2)
165        codeword[0] ^= 1;
166
167        // Validation should return false because there are errors (even if correctable)
168        assert!(!validate(&codeword, parity));
169
170        Ok(())
171    }
172
173    #[test]
174    fn test_validate_long_data_valid_no_errors() -> Result<(), TestError> {
175        let message = b"This is a longer message that will use long ECC".repeat(7);
176        let parity = 2;
177        let segment_length = 20;
178        let segment_distance = 16;
179
180        let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
181
182        // Valid data should pass validation
183        assert!(validate(&encoded, parity));
184
185        Ok(())
186    }
187
188    #[test]
189    fn test_validate_long_data_invalid_with_errors() -> Result<(), TestError> {
190        let message = b"This is a longer message that will use long ECC".to_buffer()?;
191        let parity = 2;
192        let segment_length = 20;
193        let segment_distance = 16;
194
195        let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
196
197        // Introduce errors in the message
198        encoded[32] ^= 1;
199        encoded[37] ^= 1;
200
201        // Invalid data should fail validation
202        assert!(!validate(&encoded, parity));
203        Ok(())
204    }
205
206    #[test]
207    fn test_validate_long_data_fast_path_valid() -> Result<(), TestError> {
208        let message = b"Fast path validation test".repeat(12);
209        let parity = 1;
210        let segment_length = 15;
211        let segment_distance = 12;
212
213        let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
214
215        // Valid data should pass fast validation
216        assert!(validate(&encoded, parity));
217
218        Ok(())
219    }
220
221    #[test]
222    fn test_validate_empty_data() {
223        let data = b"";
224        let parity = 0;
225
226        // Empty data should be handled gracefully
227        assert!(validate(data, parity));
228    }
229
230    #[test]
231    fn test_validate_single_byte() -> Result<(), TestError> {
232        let data = b"A";
233        let parity = 1;
234        let rs = ReedSolomon::new(parity)?;
235        let codeword = rs.encode(data)?;
236
237        assert!(validate(&codeword, parity));
238        Ok(())
239    }
240
241    #[test]
242    fn test_validate_large_short_data() -> Result<(), TestError> {
243        let data = b"This is exactly 32 bytes of test data!!";
244        let parity = 4;
245        let rs = ReedSolomon::new(parity)?;
246        let codeword = rs.encode(data)?;
247
248        assert!(validate(&codeword, parity));
249        Ok(())
250    }
251
252    #[test]
253    fn test_validate_edge_case_parity_equals_length_div_2() -> Result<(), TestError> {
254        let data = b"test"; // 4 bytes
255        let parity = 2; // 2 == 4/2 - edge case
256
257        let rs = ReedSolomon::new(parity)?;
258        let codeword = rs.encode(data)?;
259
260        assert!(validate(&codeword, parity));
261        Ok(())
262    }
263
264    #[test]
265    fn test_validate_edge_case_parity_just_over_length_div_2() {
266        let data = b"test"; // 4 bytes
267        let parity = 3; // 3 > 4/2 (2) - parity is too large
268
269        // Should return false when parity is too large
270        assert!(!validate(data, parity));
271    }
272
273    #[test]
274    fn test_validate_long_data_with_zero_parity() -> Result<(), TestError> {
275        let message = b"Zero parity test".to_buffer()?;
276        let parity = 0;
277        let segment_length = 15;
278        let segment_distance = 12;
279
280        let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
281
282        // Zero parity should still validate correctly
283        assert!(validate(&encoded, parity));
284        Ok(())
285    }
286
287    #[test]
288    fn test_validate_long_data_header_corrupted() -> Result<(), TestError> {
289        let message = b"Header corruption test".to_buffer()?;
290        let parity = 2;
291        let segment_length = 15;
292        let segment_distance = 12;
293
294        let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
295
296        // Corrupt header data
297        encoded[0] ^= 1;
298        encoded[5] ^= 1;
299
300        // Corrupted header should fail validation
301        assert!(!validate(&encoded, parity));
302        Ok(())
303    }
304
305    #[test]
306    fn test_validate_short_data_length_conversion_error() {
307        let data: Vec<u8> = vec![0x42; 300]; // 300 bytes > 255
308        let parity = 2;
309
310        // This should fall back to long::fast_validate
311        assert!(!validate(&data, parity));
312    }
313
314    #[test]
315    fn test_validate_short_data_unrecoverable_errors() -> Result<(), TestError> {
316        let data = b"test data";
317        let parity = 1;
318        let rs = ReedSolomon::new(parity)?;
319        let mut codeword = rs.encode(data)?;
320
321        // Introduce unrecoverable errors (more errors than parity can correct)
322        codeword[0] ^= 1;
323        codeword[1] ^= 1; // 2 errors with parity=1 - unrecoverable
324
325        assert!(!validate(&codeword, parity));
326        Ok(())
327    }
328
329    #[test]
330    fn test_validate_long_data_corrupted_parity() -> Result<(), TestError> {
331        let message = b"Corrupted parity test".to_buffer()?;
332        let parity = 2;
333        let segment_length = 15;
334        let segment_distance = 12;
335
336        let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
337
338        // Corrupt parity bytes
339        let parity_start = 32 + message.len();
340        if parity_start < encoded.len() {
341            encoded[parity_start] ^= 1;
342        }
343
344        // Corrupted parity should fail validation
345        assert!(!validate(&encoded, parity));
346        Ok(())
347    }
348}