throne/
state.rs

1use crate::matching::phrase_equal;
2use crate::string_cache::{Atom, StringCache};
3use crate::token::*;
4
5use rand::{rngs::SmallRng, seq::SliceRandom};
6
7use std::ops::Range;
8
9/// References a [Phrase] in a [State].
10#[derive(Clone, Copy, Eq, PartialEq, Debug)]
11pub struct PhraseId {
12    idx: usize,
13    rev: usize,
14}
15
16/// Stores a set of [Phrases](Phrase).
17#[derive(Clone, Debug)]
18pub struct State {
19    storage: Storage,
20    match_cache: MatchCache,
21    scratch_state: Option<ScratchState>,
22}
23
24impl State {
25    pub(crate) fn new() -> State {
26        State {
27            storage: Storage::new(),
28            match_cache: MatchCache::new(),
29            scratch_state: None,
30        }
31    }
32
33    /// Removes a phrase by [PhraseId].
34    pub fn remove(&mut self, id: PhraseId) {
35        assert!(id.rev == self.storage.rev);
36        self.remove_idx(id.idx);
37    }
38
39    pub(crate) fn remove_idx(&mut self, idx: usize) {
40        assert!(!self.is_locked());
41
42        let remove_phrase = self.storage.phrase_ranges.swap_remove(idx);
43        self.storage
44            .removed_phrase_ranges
45            .push(remove_phrase.token_range);
46
47        self.storage.rev += 1;
48    }
49
50    /// Removes the first occurrence of `phrase` from the state.
51    ///
52    /// Returns `false` if the phrase could not be found.
53    pub fn remove_phrase(&mut self, phrase: &Phrase) -> bool {
54        let remove_idx =
55            self.storage
56                .phrase_ranges
57                .iter()
58                .position(|PhraseMetadata { token_range, .. }| {
59                    phrase_equal(
60                        &self.storage.tokens[token_range.clone()],
61                        phrase,
62                        (0, 0),
63                        (0, 0),
64                    )
65                });
66
67        if let Some(remove_idx) = remove_idx {
68            self.remove_idx(remove_idx);
69            true
70        } else {
71            false
72        }
73    }
74
75    /// Removes any phrases matching the provided `pattern`.
76    ///
77    /// If `match_pattern_length` is `true`, only phrases matching the exact length of the provided
78    /// `pattern` will be removed. Otherwise, phrases longer than the provided `pattern` may be removed,
79    /// if their beginning subset matches the pattern.
80    pub fn remove_pattern<const N: usize>(
81        &mut self,
82        pattern: [Option<Atom>; N],
83        match_pattern_length: bool,
84    ) {
85        assert!(!self.is_locked());
86
87        let tokens = &mut self.storage.tokens;
88        let removed_phrase_ranges = &mut self.storage.removed_phrase_ranges;
89        let mut did_remove_tokens = false;
90
91        self.storage
92            .phrase_ranges
93            .retain(|PhraseMetadata { token_range, .. }| {
94                let phrase = &tokens[token_range.clone()];
95                if !test_phrase_pattern_match(phrase, pattern, match_pattern_length) {
96                    return true;
97                }
98
99                removed_phrase_ranges.push(token_range.clone());
100                did_remove_tokens = true;
101
102                false
103            });
104
105        if did_remove_tokens {
106            self.storage.rev += 1;
107        }
108    }
109
110    pub(crate) fn clear_removed_tokens(&mut self) {
111        self.storage
112            .removed_phrase_ranges
113            .sort_unstable_by_key(|range| std::cmp::Reverse(range.start));
114        for remove_range in self.storage.removed_phrase_ranges.drain(..) {
115            let remove_len = remove_range.end - remove_range.start;
116            self.storage
117                .tokens
118                .drain(remove_range.start..remove_range.end);
119            for PhraseMetadata { token_range, .. } in self.storage.phrase_ranges.iter_mut() {
120                if token_range.start >= remove_range.end {
121                    token_range.start -= remove_len;
122                    token_range.end -= remove_len;
123                }
124            }
125        }
126    }
127
128    pub(crate) fn update_cache(&mut self) {
129        self.match_cache.update_storage(&self.storage);
130    }
131
132    pub(crate) fn match_cached_state_indices_for_rule_input(
133        &self,
134        input_phrase: &Phrase,
135        input_phrase_group_count: usize,
136    ) -> &[usize] {
137        assert!(self.match_cache.storage_rev == self.storage.rev);
138        debug_assert_eq!(input_phrase.groups().count(), input_phrase_group_count);
139        self.match_cache
140            .match_rule_input(input_phrase, input_phrase_group_count)
141    }
142
143    pub(crate) fn shuffle(&mut self, rng: &mut SmallRng) {
144        assert!(self.scratch_state.is_none());
145        self.storage.phrase_ranges.shuffle(rng);
146        self.storage.rev += 1;
147    }
148
149    /// Adds a new `phrase` to the `State` and returns a [PhraseId] referencing the newly added phrase.
150    pub fn push(&mut self, phrase: Vec<Token>) -> PhraseId {
151        let group_count = phrase.groups().count();
152        self.push_with_metadata(phrase, group_count)
153    }
154
155    pub(crate) fn push_with_metadata(
156        &mut self,
157        mut phrase: Vec<Token>,
158        group_count: usize,
159    ) -> PhraseId {
160        let first_group_is_single_token = phrase[0].open_depth == 1;
161        let first_atom = if first_group_is_single_token && is_concrete_pred(&phrase) {
162            Some(phrase[0].atom)
163        } else {
164            None
165        };
166
167        let start = self.storage.tokens.len();
168        self.storage.tokens.append(&mut phrase);
169        let end = self.storage.tokens.len();
170
171        self.storage.phrase_ranges.push(PhraseMetadata {
172            token_range: Range { start, end },
173            first_atom,
174            group_count,
175        });
176        self.storage.rev += 1;
177
178        let id = PhraseId {
179            idx: self.storage.phrase_ranges.len() - 1,
180            rev: self.storage.rev,
181        };
182
183        id
184    }
185
186    /// Returns the number of phrases in the `State`.
187    pub fn len(&self) -> usize {
188        self.storage.phrase_ranges.len()
189    }
190
191    /// Returns an iterator of references to phrases in the `State`.
192    pub fn iter(&self) -> impl Iterator<Item = PhraseId> + '_ {
193        self.storage.iter()
194    }
195
196    /// Returns the [Phrase] referenced by the provided [PhraseId].
197    pub fn get(&self, id: PhraseId) -> &Phrase {
198        self.storage.get(id)
199    }
200
201    /// Constructs and returns a [Vec] of all phrases in the `State`.
202    pub fn get_all(&self) -> Vec<Vec<Token>> {
203        self.storage
204            .phrase_ranges
205            .iter()
206            .map(|PhraseMetadata { token_range, .. }| {
207                self.storage.tokens[token_range.clone()].to_vec()
208            })
209            .collect::<Vec<_>>()
210    }
211
212    /// Returns an iterator of references to phrases matching the provided `pattern`.
213    ///
214    /// If `match_pattern_length` is `true`, only phrases matching the exact length of the provided
215    /// `pattern` will be returned. Otherwise, phrases longer than the provided `pattern` may be returned,
216    /// if their beginning subset matches the pattern.
217    pub fn iter_pattern<const N: usize>(
218        &self,
219        pattern: [Option<Atom>; N],
220        match_pattern_length: bool,
221    ) -> impl Iterator<Item = PhraseId> + '_ {
222        self.iter().filter(move |phrase_id| {
223            test_phrase_pattern_match(self.get(*phrase_id), pattern, match_pattern_length)
224        })
225    }
226
227    #[cfg(test)]
228    pub(crate) fn from_phrases(phrases: &[Vec<Token>]) -> State {
229        let mut state = State::new();
230        for p in phrases {
231            state.push(p.clone());
232        }
233        state
234    }
235
236    pub(crate) fn lock_scratch(&mut self) {
237        self.scratch_state = Some(ScratchState {
238            storage_phrase_ranges_len: self.storage.phrase_ranges.len(),
239            storage_tokens_len: self.storage.tokens.len(),
240            storage_rev: self.storage.rev,
241        });
242    }
243
244    pub(crate) fn unlock_scratch(&mut self) {
245        self.reset_scratch();
246        self.scratch_state = None;
247    }
248
249    pub(crate) fn reset_scratch(&mut self) {
250        let ScratchState {
251            storage_phrase_ranges_len,
252            storage_tokens_len,
253            storage_rev,
254            ..
255        } = self.scratch_state.as_ref().expect("scratch_state");
256        self.storage
257            .phrase_ranges
258            .drain(storage_phrase_ranges_len..);
259        self.storage.tokens.drain(storage_tokens_len..);
260        self.storage.rev = *storage_rev;
261    }
262
263    fn is_locked(&self) -> bool {
264        self.scratch_state.is_some()
265    }
266}
267
268impl std::ops::Index<usize> for State {
269    type Output = [Token];
270
271    fn index(&self, i: usize) -> &Phrase {
272        self.storage.get_by_metadata(&self.storage.phrase_ranges[i])
273    }
274}
275
276impl std::fmt::Display for State {
277    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278        write!(f, "{:?}", self.get_all())
279    }
280}
281
282#[derive(Clone, Debug)]
283struct ScratchState {
284    storage_phrase_ranges_len: usize,
285    storage_tokens_len: usize,
286    storage_rev: usize,
287}
288
289#[derive(Clone, Debug)]
290struct Storage {
291    // indexes into token collection
292    phrase_ranges: Vec<PhraseMetadata>,
293    removed_phrase_ranges: Vec<Range<usize>>,
294
295    // collection of all tokens found in the state phrases
296    tokens: Vec<Token>,
297
298    // increments on mutation
299    rev: usize,
300}
301
302impl Storage {
303    fn new() -> Self {
304        Storage {
305            phrase_ranges: vec![],
306            removed_phrase_ranges: vec![],
307            tokens: vec![],
308            rev: 0,
309        }
310    }
311
312    fn iter<'a>(&'a self) -> impl Iterator<Item = PhraseId> + 'a {
313        let rev = self.rev;
314        self.phrase_ranges
315            .iter()
316            .enumerate()
317            .map(move |(idx, _)| PhraseId { idx, rev })
318    }
319
320    fn get(&self, id: PhraseId) -> &Phrase {
321        assert!(id.rev == self.rev);
322        self.get_by_metadata(&self.phrase_ranges[id.idx])
323    }
324
325    fn get_by_metadata(&self, metadata: &PhraseMetadata) -> &Phrase {
326        &self.tokens[metadata.token_range.clone()]
327    }
328}
329
330#[derive(Clone, Debug)]
331struct PhraseMetadata {
332    token_range: Range<usize>,
333    first_atom: Option<Atom>,
334    group_count: usize,
335}
336
337#[derive(Clone, Debug)]
338struct MatchCache {
339    first_atom_pairs: Vec<(Atom, usize)>,
340    first_atom_indices: Vec<usize>,
341    state_indices_by_length: Vec<Vec<usize>>,
342    storage_rev: usize,
343}
344
345impl MatchCache {
346    fn new() -> Self {
347        MatchCache {
348            first_atom_pairs: vec![],
349            first_atom_indices: vec![],
350            state_indices_by_length: vec![],
351            storage_rev: 0,
352        }
353    }
354
355    fn clear(&mut self) {
356        self.first_atom_pairs.clear();
357        self.first_atom_indices.clear();
358        self.state_indices_by_length.clear();
359    }
360
361    fn update_storage(&mut self, storage: &Storage) {
362        if self.storage_rev == storage.rev {
363            return;
364        }
365        self.storage_rev = storage.rev;
366
367        self.clear();
368        for (s_i, phrase_metadata) in storage.phrase_ranges.iter().enumerate() {
369            if let Some(first_atom) = phrase_metadata.first_atom {
370                self.first_atom_pairs.push((first_atom, s_i));
371            }
372            if self.state_indices_by_length.len() < phrase_metadata.group_count + 1 {
373                self.state_indices_by_length
374                    .resize(phrase_metadata.group_count + 1, vec![]);
375            }
376            self.state_indices_by_length[phrase_metadata.group_count].push(s_i);
377        }
378        self.first_atom_pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0));
379        for (_, s_i) in &self.first_atom_pairs {
380            self.first_atom_indices.push(*s_i);
381        }
382    }
383
384    fn match_rule_input(&self, input_phrase: &Phrase, input_phrase_group_count: usize) -> &[usize] {
385        let first_group_is_single_token = input_phrase[0].open_depth == 1;
386        if first_group_is_single_token && is_concrete_pred(input_phrase) {
387            let input_first_atom = input_phrase[0].atom;
388            if let Ok(idx) = self
389                .first_atom_pairs
390                .binary_search_by(|(atom, _)| atom.cmp(&input_first_atom))
391            {
392                // binary search won't always find the first match,
393                // so search backwards until we find it
394                let start_idx = self
395                    .first_atom_pairs
396                    .iter()
397                    .enumerate()
398                    .rev()
399                    .skip(self.first_atom_pairs.len() - 1 - idx)
400                    .take_while(|(_, (atom, _))| *atom == input_first_atom)
401                    .last()
402                    .expect("start idx")
403                    .0;
404                let end_idx = self
405                    .first_atom_pairs
406                    .iter()
407                    .enumerate()
408                    .skip(idx)
409                    .take_while(|(_, (atom, _))| *atom == input_first_atom)
410                    .last()
411                    .expect("end idx")
412                    .0;
413                return &self.first_atom_indices[start_idx..end_idx + 1];
414            } else {
415                return &[];
416            };
417        }
418
419        if let Some(v) = &self.state_indices_by_length.get(input_phrase_group_count) {
420            v
421        } else {
422            &[]
423        }
424    }
425}
426
427pub(crate) fn state_to_string(state: &State, string_cache: &StringCache) -> String {
428    state
429        .iter()
430        .map(|phrase_id| state.get(phrase_id).to_string(string_cache))
431        .collect::<Vec<_>>()
432        .join("\n")
433}