1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
//!
//! The KeyGroups module contains all of the logic for working with key groups. Nothing
//! from here should be re-exported
//!
use core::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use serde::{Serialize, Deserialize};
use super::key::{*};
use super::database::{*};
use super::records::{*};
use super::sym_spell::{*};
use super::table_config::{*};
use super::perf_counters::{*};
use crate::Coder;
/// A unique identifier for a key group, which includes its RecordID
///
/// Lower 44 bits are the RecordID, upper 20 bits are the GroupID
#[derive(Copy, Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, derive_more::Display, Serialize, Deserialize)]
pub struct KeyGroupID(usize);
impl KeyGroupID {
pub fn from_record_and_idx(record_id : RecordID, group_idx : usize) -> Self {
//Panic if we have more than the allowed number of records
if record_id.0 > 0xFFFFFFFFFFF {
panic!("too many records!");
}
let record_component = record_id.0 & 0xFFFFFFFFFFF;
let group_component = group_idx << 44;
Self(record_component + group_component)
}
pub fn record_id(&self) -> RecordID {
RecordID::from(self.0 & 0xFFFFFFFFFFF)
}
pub fn group_idx(&self) -> usize {
self.0 >> 44
}
pub fn to_le_bytes(self) -> [u8; 8] {
self.0.to_le_bytes()
}
}
/// A transient struct to keep track of the multiple key groups, reflecting what's in the DB.
/// Used when assembling key groups or adding new keys to existing groups
#[derive(Clone)]
pub struct KeyGroups<OwnedKeyT, const UTF8_KEYS : bool> {
variant_reverse_lookup_map : HashMap<Vec<u8>, usize>,
pub key_group_variants : Vec<HashSet<Vec<u8>>>,
pub key_group_keys : Vec<HashSet<OwnedKeyT>>,
pub group_ids : Vec<usize> //The contents of this vec correspond to the KeyGroupID
}
impl <OwnedKeyT, const UTF8_KEYS : bool>KeyGroups<OwnedKeyT, UTF8_KEYS> {
fn new() -> Self {
Self{
variant_reverse_lookup_map : HashMap::new(),
key_group_variants : vec![],
key_group_keys : vec![],
group_ids : vec![]
}
}
fn next_available_group_id(&self) -> usize {
//It doesn't matter if we leave some holes, but we must not collide, therefore we'll
//start at the length of the vec, and search forward from there
let mut group_id = self.group_ids.len();
while self.group_ids.contains(&group_id) {
group_id += 1;
}
group_id
}
/// Adds a new key to a KeyGroups transient structure. Doesn't touch the DB
///
/// This function is the owner of the decision whether or not to add a key to an existing
/// group or to create a new group for a key
pub fn add_key_to_groups<KeyCharT : Clone, K, ConfigT : TableConfig>(&mut self, key : &K, update_reverse_map : bool, config : &ConfigT) -> Result<(), String>
where
OwnedKeyT : OwnedKey<KeyCharT = KeyCharT>,
K : Key<KeyCharT = KeyCharT>
{
//Make sure the key is within the maximum allowable MAX_KEY_LENGTH
if key.num_chars() > MAX_KEY_LENGTH {
return Err("key length exceeds MAX_KEY_LENGTH".to_string());
}
//Compute the variants for the key
let key_variants = SymSpell::<OwnedKeyT, UTF8_KEYS>::variants(key, config);
//Variables that determine which group we merge into, or whether we create a new key group
let mut group_idx; //The index of the key group we'll merge this key into
let create_new_group;
//If we already have exactly this key as a variant, then we will add the key to that
// key group
if let Some(existing_group) = self.variant_reverse_lookup_map.get(key.as_bytes()) {
group_idx = *existing_group;
create_new_group = false;
} else {
// We'll need to figure out which existing key_group is the best place for this key
// Or if we'll create a new key_group instead
//If ConfigT::GROUP_VARIANT_OVERLAP_THRESHOLD == 0, then we always add the key to group 0,
//which is the record's only group. So checking the overlap with existing groups is a
//waste of time
if ConfigT::GROUP_VARIANT_OVERLAP_THRESHOLD > 0 {
//Count the number of overlapping variants the key has with each existing group
// NOTE: It's possible the variant_reverse_lookup_map doesn't capture all of the
// different groups containing a given variant. This could happen if we chose not
// to merge variants for any reason, like exceeding a max number of keys in a key group,
// It could also happen if a variant set ends up overlapping two previously disjoint
// variant sets. The only way to avoid that would be to merge the two existing key
// groups into a single key group, but we don't have logic to merge existing key groups,
// only to append new keys and the key's variants to a group.
// Since the whole key groups logic is just an optimization, this edge case will not
// affect the correctness of the results.
let mut overlap_counts : Vec<usize> = vec![0; self.key_group_keys.len()];
for variant in key_variants.iter() {
//See if it's already part of another key's variant list
if let Some(existing_group) = self.variant_reverse_lookup_map.get(&variant[..]) {
overlap_counts[*existing_group] += 1;
}
}
let (max_group_idx, max_overlaps) = overlap_counts.into_iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap_or((0, 0));
group_idx = max_group_idx;
create_new_group = max_overlaps < ConfigT::GROUP_VARIANT_OVERLAP_THRESHOLD; //Unless we have at least GROUP_VARIANT_OVERLAP_THRESHOLD variant overlaps we'll make a new key group.
} else {
group_idx = 0;
create_new_group = self.key_group_keys.is_empty();
}
}
//Make a decision about whether to:
//A.) Use the key as the start of a new key group, or
//B.) Combine the key and its variant into an existing group
if create_new_group {
//A. We have no overlap with any existing group, so we will create a new group for this key
group_idx = self.key_group_keys.len();
let mut new_set = HashSet::with_capacity(1);
new_set.insert(OwnedKeyT::from_key(key));
self.key_group_keys.push(new_set);
self.key_group_variants.push(key_variants.clone());
//We can't count on the KeyGroupIDs not having holes so we need to use a function to
//find a unique ID.
let new_group_id = self.next_available_group_id();
self.group_ids.push(new_group_id);
} else {
//B. We will append the key to the existing group at group_index, and merge the variants
self.key_group_keys[group_idx].insert(OwnedKeyT::from_key(key));
self.key_group_variants[group_idx].extend(key_variants.clone());
}
//If we're not at the last key in the list, add the variants to the variant_reverse_lookup_map
if update_reverse_map {
for variant in key_variants {
self.variant_reverse_lookup_map.insert(variant, group_idx);
}
}
Ok(())
}
/// Divides a list of keys up into one or more key groups based on some criteria; the primary
/// of which is the overlap between key variants. Keys with more overlapping variants are more
/// likely to belong in the same group and keys with fewer or none are less likely.
pub fn make_groups_from_keys<'a, KeyCharT : Clone, K, KeysIterT : Iterator<Item=&'a K>, ConfigT : TableConfig>(keys_iter : KeysIterT, num_keys : usize, config : &ConfigT) -> Result<Self, String>
where
OwnedKeyT : OwnedKey<KeyCharT = KeyCharT>,
K : Key<KeyCharT = KeyCharT> + 'a
{
//Start with empty key groups, and add the keys one at a time
let mut groups = KeyGroups::new();
for (key_idx, key) in keys_iter.enumerate() {
let update_reverse_map = key_idx < num_keys-1;
groups.add_key_to_groups(key, update_reverse_map, config)?;
}
Ok(groups)
}
/// Loads the existing key groups for a record in the [Table]
///
/// This function is used when adding new keys to a record, and figuring out which groups to
/// merge the keys into
pub fn load_key_groups<KeyCharT : Clone, ConfigT : TableConfig, C>(db : &DBConnection<C>, record_id : RecordID, config : &ConfigT, perf_counters : &PerfCounters) -> Result<Self, String>
where
OwnedKeyT : OwnedKey<KeyCharT = KeyCharT>,
C: Coder + Send + Sync + 'static,
{
let mut groups = KeyGroups::new();
//Load the group indices from the rec_data table and loop over each key group
for (group_idx, key_group) in db.get_record_key_groups(record_id)?.enumerate() {
let mut group_keys = HashSet::new();
let mut group_variants = HashSet::new();
//Load the group's keys and loop over each one
for key in db.get_keys_in_group::<OwnedKeyT>(key_group, perf_counters)? {
//Compute the variants for the key, and merge them into the group variants
let key_variants = SymSpell::<OwnedKeyT, UTF8_KEYS>::variants(&key, config);
//Update the reverse_lookup_map with every variant
for variant in key_variants.iter() {
groups.variant_reverse_lookup_map.insert(variant.clone(), group_idx);
}
//Push this key into the group's key list
group_keys.insert(key);
//Merge this key's variants with the other variants in this group
group_variants.extend(key_variants);
}
groups.key_group_variants.push(group_variants);
groups.key_group_keys.push(group_keys);
groups.group_ids.push(key_group.group_idx());
}
Ok(groups)
}
}