1use std::collections::{HashMap, HashSet};
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5
6use super::{TagAcceptor, TagSet};
7
8#[derive(Default, Serialize, Deserialize)]
9pub struct TagConfig {
10 #[serde(default)]
11 host: Option<String>,
12 #[serde(default)]
13 hosts: Vec<String>,
14 #[serde(default)]
15 alt: Option<String>,
16 #[serde(default)]
17 alts: Vec<String>,
18 #[serde(default)]
19 implies: Option<String>,
20 #[serde(default)]
21 internal: Option<String>,
22 #[serde(default)]
23 excludes: Vec<String>,
24 #[serde(default)]
25 symbol: bool,
26}
27
28#[derive(Default, Serialize, Deserialize)]
29pub struct TaggerConfig {
30 tags: HashMap<String, HashMap<String, TagConfig>>,
31}
32
33#[derive(Debug)]
34struct TagRecord {
35 output: String,
36 implies: Vec<String>,
37}
38
39#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
40struct MultiTokenTag {
41 tag: Vec<String>,
42}
43
44impl MultiTokenTag {
45 pub fn matches<T: AsRef<str> + std::cmp::PartialEq<String>>(&self, slice: &[T]) -> bool {
46 if slice.len() < self.tag.len() {
47 return false;
48 }
49 itertools::equal(
51 slice.iter().take(self.tag.len()).map(T::as_ref),
52 self.tag.iter(),
53 )
54 }
55
56 pub fn chomp<T: AsRef<str> + std::cmp::PartialEq<String>>(&self, slice: &mut &[T]) -> bool {
57 if self.matches(slice) {
58 *slice = &slice[self.tag.len()..];
59 true
60 } else {
61 false
62 }
63 }
64}
65
66#[derive(Debug)]
67pub struct StoryTagger {
69 records: Vec<TagRecord>,
70 forward: HashMap<String, usize>,
72 forward_multi: HashMap<MultiTokenTag, usize>,
74 exclusions: HashMap<MultiTokenTag, String>,
76 backward: HashMap<String, String>,
78 symbols: HashMap<String, usize>,
80}
81
82impl StoryTagger {
83 fn compute_tag(tag: &str) -> Vec<String> {
85 if tag.contains("(-)") {
87 let mut v = Self::compute_tag(&tag.replace("(-)", "-"));
88 v.extend(Self::compute_tag(&tag.replace("(-)", " ")));
89 return v;
90 }
91 if let Some(tag) = tag.strip_suffix("(s)") {
92 vec![tag.to_owned(), tag.to_owned() + "s"]
93 } else {
94 vec![tag.to_owned()]
95 }
96 }
97
98 fn compute_all_tags(
100 tag: &str,
101 alt: &Option<String>,
102 alts: &Vec<String>,
103 ) -> (String, HashSet<String>) {
104 let mut tags = HashSet::new();
105 let v = Self::compute_tag(tag);
106 let primary = v[0].clone();
107 tags.extend(v);
108 if let Some(alt) = alt {
109 tags.extend(Self::compute_tag(alt));
110 }
111 for alt in alts {
112 tags.extend(Self::compute_tag(alt));
113 }
114 (primary, tags)
115 }
116
117 pub fn new(config: &TaggerConfig) -> Self {
118 let mut new = Self {
119 forward: HashMap::new(),
120 forward_multi: HashMap::new(),
121 backward: HashMap::new(),
122 records: vec![],
123 symbols: HashMap::new(),
124 exclusions: HashMap::new(),
125 };
126 for tags in config.tags.values() {
127 for (tag, tags) in tags {
128 let (primary, all_tags) = Self::compute_all_tags(tag, &tags.alt, &tags.alts);
129 let excludes = tags
130 .excludes
131 .iter()
132 .flat_map(|s| Self::compute_tag(s))
133 .map(|s| MultiTokenTag {
134 tag: s.split_ascii_whitespace().map(str::to_owned).collect(),
135 });
136 for exclude in excludes {
137 new.exclusions.insert(exclude, primary.clone());
138 }
139 let record = TagRecord {
140 output: match tags.internal {
141 Some(ref s) => s.clone(),
142 None => primary,
143 },
144 implies: tags.implies.clone().into_iter().collect(),
145 };
146 if let Some(internal) = &tags.internal {
147 new.backward.insert(internal.clone(), tag.clone());
148 }
149 for tag in all_tags {
150 if tags.symbol {
151 new.backward.insert(record.output.clone(), tag.clone());
152 new.symbols.insert(tag, new.records.len());
153 } else if tag.contains(' ') {
154 let tag = MultiTokenTag {
155 tag: tag.split_ascii_whitespace().map(str::to_owned).collect(),
156 };
157 new.forward_multi.insert(tag, new.records.len());
158 } else {
159 new.forward.insert(tag, new.records.len());
160 }
161 }
162
163 new.records.push(record);
164 }
165 }
166
167 new
168 }
169
170 pub fn tag<T: TagAcceptor>(&self, s: &str, tags: &mut T) {
171 let s = s.to_lowercase();
172
173 let s = s.replace(
175 |c| {
176 c == '`' || c == '\u{2018}' || c == '\u{2019}' || c == '\u{201a}' || c == '\u{201b}'
177 },
178 "'",
179 );
180
181 let mut s = s.replace("'s", "");
183
184 for (symbol, rec) in &self.symbols {
186 if s.contains(symbol) {
187 s = s.replace(symbol, " ");
188 tags.tag(&self.records[*rec].output);
189 for implies in &self.records[*rec].implies {
190 tags.tag(implies);
191 }
192 }
193 }
194
195 let tokens_vec = s
197 .split_ascii_whitespace()
198 .map(|s| s.replace(|c: char| !c.is_alphanumeric() && c != '-', ""))
199 .filter(|s| !s.is_empty())
200 .collect_vec();
201 let mut tokens = tokens_vec.as_slice();
202
203 let mut mutes = HashMap::new();
204
205 'outer: while !tokens.is_empty() {
206 mutes.retain(|_k, v| {
207 if *v == 0 {
208 false
209 } else {
210 *v -= 1;
211 true
212 }
213 });
214 for (exclusion, tag) in &self.exclusions {
215 if exclusion.matches(tokens) {
216 mutes.insert(tag.clone(), exclusion.tag.len() - 1);
217 }
218 }
219 for (multi, rec) in &self.forward_multi {
220 if multi.chomp(&mut tokens) {
221 let rec = &self.records[*rec];
222 tags.tag(&rec.output);
223 for implies in &rec.implies {
224 tags.tag(implies);
225 }
226 continue 'outer;
227 }
228 }
229 if let Some(rec) = self.forward.get(&tokens[0]) {
230 if !mutes.contains_key(&tokens[0]) {
231 let rec = &self.records[*rec];
232 tags.tag(&rec.output);
233 for implies in &rec.implies {
234 tags.tag(implies);
235 }
236 }
237 }
238 tokens = &tokens[1..];
239 }
240 }
241
242 pub fn check_tag_search(&self, search: &str) -> Option<&str> {
245 let lowercase = search.to_lowercase();
246 if let Some(idx) = self.symbols.get(&lowercase) {
247 return Some(&self.records[*idx].output);
248 }
249 if let Some(idx) = self.forward.get(&lowercase) {
250 return Some(&self.records[*idx].output);
251 }
252 if let Some((k, _)) = self.backward.get_key_value(&lowercase) {
253 return Some(k.as_str());
254 }
255
256 None
257 }
258
259 pub fn make_display_tag<'a, S: AsRef<str> + 'a>(&'a self, s: S) -> String {
261 let lowercase = s.as_ref().to_lowercase();
262 if let Some(backward) = self.backward.get(&lowercase) {
263 backward.clone()
264 } else {
265 lowercase
266 }
267 }
268
269 pub fn make_display_tags<'a, S: AsRef<str>, I: IntoIterator<Item = S> + 'a>(
271 &'a self,
272 iter: I,
273 ) -> impl Iterator<Item = String> + 'a {
274 iter.into_iter().map(|s| self.make_display_tag(s))
275 }
276
277 pub fn tag_details() -> Vec<(String, TagSet)> {
278 Default::default()
288 }
289}
290
291#[cfg(test)]
292pub(crate) mod test {
293 use itertools::Itertools;
294 use rstest::*;
295 use serde_json::json;
296
297 use crate::story::TagSet;
298
299 use super::{StoryTagger, TaggerConfig};
300
301 #[fixture]
303 pub(crate) fn tagger_config() -> TaggerConfig {
304 serde_json::from_value(json!({
305 "tags": {
306 "testing": {
307 "video(s)": {"hosts": ["youtube.com", "vimeo.com"]},
308 "rust": {},
309 "chrome": {"alt": "chromium"},
310 "neovim": {"implies": "vim"},
311 "vim": {},
312 "3d": {"alts": ["3(-)d", "3(-)dimension(s)", "three(-)d", "three(-)dimension(s)", "three(-)dimensional", "3(-)dimensional"]},
313 "usbc": {"alt": "usb(-)c"},
314 "at&t": {"internal": "atandt", "symbol": true},
315 "angular": {"alt": "angularjs"},
316 "vi": {"internal": "vieditor"},
317 "go": {"alt": "golang", "internal": "golang", "excludes": ["can go", "will go", "to go", "go to", "go in", "go into", "let go", "letting go", "go home"]},
318 "c": {"internal": "clanguage"},
319 "d": {"internal": "dlanguage", "excludes": ["vitamin d", "d wave", "d waves"]},
320 "c++": {"internal": "cplusplus", "symbol": true},
321 "c#": {"internal": "csharp", "symbol": true},
322 "f#": {"internal": "fsharp", "symbol": true},
323 ".net": {"internal": "dotnet", "symbol": true},
324 }
325 }
326 })).expect("Failed to parse test config")
327 }
328
329 #[fixture]
330 fn tagger(tagger_config: TaggerConfig) -> StoryTagger {
331 StoryTagger::new(&tagger_config)
333 }
334
335 #[rstest]
337 fn test_display_tags(tagger: StoryTagger) {
338 assert_eq!(
339 tagger
340 .make_display_tags(["atandt", "cplusplus", "clanguage", "rust"])
341 .collect_vec(),
342 vec!["at&t", "c++", "c", "rust"]
343 );
344 }
345
346 #[rstest]
348 #[case("cplusplus", &["c++", "cplusplus"])]
349 #[case("clanguage", &["c", "clanguage"])]
350 #[case("atandt", &["at&t", "atandt"])]
351 #[case("angular", &["angular", "angularjs"])]
352 #[case("golang", &["go", "golang"])]
353 #[case("dotnet", &[".net", "dotnet"])]
354 fn test_search_mapping(tagger: StoryTagger, #[case] a: &str, #[case] b: &[&str]) {
355 for b in b {
356 assert_eq!(
357 tagger.check_tag_search(b),
358 Some(a),
359 "Didn't match for '{}'",
360 b
361 );
362 }
363 }
364
365 #[rstest]
366 #[case("I love rust!", &["rust"])]
367 #[case("Good old video", &["video"])]
368 #[case("Good old videos", &["video"])]
369 #[case("Chromium is a project", &["chrome"])]
370 #[case("AngularJS is fun", &["angular"])]
371 #[case("Chromium is the open Chrome", &["chrome"])]
372 #[case("Neovim is kind of cool", &["neovim", "vim"])]
373 #[case("Neovim is a kind of vim", &["neovim", "vim"])]
374 #[case("C is hard", &["clanguage"])]
375 #[case("D is hard", &["dlanguage"])]
376 #[case("C# is hard", &["csharp"])]
377 #[case("C++ is hard", &["cplusplus"])]
378 #[case("AT&T has an ampersand", &["atandt"])]
379 fn test_tag_extraction(tagger: StoryTagger, #[case] s: &str, #[case] tags: &[&str]) {
380 let mut tag_set = TagSet::new();
381 tagger.tag(s, &mut tag_set);
382 assert_eq!(
383 tag_set.collect(),
384 tags.to_vec(),
385 "while checking tags for {}",
386 s
387 );
388 }
389
390 #[rstest]
391 #[case("Usbc.wtf - an article and quiz to find the right USB-C cable", &["usbc"])]
392 #[case("D&D Publisher Addresses Backlash Over Controversial License", &[])]
393 #[case("Microfeatures I'd like to see in more languages", &[])]
394 #[case("What are companies doing with D-Wave's quantum hardware?", &[])]
395 #[case("What are companies doing with D Wave's quantum hardware?", &[])]
396 #[case("Conserving Dürer's Triumphal Arch: coming apart at the seams (2016)", &[])]
397 #[case("J.D. Vance Is Coming for You", &[])]
398 #[case("Rewriting TypeScript in Rust? You'd have to be crazy", &["rust"])]
399 #[case("Vitamin D Supplementation Does Not Influence Growth in Children", &[])]
400 #[case("Vitamin-D Supplementation Does Not Influence Growth in Children", &[])]
401 #[case("They'd rather not", &[])]
402 #[case("Apple Music deletes your original songs and replaces them with DRM'd versions", &[])]
403 fn test_c_and_d_cases(tagger: StoryTagger, #[case] s: &str, #[case] tags: &[&str]) {
404 let mut tag_set = TagSet::new();
405 tagger.tag(s, &mut tag_set);
406 assert_eq!(
407 tag_set.collect(),
408 tags.to_vec(),
409 "while checking tags for {}",
410 s
411 );
412 }
413
414 #[rstest]
415 #[case("New Process Allows 3-D Printing of Microscale Metallic Parts", &["3d"])]
416 #[case("3D printing is wild", &["3d"])]
417 #[case("3 D printing is hard", &["3d"])]
418 #[case("3-D printing is hard", &["3d"])]
419 #[case("three-d printing is hard", &["3d"])]
420 #[case("three d printing is hard", &["3d"])]
421 #[case("three dimensional printing is hard", &["3d"])]
422 #[case("3 dimensional printing is hard", &["3d"])]
423 #[case("3-dimensional printing is hard", &["3d"])]
424 #[case("I love printing in three dimensions", &["3d"])]
426 fn test_3d_cases(tagger: StoryTagger, #[case] s: &str, #[case] tags: &[&str]) {
427 let mut tag_set = TagSet::new();
428 tagger.tag(s, &mut tag_set);
429 assert_eq!(
430 tag_set.collect(),
431 tags.to_vec(),
432 "while checking tags for {}",
433 s
434 );
435 }
436}