Skip to main content

mlt_core/
decoder.rs

1use std::mem::size_of;
2
3use crate::errors::AsMltError as _;
4use crate::{Layer, MltError, MltResult};
5
6/// Default memory budget: 10 MiB.
7const DEFAULT_MAX_BYTES: u32 = 10 * 1024 * 1024;
8
9/// Stateful decoder that enforces a per-tile memory budget during decoding.
10///
11/// Pass a `Decoder` to every `raw.decode()` / `into_tile()` call and to
12/// `from_bytes`-style parsers. Each method charges the budget before
13/// performing heap allocations, so the total heap used never exceeds `max_bytes`
14/// (in bytes).
15///
16/// ```
17/// use mlt_core::Decoder;
18///
19/// // Default: 10 MiB budget.
20/// let mut dec = Decoder::default();
21///
22/// // Custom budget.
23/// let mut dec = Decoder::with_max_size(64 * 1024 * 1024);
24/// ```
25#[derive(Debug, Clone, PartialEq, Eq, Default)]
26pub struct Decoder {
27    /// Keep track of the memory used when decoding a tile: raw->parsed transition
28    budget: MemBudget,
29    /// Reusable scratch buffer for the physical u32 decode pass.
30    /// Held here so its heap allocation is reused across streams without extra cost.
31    pub(crate) buffer_u32: Vec<u32>,
32    /// Reusable scratch buffer for the physical u64 decode pass.
33    /// Held here so its heap allocation is reused across streams without extra cost.
34    pub(crate) buffer_u64: Vec<u64>,
35}
36
37impl Decoder {
38    /// Create a decoder with a custom memory budget (in bytes).
39    #[must_use]
40    pub fn with_max_size(max_bytes: u32) -> Self {
41        Self {
42            budget: MemBudget::with_max_size(max_bytes),
43            ..Default::default()
44        }
45    }
46
47    /// Allocate a `Vec<T>` with the given capacity, charging the decoder's budget for
48    /// `capacity * size_of::<T>()` bytes. Use this instead of `Vec::with_capacity` in decode paths.
49    #[inline]
50    pub(crate) fn alloc<T>(&mut self, capacity: usize) -> Result<Vec<T>, MltError> {
51        let bytes = capacity.checked_mul(size_of::<T>()).or_overflow()?;
52        let bytes_u32 = u32::try_from(bytes).or_overflow()?;
53        self.budget.consume(bytes_u32)?;
54        Ok(Vec::with_capacity(capacity))
55    }
56
57    /// Charge the budget for `size` raw bytes. Prefer [`consume_items`][Self::consume_items]
58    /// when charging for a known-type collection.
59    #[inline]
60    pub(crate) fn consume(&mut self, size: u32) -> MltResult<()> {
61        self.budget.consume(size)
62    }
63
64    /// Charge the budget for `count` items of type `T` (`count * size_of::<T>()` bytes).
65    #[inline]
66    pub(crate) fn consume_items<T>(&mut self, count: usize) -> MltResult<()> {
67        let bytes = count.checked_mul(size_of::<T>()).or_overflow()?;
68        self.budget.consume(u32::try_from(bytes).or_overflow()?)
69    }
70
71    #[inline]
72    pub(crate) fn adjust(&mut self, adjustment: u32) {
73        self.budget.adjust(adjustment);
74    }
75
76    /// Assert (in debug builds) that `buf` has not grown beyond `alloc_size`, then adjust the
77    /// budget to return any bytes that were pre-charged but not actually used.
78    ///
79    /// Call this after fully populating a `Vec<T>` that was pre-allocated with [`Decoder::alloc`],
80    /// passing the same `alloc_size` that was given to `alloc`.
81    ///
82    /// - Panics in debug builds if `buf.capacity() > alloc_size` (unexpected reallocation).
83    /// - Subtracts `(alloc_size - buf.len()) * size_of::<T>()` from the budget (the pre-charged
84    ///   bytes that correspond to capacity that was never filled).
85    #[inline]
86    pub(crate) fn adjust_alloc<T>(&mut self, buf: &Vec<T>, alloc_size: usize) {
87        debug_assert!(
88            buf.capacity() <= alloc_size,
89            "Vector reallocated beyond initial allocation size ({alloc_size}); final capacity: {}",
90            buf.capacity()
91        );
92        // Return the unused portion of the pre-charged budget.
93        // alloc_size >= buf.len() is guaranteed by the assert above (capacity >= len always).
94        let unused = (alloc_size - buf.len()) * size_of::<T>();
95        // unused fits in u32: it's at most alloc_size * size_of::<T>(), which was checked to fit
96        // in u32 when alloc() was called. Using saturating_cast to avoid a fallible conversion.
97        #[expect(
98            clippy::cast_possible_truncation,
99            reason = "unused <= alloc_size * size_of::<T>() which was verified to fit in u32 by alloc()"
100        )]
101        self.budget.adjust(unused as u32);
102    }
103
104    #[must_use]
105    pub fn consumed(&self) -> u32 {
106        self.budget.consumed()
107    }
108}
109
110/// Stateful parser that enforces a memory budget during parsing (binary → raw structures).
111///
112/// The parse chain reserves memory before allocations so total heap stays within the limit.
113///
114/// ```
115/// use mlt_core::Parser;
116///
117/// # let bytes: &[u8] = &[];
118/// let mut parser = Parser::default();
119/// let layers = parser.parse_layers(bytes).expect("parse");
120///
121/// // Or with a custom limit:
122/// let mut parser = Parser::with_max_size(64 * 1024 * 1024);
123/// ```
124#[derive(Debug, Clone, PartialEq, Eq, Default)]
125pub struct Parser {
126    budget: MemBudget,
127}
128
129impl Parser {
130    /// Create a parser with a custom memory budget (in bytes).
131    #[must_use]
132    pub fn with_max_size(max_bytes: u32) -> Self {
133        Self {
134            budget: MemBudget::with_max_size(max_bytes),
135        }
136    }
137
138    /// Parse a sequence of binary layers, reserving decoded memory against this parser's budget.
139    pub fn parse_layers<'a>(&mut self, mut input: &'a [u8]) -> Result<Vec<Layer<'a>>, MltError> {
140        let mut result = Vec::new();
141        while !input.is_empty() {
142            let layer;
143            (input, layer) = Layer::from_bytes(input, self)?;
144            result.push(layer);
145        }
146        Ok(result)
147    }
148
149    /// Reserve `size` bytes from the parse budget. Used internally by the parse chain.
150    #[inline]
151    pub(crate) fn reserve(&mut self, size: u32) -> MltResult<()> {
152        self.budget.consume(size)
153    }
154
155    #[must_use]
156    pub fn reserved(&self) -> u32 {
157        self.budget.consumed()
158    }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162struct MemBudget {
163    /// Hard ceiling: total decoded bytes may not exceed this value.
164    pub max_bytes: u32,
165    /// Running total of used bytes so far.
166    pub bytes_used: u32,
167}
168
169impl Default for MemBudget {
170    /// Create a decoder with the default 10 MiB memory budget.
171    fn default() -> Self {
172        Self::with_max_size(DEFAULT_MAX_BYTES)
173    }
174}
175
176impl MemBudget {
177    /// Create a decoder with a custom memory budget (in bytes).
178    #[must_use]
179    fn with_max_size(max_bytes: u32) -> Self {
180        Self {
181            max_bytes,
182            bytes_used: 0,
183        }
184    }
185
186    /// Adjust previous consumption by `- adjustment` bytes.  Will panic if used incorrectly.
187    #[inline]
188    fn adjust(&mut self, adjustment: u32) {
189        self.bytes_used = self.bytes_used.checked_sub(adjustment).unwrap();
190    }
191
192    /// Take `size` bytes from the allocation budget. Call this before the actual allocation.
193    #[inline]
194    fn consume(&mut self, size: u32) -> MltResult<()> {
195        let accumulator = &mut self.bytes_used;
196        let max_bytes = self.max_bytes;
197        if let Some(new_value) = accumulator
198            .checked_add(size)
199            .and_then(|v| if v > max_bytes { None } else { Some(v) })
200        {
201            *accumulator = new_value;
202            Ok(())
203        } else {
204            Err(MltError::MemoryLimitExceeded {
205                limit: max_bytes,
206                used: *accumulator,
207                requested: size,
208            })
209        }
210    }
211
212    fn consumed(&self) -> u32 {
213        self.bytes_used
214    }
215}
216
217#[inline]
218pub fn debug_assert_length<T>(buffer: &[T], expected_len: usize) {
219    debug_assert_eq!(
220        buffer.len(),
221        expected_len,
222        "Expected buffer to have exact length"
223    );
224}