progscrape_application/story/
tagger.rs

1use std::collections::{HashMap, HashSet};
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5
6use super::{TagAcceptor, TagSet};
7
8#[derive(Default, Serialize, Deserialize)]
9pub struct TagConfig {
10    #[serde(default)]
11    host: Option<String>,
12    #[serde(default)]
13    hosts: Vec<String>,
14    #[serde(default)]
15    alt: Option<String>,
16    #[serde(default)]
17    alts: Vec<String>,
18    #[serde(default)]
19    implies: Option<String>,
20    #[serde(default)]
21    internal: Option<String>,
22    #[serde(default)]
23    excludes: Vec<String>,
24    #[serde(default)]
25    symbol: bool,
26}
27
28#[derive(Default, Serialize, Deserialize)]
29pub struct TaggerConfig {
30    tags: HashMap<String, HashMap<String, TagConfig>>,
31}
32
33#[derive(Debug)]
34struct TagRecord {
35    output: String,
36    implies: Vec<String>,
37}
38
39#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
40struct MultiTokenTag {
41    tag: Vec<String>,
42}
43
44impl MultiTokenTag {
45    pub fn matches<T: AsRef<str> + std::cmp::PartialEq<String>>(&self, slice: &[T]) -> bool {
46        if slice.len() < self.tag.len() {
47            return false;
48        }
49        // If the next `self.tag.len()` items in slice match, we match (additional items are OK)
50        itertools::equal(
51            slice.iter().take(self.tag.len()).map(T::as_ref),
52            self.tag.iter(),
53        )
54    }
55
56    pub fn chomp<T: AsRef<str> + std::cmp::PartialEq<String>>(&self, slice: &mut &[T]) -> bool {
57        if self.matches(slice) {
58            *slice = &slice[self.tag.len()..];
59            true
60        } else {
61            false
62        }
63    }
64}
65
66#[derive(Debug)]
67/// The `StoryTagger` creates a list of tag symbols from a story.
68pub struct StoryTagger {
69    records: Vec<TagRecord>,
70    /// Maps tags to internal symbols
71    forward: HashMap<String, usize>,
72    /// Forward-maps multi-token tags.
73    forward_multi: HashMap<MultiTokenTag, usize>,
74    /// Exclusion tokens that mute other tags.
75    exclusions: HashMap<MultiTokenTag, String>,
76    /// Maps internal symbols to tags (only required in a handful of cases)
77    backward: HashMap<String, String>,
78    ///
79    symbols: HashMap<String, usize>,
80}
81
82impl StoryTagger {
83    // TODO: These methods allocate a lot of temporaries that probably don't need to be allocated
84    fn compute_tag(tag: &str) -> Vec<String> {
85        // Optional hyphen/space
86        if tag.contains("(-)") {
87            let mut v = Self::compute_tag(&tag.replace("(-)", "-"));
88            v.extend(Self::compute_tag(&tag.replace("(-)", " ")));
89            return v;
90        }
91        if let Some(tag) = tag.strip_suffix("(s)") {
92            vec![tag.to_owned(), tag.to_owned() + "s"]
93        } else {
94            vec![tag.to_owned()]
95        }
96    }
97
98    /// From a tag and list of alts, compute all possible permutations
99    fn compute_all_tags(
100        tag: &str,
101        alt: &Option<String>,
102        alts: &Vec<String>,
103    ) -> (String, HashSet<String>) {
104        let mut tags = HashSet::new();
105        let v = Self::compute_tag(tag);
106        let primary = v[0].clone();
107        tags.extend(v);
108        if let Some(alt) = alt {
109            tags.extend(Self::compute_tag(alt));
110        }
111        for alt in alts {
112            tags.extend(Self::compute_tag(alt));
113        }
114        (primary, tags)
115    }
116
117    pub fn new(config: &TaggerConfig) -> Self {
118        let mut new = Self {
119            forward: HashMap::new(),
120            forward_multi: HashMap::new(),
121            backward: HashMap::new(),
122            records: vec![],
123            symbols: HashMap::new(),
124            exclusions: HashMap::new(),
125        };
126        for tags in config.tags.values() {
127            for (tag, tags) in tags {
128                let (primary, all_tags) = Self::compute_all_tags(tag, &tags.alt, &tags.alts);
129                let excludes = tags
130                    .excludes
131                    .iter()
132                    .flat_map(|s| Self::compute_tag(s))
133                    .map(|s| MultiTokenTag {
134                        tag: s.split_ascii_whitespace().map(str::to_owned).collect(),
135                    });
136                for exclude in excludes {
137                    new.exclusions.insert(exclude, primary.clone());
138                }
139                let record = TagRecord {
140                    output: match tags.internal {
141                        Some(ref s) => s.clone(),
142                        None => primary,
143                    },
144                    implies: tags.implies.clone().into_iter().collect(),
145                };
146                if let Some(internal) = &tags.internal {
147                    new.backward.insert(internal.clone(), tag.clone());
148                }
149                for tag in all_tags {
150                    if tags.symbol {
151                        new.backward.insert(record.output.clone(), tag.clone());
152                        new.symbols.insert(tag, new.records.len());
153                    } else if tag.contains(' ') {
154                        let tag = MultiTokenTag {
155                            tag: tag.split_ascii_whitespace().map(str::to_owned).collect(),
156                        };
157                        new.forward_multi.insert(tag, new.records.len());
158                    } else {
159                        new.forward.insert(tag, new.records.len());
160                    }
161                }
162
163                new.records.push(record);
164            }
165        }
166
167        new
168    }
169
170    pub fn tag<T: TagAcceptor>(&self, s: &str, tags: &mut T) {
171        let s = s.to_lowercase();
172
173        // Clean up single quotes to a standard type
174        let s = s.replace(
175            |c| {
176                c == '`' || c == '\u{2018}' || c == '\u{2019}' || c == '\u{201a}' || c == '\u{201b}'
177            },
178            "'",
179        );
180
181        // Replace possessive with non-possessive
182        let mut s = s.replace("'s", "");
183
184        // First, we replace all symbols and generate tags
185        for (symbol, rec) in &self.symbols {
186            if s.contains(symbol) {
187                s = s.replace(symbol, " ");
188                tags.tag(&self.records[*rec].output);
189                for implies in &self.records[*rec].implies {
190                    tags.tag(implies);
191                }
192            }
193        }
194
195        // Next, we check all the word-like tokens for potential matches
196        let tokens_vec = s
197            .split_ascii_whitespace()
198            .map(|s| s.replace(|c: char| !c.is_alphanumeric() && c != '-', ""))
199            .filter(|s| !s.is_empty())
200            .collect_vec();
201        let mut tokens = tokens_vec.as_slice();
202
203        let mut mutes = HashMap::new();
204
205        'outer: while !tokens.is_empty() {
206            mutes.retain(|_k, v| {
207                if *v == 0 {
208                    false
209                } else {
210                    *v -= 1;
211                    true
212                }
213            });
214            for (exclusion, tag) in &self.exclusions {
215                if exclusion.matches(tokens) {
216                    mutes.insert(tag.clone(), exclusion.tag.len() - 1);
217                }
218            }
219            for (multi, rec) in &self.forward_multi {
220                if multi.chomp(&mut tokens) {
221                    let rec = &self.records[*rec];
222                    tags.tag(&rec.output);
223                    for implies in &rec.implies {
224                        tags.tag(implies);
225                    }
226                    continue 'outer;
227                }
228            }
229            if let Some(rec) = self.forward.get(&tokens[0]) {
230                if !mutes.contains_key(&tokens[0]) {
231                    let rec = &self.records[*rec];
232                    tags.tag(&rec.output);
233                    for implies in &rec.implies {
234                        tags.tag(implies);
235                    }
236                }
237            }
238            tokens = &tokens[1..];
239        }
240    }
241
242    /// Identify any tags in the search term and return the appropriate search term to use. If the search term is a symbol,
243    /// we must use its internal version (ie: cplusplus -> c++, c -> clanguage).
244    pub fn check_tag_search(&self, search: &str) -> Option<&str> {
245        let lowercase = search.to_lowercase();
246        if let Some(idx) = self.symbols.get(&lowercase) {
247            return Some(&self.records[*idx].output);
248        }
249        if let Some(idx) = self.forward.get(&lowercase) {
250            return Some(&self.records[*idx].output);
251        }
252        if let Some((k, _)) = self.backward.get_key_value(&lowercase) {
253            return Some(k.as_str());
254        }
255
256        None
257    }
258
259    /// Given a raw, indexed tag, output a tag that is suitable for display purposes (ie: cplusplus -> c++).
260    pub fn make_display_tag<'a, S: AsRef<str> + 'a>(&'a self, s: S) -> String {
261        let lowercase = s.as_ref().to_lowercase();
262        if let Some(backward) = self.backward.get(&lowercase) {
263            backward.clone()
264        } else {
265            lowercase
266        }
267    }
268
269    /// Given an iterator of raw, indexed tags, output an iterator that is suitable for display purposes (ie: cplusplus -> c++).
270    pub fn make_display_tags<'a, S: AsRef<str>, I: IntoIterator<Item = S> + 'a>(
271        &'a self,
272        iter: I,
273    ) -> impl Iterator<Item = String> + 'a {
274        iter.into_iter().map(|s| self.make_display_tag(s))
275    }
276
277    pub fn tag_details() -> Vec<(String, TagSet)> {
278        // let mut tags = HashMap::new();
279        // let mut tag_set = TagSet::new();
280        // resources.tagger().tag(story.title(), &mut tag_set);
281        // tags.insert("title".to_owned(), tag_set.collect());
282        // for (id, scrape) in &story.scrapes {
283        //     let mut tag_set = TagSet::new();
284        //     scrape.tag(&resources.config().scrape, &mut tag_set)?;
285        //     tags.insert(format!("scrape {:?}", id), tag_set.collect());
286        // }
287        Default::default()
288    }
289}
290
291#[cfg(test)]
292pub(crate) mod test {
293    use itertools::Itertools;
294    use rstest::*;
295    use serde_json::json;
296
297    use crate::story::TagSet;
298
299    use super::{StoryTagger, TaggerConfig};
300
301    /// Create a tagger configuration with a wide variety of cases. Note that this is used in `StoryEvaulator`'s test mode.
302    #[fixture]
303    pub(crate) fn tagger_config() -> TaggerConfig {
304        serde_json::from_value(json!({
305            "tags": {
306                "testing": {
307                    "video(s)": {"hosts": ["youtube.com", "vimeo.com"]},
308                    "rust": {},
309                    "chrome": {"alt": "chromium"},
310                    "neovim": {"implies": "vim"},
311                    "vim": {},
312                    "3d": {"alts": ["3(-)d", "3(-)dimension(s)", "three(-)d", "three(-)dimension(s)", "three(-)dimensional", "3(-)dimensional"]},
313                    "usbc": {"alt": "usb(-)c"},
314                    "at&t": {"internal": "atandt", "symbol": true},
315                    "angular": {"alt": "angularjs"},
316                    "vi": {"internal": "vieditor"},
317                    "go": {"alt": "golang", "internal": "golang", "excludes": ["can go", "will go", "to go", "go to", "go in", "go into", "let go", "letting go", "go home"]},
318                    "c": {"internal": "clanguage"},
319                    "d": {"internal": "dlanguage", "excludes": ["vitamin d", "d wave", "d waves"]},
320                    "c++": {"internal": "cplusplus", "symbol": true},
321                    "c#": {"internal": "csharp", "symbol": true},
322                    "f#": {"internal": "fsharp", "symbol": true},
323                    ".net": {"internal": "dotnet", "symbol": true},
324                }
325            }
326        })).expect("Failed to parse test config")
327    }
328
329    #[fixture]
330    fn tagger(tagger_config: TaggerConfig) -> StoryTagger {
331        // println!("{:?}", tagger);
332        StoryTagger::new(&tagger_config)
333    }
334
335    /// Ensure that symbol-like tags are reverse-lookup'd properly for display purposes.
336    #[rstest]
337    fn test_display_tags(tagger: StoryTagger) {
338        assert_eq!(
339            tagger
340                .make_display_tags(["atandt", "cplusplus", "clanguage", "rust"])
341                .collect_vec(),
342            vec!["at&t", "c++", "c", "rust"]
343        );
344    }
345
346    /// Esnure that we can detect when symbol-like tags are passed to a search function.
347    #[rstest]
348    #[case("cplusplus", &["c++", "cplusplus"])]
349    #[case("clanguage", &["c", "clanguage"])]
350    #[case("atandt", &["at&t", "atandt"])]
351    #[case("angular", &["angular", "angularjs"])]
352    #[case("golang", &["go", "golang"])]
353    #[case("dotnet", &[".net", "dotnet"])]
354    fn test_search_mapping(tagger: StoryTagger, #[case] a: &str, #[case] b: &[&str]) {
355        for b in b {
356            assert_eq!(
357                tagger.check_tag_search(b),
358                Some(a),
359                "Didn't match for '{}'",
360                b
361            );
362        }
363    }
364
365    #[rstest]
366    #[case("I love rust!", &["rust"])]
367    #[case("Good old video", &["video"])]
368    #[case("Good old videos", &["video"])]
369    #[case("Chromium is a project", &["chrome"])]
370    #[case("AngularJS is fun", &["angular"])]
371    #[case("Chromium is the open Chrome", &["chrome"])]
372    #[case("Neovim is kind of cool", &["neovim", "vim"])]
373    #[case("Neovim is a kind of vim", &["neovim", "vim"])]
374    #[case("C is hard", &["clanguage"])]
375    #[case("D is hard", &["dlanguage"])]
376    #[case("C# is hard", &["csharp"])]
377    #[case("C++ is hard", &["cplusplus"])]
378    #[case("AT&T has an ampersand", &["atandt"])]
379    fn test_tag_extraction(tagger: StoryTagger, #[case] s: &str, #[case] tags: &[&str]) {
380        let mut tag_set = TagSet::new();
381        tagger.tag(s, &mut tag_set);
382        assert_eq!(
383            tag_set.collect(),
384            tags.to_vec(),
385            "while checking tags for {}",
386            s
387        );
388    }
389
390    #[rstest]
391    #[case("Usbc.wtf - an article and quiz to find the right USB-C cable", &["usbc"])]
392    #[case("D&D Publisher Addresses Backlash Over Controversial License", &[])]
393    #[case("Microfeatures I'd like to see in more languages", &[])]
394    #[case("What are companies doing with D-Wave's quantum hardware?", &[])]
395    #[case("What are companies doing with D Wave's quantum hardware?", &[])]
396    #[case("Conserving Dürer's Triumphal Arch: coming apart at the seams (2016)", &[])]
397    #[case("J.D. Vance Is Coming for You", &[])]
398    #[case("Rewriting TypeScript in Rust? You'd have to be crazy", &["rust"])]
399    #[case("Vitamin D Supplementation Does Not Influence Growth in Children", &[])]
400    #[case("Vitamin-D Supplementation Does Not Influence Growth in Children", &[])]
401    #[case("They'd rather not", &[])]
402    #[case("Apple Music deletes your original songs and replaces them with DRM'd versions", &[])]
403    fn test_c_and_d_cases(tagger: StoryTagger, #[case] s: &str, #[case] tags: &[&str]) {
404        let mut tag_set = TagSet::new();
405        tagger.tag(s, &mut tag_set);
406        assert_eq!(
407            tag_set.collect(),
408            tags.to_vec(),
409            "while checking tags for {}",
410            s
411        );
412    }
413
414    #[rstest]
415    #[case("New Process Allows 3-D Printing of Microscale Metallic Parts", &["3d"])]
416    #[case("3D printing is wild", &["3d"])]
417    #[case("3 D printing is hard", &["3d"])]
418    #[case("3-D printing is hard", &["3d"])]
419    #[case("three-d printing is hard", &["3d"])]
420    #[case("three d printing is hard", &["3d"])]
421    #[case("three dimensional printing is hard", &["3d"])]
422    #[case("3 dimensional printing is hard", &["3d"])]
423    #[case("3-dimensional printing is hard", &["3d"])]
424    // Multi-word token at the end
425    #[case("I love printing in three dimensions", &["3d"])]
426    fn test_3d_cases(tagger: StoryTagger, #[case] s: &str, #[case] tags: &[&str]) {
427        let mut tag_set = TagSet::new();
428        tagger.tag(s, &mut tag_set);
429        assert_eq!(
430            tag_set.collect(),
431            tags.to_vec(),
432            "while checking tags for {}",
433            s
434        );
435    }
436}