redoubt_codec_core/collections/
helpers.rs

1// Copyright (c) 2025-2026 Federico Hoerth <memparanoid@gmail.com>
2// SPDX-License-Identifier: GPL-3.0-only
3// See LICENSE in the repository root for full license text.
4
5#[cfg(feature = "zeroize")]
6use core::sync::atomic::{Ordering, compiler_fence};
7#[cfg(feature = "zeroize")]
8use redoubt_zero::FastZeroizable;
9#[cfg(feature = "zeroize")]
10use smallvec::SmallVec;
11
12use crate::codec_buffer::RedoubtCodecBuffer;
13use crate::error::{DecodeError, EncodeError, OverflowError, RedoubtCodecBufferError};
14use crate::traits::{BytesRequired, Decode, DecodeBuffer, DecodeZeroize, Encode, EncodeZeroize};
15use crate::zeroizing::Zeroizing;
16
17pub fn header_size() -> usize {
18    2 * size_of::<usize>()
19}
20
21#[inline(always)]
22pub fn write_header(
23    buf: &mut RedoubtCodecBuffer,
24    size: &mut usize,
25    bytes_required: &mut usize,
26) -> Result<(), RedoubtCodecBufferError> {
27    buf.write(size)?;
28    buf.write(bytes_required)?;
29
30    Ok(())
31}
32
33#[inline(always)]
34pub fn process_header(buf: &mut &mut [u8], output_size: &mut usize) -> Result<(), DecodeError> {
35    let header_size = Zeroizing::from(&mut header_size());
36
37    if buf.len() < *header_size {
38        return Err(DecodeError::PreconditionViolated);
39    }
40
41    // Infallible: precondition ensures buf.len() >= header_size (2 * usize)
42    // Error branch kept for panic-free guarantees, cannot be tested
43    buf.read_usize(output_size)?;
44
45    // bytes_required is only used internally for validation
46    let mut bytes_required = Zeroizing::from(&mut 0usize);
47
48    // Infallible: precondition ensures buf.len() >= header_size (2 * usize)
49    // Error branch kept for panic-free guarantees, cannot be tested
50    buf.read_usize(&mut bytes_required)?;
51
52    if *header_size > *bytes_required {
53        return Err(DecodeError::PreconditionViolated);
54    }
55
56    let expected_len = Zeroizing::from(&mut (*bytes_required - *header_size));
57
58    if buf.len() < *expected_len {
59        return Err(DecodeError::PreconditionViolated);
60    }
61
62    Ok(())
63}
64
65// =============================================================================
66// Derive macro helpers
67// =============================================================================
68
69/// Convert a reference to `&dyn BytesRequired`.
70#[inline(always)]
71pub fn to_bytes_required_dyn_ref<T: BytesRequired>(x: &T) -> &dyn BytesRequired {
72    x
73}
74
75/// Convert a mutable reference to `&mut dyn Encode`.
76#[inline(always)]
77pub fn to_encode_dyn_mut<T: Encode>(x: &mut T) -> &mut dyn Encode {
78    x
79}
80
81/// Convert a mutable reference to `&mut dyn Decode`.
82#[inline(always)]
83pub fn to_decode_dyn_mut<T: Decode>(x: &mut T) -> &mut dyn Decode {
84    x
85}
86
87/// Convert a mutable reference to `&mut dyn EncodeZeroize`.
88#[inline(always)]
89pub fn to_encode_zeroize_dyn_mut<T: EncodeZeroize>(x: &mut T) -> &mut dyn EncodeZeroize {
90    x
91}
92
93/// Convert a mutable reference to `&mut dyn DecodeZeroize`.
94#[inline(always)]
95pub fn to_decode_zeroize_dyn_mut<T: DecodeZeroize>(x: &mut T) -> &mut dyn DecodeZeroize {
96    x
97}
98
99/// Sum bytes required from an iterator of `&dyn BytesRequired`.
100#[inline(always)]
101pub fn bytes_required_sum<'a>(
102    iter: impl Iterator<Item = &'a dyn BytesRequired>,
103) -> Result<usize, OverflowError> {
104    let mut total = Zeroizing::from(&mut 0usize);
105
106    for elem in iter {
107        let new_total = Zeroizing::from(&mut total.wrapping_add(elem.encode_bytes_required()?));
108
109        if *new_total < *total {
110            return Err(OverflowError {
111                reason: "bytes_required_sum overflow".into(),
112            });
113        }
114
115        *total = *new_total;
116    }
117
118    Ok(*total)
119}
120
121/// Encode fields from an iterator of `&mut dyn EncodeZeroize`.
122/// On error with zeroize feature, zeroizes all fields and the buffer.
123#[inline(always)]
124pub fn encode_fields<'a>(
125    fields: impl Iterator<Item = &'a mut dyn EncodeZeroize>,
126    buf: &mut RedoubtCodecBuffer,
127) -> Result<(), EncodeError> {
128    let mut result = Ok(());
129
130    for field in fields {
131        #[cfg(feature = "zeroize")]
132        if result.is_err() {
133            field.fast_zeroize();
134            compiler_fence(Ordering::SeqCst);
135            continue;
136        }
137
138        if let Err(e) = field.encode_into(buf) {
139            result = Err(e);
140            #[cfg(feature = "zeroize")]
141            {
142                field.fast_zeroize();
143                compiler_fence(Ordering::SeqCst);
144                buf.fast_zeroize();
145            }
146
147            #[cfg(not(feature = "zeroize"))]
148            break;
149        }
150    }
151
152    result
153}
154
155/// Decode fields from an iterator of `&mut dyn DecodeZeroize`.
156/// On error with zeroize feature, zeroizes all fields and the buffer.
157#[inline(always)]
158pub fn decode_fields<'a>(
159    fields: impl Iterator<Item = &'a mut dyn DecodeZeroize>,
160    buf: &mut &mut [u8],
161) -> Result<(), DecodeError> {
162    #[cfg(feature = "zeroize")]
163    let mut decoded: SmallVec<[&'a mut dyn DecodeZeroize; 32]> = SmallVec::new();
164    let mut result = Ok(());
165
166    for field in fields {
167        #[cfg(feature = "zeroize")]
168        if result.is_err() {
169            field.fast_zeroize();
170            compiler_fence(Ordering::SeqCst);
171            continue;
172        }
173
174        if let Err(e) = field.decode_from(buf) {
175            result = Err(e);
176
177            #[cfg(feature = "zeroize")]
178            {
179                field.fast_zeroize();
180                compiler_fence(Ordering::SeqCst);
181
182                // Zeroize all previously decoded fields
183                for decoded_field in decoded.iter_mut() {
184                    decoded_field.fast_zeroize();
185                    compiler_fence(Ordering::SeqCst);
186                }
187
188                redoubt_util::fast_zeroize_slice(buf);
189            }
190
191            #[cfg(not(feature = "zeroize"))]
192            break;
193        } else {
194            #[cfg(feature = "zeroize")]
195            decoded.push(field);
196        }
197    }
198
199    result
200}