1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
use crate::{engine::bytecode::Register, Error};
use core::mem;
use std::{
    collections::{btree_map, BTreeMap},
    vec::Vec,
};

#[cfg(doc)]
use super::ProviderStack;

/// The index of a `local.get` on the [`ProviderStack`].
pub type StackIndex = usize;

/// The index of an entry in the [`LocalRefs`] data structure.
type EntryIndex = usize;

/// Data structure to store local indices on the compilation stack for large stacks.
///
/// # Note
///
/// - The main purpose is to provide efficient implementations to preserve locals on
///   the compilation stack for single local and mass local preservation.
/// - Also this data structure is critical to not be attackable by malicious actors
///   when operating on very large stacks or local variable quantities.
#[derive(Debug, Default)]
pub struct LocalRefs {
    /// The last local added to [`LocalRefs`] per local variable if any.
    locals_last: BTreeMap<Register, EntryIndex>,
    /// The entries of the [`LocalRefs`] data structure.
    entries: LocalRefsEntries,
}

/// The entries of the [`LocalRefs`] data structure.
///
/// # Note
///
/// This type mostly exists to gracefully resolve some borrow-checking issues
/// when operating on parts of the fields of the [`LocalRefs`] while `locals_last`
/// is borrowed.
#[derive(Debug, Default)]
pub struct LocalRefsEntries {
    /// The index of the next free (vacant) entry.
    next_free: Option<EntryIndex>,
    /// All entries of the [`LocalRefs`] data structure.
    entries: Vec<LocalRefEntry>,
}

impl LocalRefsEntries {
    /// Resets the [`LocalRefs`].
    pub fn reset(&mut self) {
        self.next_free = None;
        self.entries.clear();
    }

    /// Returns the next free [`EntryIndex`] for reusing vacant entries.
    #[inline]
    pub fn next_free(&self) -> Option<EntryIndex> {
        self.next_free
    }

    /// Returns the next [`EntryIndex`] for the next new non-reused entry.
    #[inline]
    pub fn next_index(&self) -> EntryIndex {
        self.entries.len()
    }

    /// Pushes an occupied entry to the [`LocalRefsEntries`].
    #[inline]
    pub fn push_occupied(&mut self, slot: StackIndex, prev: Option<EntryIndex>) -> EntryIndex {
        let index = self.next_index();
        self.entries.push(LocalRefEntry::Occupied { slot, prev });
        index
    }

    /// Reuses the vacant entry at `index` for a new occupied entry.
    ///
    /// # Panics
    ///
    /// If the entry at `index` is not vacant.
    #[inline]
    pub fn reuse_vacant(&mut self, index: EntryIndex, slot: StackIndex, prev: Option<EntryIndex>) {
        let old_entry = mem::replace(
            &mut self.entries[index],
            LocalRefEntry::Occupied { slot, prev },
        );
        self.next_free = match old_entry {
            LocalRefEntry::Vacant { next_free } => next_free,
            occupied @ LocalRefEntry::Occupied { .. } => {
                panic!("tried to reuse occupied entry at index {index}: {occupied:?}")
            }
        };
    }

    /// Removes the entry at the given `index`.
    ///
    /// Returns the entry index of the next entry in the list and the
    /// [`StackIndex`] associated to the removed entry.
    #[inline]
    fn remove_entry(&mut self, index: EntryIndex) -> (Option<EntryIndex>, StackIndex) {
        let next_free = self.next_free();
        let old_entry = mem::replace(
            &mut self.entries[index],
            LocalRefEntry::Vacant { next_free },
        );
        let LocalRefEntry::Occupied { prev, slot } = old_entry else {
            panic!("expected occupied entry but found vacant: {old_entry:?}");
        };
        self.next_free = Some(index);
        (prev, slot)
    }
}

/// An entry representing a local variable on the compilation stack or a vacant entry.
#[derive(Debug, Copy, Clone)]
enum LocalRefEntry {
    Vacant {
        /// The next free slot of the [`LocalRefs`] data structure.
        next_free: Option<EntryIndex>,
    },
    Occupied {
        /// The slot index of the local variable on the compilation stack.
        slot: StackIndex,
        /// The next [`LocalRefEntry`] referencing the same local variable if any.
        prev: Option<EntryIndex>,
    },
}

impl LocalRefs {
    /// Resets the [`LocalRefs`].
    pub fn reset(&mut self) {
        self.locals_last.clear();
        self.entries.reset();
    }

    /// Registers an `amount` of function inputs or local variables.
    ///
    /// # Errors
    ///
    /// If too many registers have been registered.
    pub fn register_locals(&mut self, _amount: u32) {
        // Nothing to do here.
    }

    /// Updates the last index for `local` to `index` and returns the previous last index.
    fn update_last(&mut self, index: EntryIndex, local: Register) -> Option<EntryIndex> {
        match self.locals_last.entry(local) {
            btree_map::Entry::Vacant(entry) => {
                entry.insert(index);
                None
            }
            btree_map::Entry::Occupied(mut entry) => {
                let prev = *entry.get();
                entry.insert(index);
                Some(prev)
            }
        }
    }

    /// Pushes the stack index of a `local.get` on the [`ProviderStack`].
    ///
    /// # Panics
    ///
    /// If the `local` index is out of bounds.
    pub fn push_at(&mut self, local: Register, slot: StackIndex) {
        match self.entries.next_free() {
            Some(index) => {
                let prev = self.update_last(index, local);
                self.entries.reuse_vacant(index, slot, prev);
            }
            None => {
                let index = self.entries.next_index();
                let prev = self.update_last(index, local);
                let pushed = self.entries.push_occupied(slot, prev);
                debug_assert_eq!(pushed, index);
            }
        };
    }

    /// Returns `true` if `self` is empty.
    #[inline]
    fn is_empty(&self) -> bool {
        self.locals_last.is_empty()
    }

    /// Reset `self` if `self` is empty.
    #[inline]
    fn reset_if_empty(&mut self) {
        if self.is_empty() {
            self.entries.reset();
        }
    }

    /// Pops the stack index of a `local.get` on the [`ProviderStack`].
    ///
    /// # Panics
    ///
    /// - If the `local` index is out of bounds.
    /// - If there is no `local.get` stack index on the stack.
    pub fn pop_at(&mut self, local: Register) -> StackIndex {
        let btree_map::Entry::Occupied(mut last) = self.locals_last.entry(local) else {
            panic!("missing stack index for local on the provider stack: {local:?}")
        };
        let index = *last.get();
        let (prev, slot) = self.entries.remove_entry(index);
        match prev {
            Some(prev) => last.insert(prev),
            None => last.remove(),
        };
        self.reset_if_empty();
        slot
    }

    /// Drains all local indices of the `local` variable on the [`ProviderStack`].
    ///
    /// # Note
    ///
    /// Calls `f` with the index of each local on the [`ProviderStack`] that matches `local`.
    pub fn drain_at(
        &mut self,
        local: Register,
        f: impl FnMut(StackIndex) -> Result<(), Error>,
    ) -> Result<(), Error> {
        let Some(last) = self.locals_last.remove(&local) else {
            return Ok(());
        };
        self.drain_list_at(last, f)?;
        self.reset_if_empty();
        Ok(())
    }

    /// Drains all local indices on the [`ProviderStack`].
    ///
    /// # Note
    ///
    /// Calls `f` with the pair of local and its index of each local on the [`ProviderStack`].
    pub fn drain_all(
        &mut self,
        mut f: impl FnMut(Register, StackIndex) -> Result<(), Error>,
    ) -> Result<(), Error> {
        let local_last = mem::take(&mut self.locals_last);
        for (local, last) in &local_last {
            let local = *local;
            self.drain_list_at(*last, |index| f(local, index))?;
        }
        self.locals_last = local_last;
        self.locals_last.clear();
        self.entries.reset();
        Ok(())
    }

    /// Drains the list of locals starting at `index` at the entries array.
    #[inline]
    fn drain_list_at(
        &mut self,
        index: EntryIndex,
        mut f: impl FnMut(StackIndex) -> Result<(), Error>,
    ) -> Result<(), Error> {
        let mut last = Some(index);
        while let Some(index) = last {
            let (prev, slot) = self.entries.remove_entry(index);
            last = prev;
            f(slot)?;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn reg(index: i16) -> Register {
        Register::from_i16(index)
    }

    #[test]
    fn push_pop_works() {
        let mut locals = LocalRefs::default();
        locals.push_at(reg(0), 2);
        locals.push_at(reg(0), 4);
        locals.push_at(reg(1), 6);
        locals.push_at(reg(2), 8);
        locals.push_at(reg(5), 10);
        locals.push_at(reg(1), 12);
        locals.push_at(reg(0), 14);
        assert_eq!(locals.pop_at(reg(0)), 14);
        assert_eq!(locals.pop_at(reg(0)), 4);
        assert_eq!(locals.pop_at(reg(0)), 2);
        assert_eq!(locals.pop_at(reg(1)), 12);
        assert_eq!(locals.pop_at(reg(1)), 6);
        assert_eq!(locals.pop_at(reg(2)), 8);
        assert_eq!(locals.pop_at(reg(5)), 10);
    }
}