1use crate::matching::phrase_equal;
2use crate::string_cache::{Atom, StringCache};
3use crate::token::*;
4
5use rand::{rngs::SmallRng, seq::SliceRandom};
6
7use std::ops::Range;
8
9#[derive(Clone, Copy, Eq, PartialEq, Debug)]
11pub struct PhraseId {
12 idx: usize,
13 rev: usize,
14}
15
16#[derive(Clone, Debug)]
18pub struct State {
19 storage: Storage,
20 match_cache: MatchCache,
21 scratch_state: Option<ScratchState>,
22}
23
24impl State {
25 pub(crate) fn new() -> State {
26 State {
27 storage: Storage::new(),
28 match_cache: MatchCache::new(),
29 scratch_state: None,
30 }
31 }
32
33 pub fn remove(&mut self, id: PhraseId) {
35 assert!(id.rev == self.storage.rev);
36 self.remove_idx(id.idx);
37 }
38
39 pub(crate) fn remove_idx(&mut self, idx: usize) {
40 assert!(!self.is_locked());
41
42 let remove_phrase = self.storage.phrase_ranges.swap_remove(idx);
43 self.storage
44 .removed_phrase_ranges
45 .push(remove_phrase.token_range);
46
47 self.storage.rev += 1;
48 }
49
50 pub fn remove_phrase(&mut self, phrase: &Phrase) -> bool {
54 let remove_idx =
55 self.storage
56 .phrase_ranges
57 .iter()
58 .position(|PhraseMetadata { token_range, .. }| {
59 phrase_equal(
60 &self.storage.tokens[token_range.clone()],
61 phrase,
62 (0, 0),
63 (0, 0),
64 )
65 });
66
67 if let Some(remove_idx) = remove_idx {
68 self.remove_idx(remove_idx);
69 true
70 } else {
71 false
72 }
73 }
74
75 pub fn remove_pattern<const N: usize>(
81 &mut self,
82 pattern: [Option<Atom>; N],
83 match_pattern_length: bool,
84 ) {
85 assert!(!self.is_locked());
86
87 let tokens = &mut self.storage.tokens;
88 let removed_phrase_ranges = &mut self.storage.removed_phrase_ranges;
89 let mut did_remove_tokens = false;
90
91 self.storage
92 .phrase_ranges
93 .retain(|PhraseMetadata { token_range, .. }| {
94 let phrase = &tokens[token_range.clone()];
95 if !test_phrase_pattern_match(phrase, pattern, match_pattern_length) {
96 return true;
97 }
98
99 removed_phrase_ranges.push(token_range.clone());
100 did_remove_tokens = true;
101
102 false
103 });
104
105 if did_remove_tokens {
106 self.storage.rev += 1;
107 }
108 }
109
110 pub(crate) fn clear_removed_tokens(&mut self) {
111 self.storage
112 .removed_phrase_ranges
113 .sort_unstable_by_key(|range| std::cmp::Reverse(range.start));
114 for remove_range in self.storage.removed_phrase_ranges.drain(..) {
115 let remove_len = remove_range.end - remove_range.start;
116 self.storage
117 .tokens
118 .drain(remove_range.start..remove_range.end);
119 for PhraseMetadata { token_range, .. } in self.storage.phrase_ranges.iter_mut() {
120 if token_range.start >= remove_range.end {
121 token_range.start -= remove_len;
122 token_range.end -= remove_len;
123 }
124 }
125 }
126 }
127
128 pub(crate) fn update_cache(&mut self) {
129 self.match_cache.update_storage(&self.storage);
130 }
131
132 pub(crate) fn match_cached_state_indices_for_rule_input(
133 &self,
134 input_phrase: &Phrase,
135 input_phrase_group_count: usize,
136 ) -> &[usize] {
137 assert!(self.match_cache.storage_rev == self.storage.rev);
138 debug_assert_eq!(input_phrase.groups().count(), input_phrase_group_count);
139 self.match_cache
140 .match_rule_input(input_phrase, input_phrase_group_count)
141 }
142
143 pub(crate) fn shuffle(&mut self, rng: &mut SmallRng) {
144 assert!(self.scratch_state.is_none());
145 self.storage.phrase_ranges.shuffle(rng);
146 self.storage.rev += 1;
147 }
148
149 pub fn push(&mut self, phrase: Vec<Token>) -> PhraseId {
151 let group_count = phrase.groups().count();
152 self.push_with_metadata(phrase, group_count)
153 }
154
155 pub(crate) fn push_with_metadata(
156 &mut self,
157 mut phrase: Vec<Token>,
158 group_count: usize,
159 ) -> PhraseId {
160 let first_group_is_single_token = phrase[0].open_depth == 1;
161 let first_atom = if first_group_is_single_token && is_concrete_pred(&phrase) {
162 Some(phrase[0].atom)
163 } else {
164 None
165 };
166
167 let start = self.storage.tokens.len();
168 self.storage.tokens.append(&mut phrase);
169 let end = self.storage.tokens.len();
170
171 self.storage.phrase_ranges.push(PhraseMetadata {
172 token_range: Range { start, end },
173 first_atom,
174 group_count,
175 });
176 self.storage.rev += 1;
177
178 let id = PhraseId {
179 idx: self.storage.phrase_ranges.len() - 1,
180 rev: self.storage.rev,
181 };
182
183 id
184 }
185
186 pub fn len(&self) -> usize {
188 self.storage.phrase_ranges.len()
189 }
190
191 pub fn iter(&self) -> impl Iterator<Item = PhraseId> + '_ {
193 self.storage.iter()
194 }
195
196 pub fn get(&self, id: PhraseId) -> &Phrase {
198 self.storage.get(id)
199 }
200
201 pub fn get_all(&self) -> Vec<Vec<Token>> {
203 self.storage
204 .phrase_ranges
205 .iter()
206 .map(|PhraseMetadata { token_range, .. }| {
207 self.storage.tokens[token_range.clone()].to_vec()
208 })
209 .collect::<Vec<_>>()
210 }
211
212 pub fn iter_pattern<const N: usize>(
218 &self,
219 pattern: [Option<Atom>; N],
220 match_pattern_length: bool,
221 ) -> impl Iterator<Item = PhraseId> + '_ {
222 self.iter().filter(move |phrase_id| {
223 test_phrase_pattern_match(self.get(*phrase_id), pattern, match_pattern_length)
224 })
225 }
226
227 #[cfg(test)]
228 pub(crate) fn from_phrases(phrases: &[Vec<Token>]) -> State {
229 let mut state = State::new();
230 for p in phrases {
231 state.push(p.clone());
232 }
233 state
234 }
235
236 pub(crate) fn lock_scratch(&mut self) {
237 self.scratch_state = Some(ScratchState {
238 storage_phrase_ranges_len: self.storage.phrase_ranges.len(),
239 storage_tokens_len: self.storage.tokens.len(),
240 storage_rev: self.storage.rev,
241 });
242 }
243
244 pub(crate) fn unlock_scratch(&mut self) {
245 self.reset_scratch();
246 self.scratch_state = None;
247 }
248
249 pub(crate) fn reset_scratch(&mut self) {
250 let ScratchState {
251 storage_phrase_ranges_len,
252 storage_tokens_len,
253 storage_rev,
254 ..
255 } = self.scratch_state.as_ref().expect("scratch_state");
256 self.storage
257 .phrase_ranges
258 .drain(storage_phrase_ranges_len..);
259 self.storage.tokens.drain(storage_tokens_len..);
260 self.storage.rev = *storage_rev;
261 }
262
263 fn is_locked(&self) -> bool {
264 self.scratch_state.is_some()
265 }
266}
267
268impl std::ops::Index<usize> for State {
269 type Output = [Token];
270
271 fn index(&self, i: usize) -> &Phrase {
272 self.storage.get_by_metadata(&self.storage.phrase_ranges[i])
273 }
274}
275
276impl std::fmt::Display for State {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 write!(f, "{:?}", self.get_all())
279 }
280}
281
282#[derive(Clone, Debug)]
283struct ScratchState {
284 storage_phrase_ranges_len: usize,
285 storage_tokens_len: usize,
286 storage_rev: usize,
287}
288
289#[derive(Clone, Debug)]
290struct Storage {
291 phrase_ranges: Vec<PhraseMetadata>,
293 removed_phrase_ranges: Vec<Range<usize>>,
294
295 tokens: Vec<Token>,
297
298 rev: usize,
300}
301
302impl Storage {
303 fn new() -> Self {
304 Storage {
305 phrase_ranges: vec![],
306 removed_phrase_ranges: vec![],
307 tokens: vec![],
308 rev: 0,
309 }
310 }
311
312 fn iter<'a>(&'a self) -> impl Iterator<Item = PhraseId> + 'a {
313 let rev = self.rev;
314 self.phrase_ranges
315 .iter()
316 .enumerate()
317 .map(move |(idx, _)| PhraseId { idx, rev })
318 }
319
320 fn get(&self, id: PhraseId) -> &Phrase {
321 assert!(id.rev == self.rev);
322 self.get_by_metadata(&self.phrase_ranges[id.idx])
323 }
324
325 fn get_by_metadata(&self, metadata: &PhraseMetadata) -> &Phrase {
326 &self.tokens[metadata.token_range.clone()]
327 }
328}
329
330#[derive(Clone, Debug)]
331struct PhraseMetadata {
332 token_range: Range<usize>,
333 first_atom: Option<Atom>,
334 group_count: usize,
335}
336
337#[derive(Clone, Debug)]
338struct MatchCache {
339 first_atom_pairs: Vec<(Atom, usize)>,
340 first_atom_indices: Vec<usize>,
341 state_indices_by_length: Vec<Vec<usize>>,
342 storage_rev: usize,
343}
344
345impl MatchCache {
346 fn new() -> Self {
347 MatchCache {
348 first_atom_pairs: vec![],
349 first_atom_indices: vec![],
350 state_indices_by_length: vec![],
351 storage_rev: 0,
352 }
353 }
354
355 fn clear(&mut self) {
356 self.first_atom_pairs.clear();
357 self.first_atom_indices.clear();
358 self.state_indices_by_length.clear();
359 }
360
361 fn update_storage(&mut self, storage: &Storage) {
362 if self.storage_rev == storage.rev {
363 return;
364 }
365 self.storage_rev = storage.rev;
366
367 self.clear();
368 for (s_i, phrase_metadata) in storage.phrase_ranges.iter().enumerate() {
369 if let Some(first_atom) = phrase_metadata.first_atom {
370 self.first_atom_pairs.push((first_atom, s_i));
371 }
372 if self.state_indices_by_length.len() < phrase_metadata.group_count + 1 {
373 self.state_indices_by_length
374 .resize(phrase_metadata.group_count + 1, vec![]);
375 }
376 self.state_indices_by_length[phrase_metadata.group_count].push(s_i);
377 }
378 self.first_atom_pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0));
379 for (_, s_i) in &self.first_atom_pairs {
380 self.first_atom_indices.push(*s_i);
381 }
382 }
383
384 fn match_rule_input(&self, input_phrase: &Phrase, input_phrase_group_count: usize) -> &[usize] {
385 let first_group_is_single_token = input_phrase[0].open_depth == 1;
386 if first_group_is_single_token && is_concrete_pred(input_phrase) {
387 let input_first_atom = input_phrase[0].atom;
388 if let Ok(idx) = self
389 .first_atom_pairs
390 .binary_search_by(|(atom, _)| atom.cmp(&input_first_atom))
391 {
392 let start_idx = self
395 .first_atom_pairs
396 .iter()
397 .enumerate()
398 .rev()
399 .skip(self.first_atom_pairs.len() - 1 - idx)
400 .take_while(|(_, (atom, _))| *atom == input_first_atom)
401 .last()
402 .expect("start idx")
403 .0;
404 let end_idx = self
405 .first_atom_pairs
406 .iter()
407 .enumerate()
408 .skip(idx)
409 .take_while(|(_, (atom, _))| *atom == input_first_atom)
410 .last()
411 .expect("end idx")
412 .0;
413 return &self.first_atom_indices[start_idx..end_idx + 1];
414 } else {
415 return &[];
416 };
417 }
418
419 if let Some(v) = &self.state_indices_by_length.get(input_phrase_group_count) {
420 v
421 } else {
422 &[]
423 }
424 }
425}
426
427pub(crate) fn state_to_string(state: &State, string_cache: &StringCache) -> String {
428 state
429 .iter()
430 .map(|phrase_id| state.get(phrase_id).to_string(string_cache))
431 .collect::<Vec<_>>()
432 .join("\n")
433}