layered_nlp/resolvers/
text_match.rs1use std::collections::HashMap;
2use std::fmt::Debug;
3
4use crate::{x, Resolver};
5
6pub struct TextMatchAssignResolver<T: Clone> {
10 case_sensitive: bool,
11 lookup: HashMap<String, Vec<T>>,
13}
14
15impl<T: Clone> TextMatchAssignResolver<T> {
16 pub fn new(lookup: HashMap<String, Vec<T>>) -> Self {
17 TextMatchAssignResolver {
18 case_sensitive: true,
19 lookup,
20 }
21 }
22 pub fn new_case_insensitive(lookup: HashMap<String, Vec<T>>) -> Self {
23 TextMatchAssignResolver {
24 case_sensitive: false,
25 lookup: lookup
26 .into_iter()
27 .map(|(key, val)| (key.to_lowercase(), val))
28 .collect(),
29 }
30 }
31 pub fn new_case_insensitive_str_arr<const N: usize>(lookup: [(&'static str, T); N]) -> Self {
32 TextMatchAssignResolver {
33 case_sensitive: false,
34 lookup: lookup
35 .iter()
36 .map(|(key, val)| (key.to_lowercase(), vec![val.clone()]))
37 .collect(),
38 }
39 }
40}
41
42impl<T: Debug + Clone + 'static> Resolver for TextMatchAssignResolver<T> {
43 type Attr = T;
44
45 fn go(&self, selection: crate::LLSelection) -> Vec<crate::LLCursorAssignment<Self::Attr>> {
46 selection
47 .find_by(&x::token_text())
48 .into_iter()
49 .flat_map(|(selection, text)| {
50 if self.case_sensitive {
51 self.lookup.get(text)
52 } else {
53 self.lookup.get(&text.to_lowercase())
54 }
55 .map(|values| {
56 values
57 .iter()
58 .cloned()
59 .map(move |attr: T| selection.finish_with_attr(attr))
60 })
61 })
62 .flatten()
63 .collect()
64 }
65}
66
67#[test]
68fn test() {
69 use crate::{create_line_from_input_tokens, InputToken, LLLineDisplay};
70
71 #[derive(Debug, Clone)]
72 enum Service {
73 Slack,
74 Algolia,
75 Magic,
76 Wolfram,
77 }
78
79 let ll_line = create_line_from_input_tokens(
80 vec![
81 InputToken::text("when Slack hears a message in #general".to_string(), vec![]),
82 InputToken::text("Algolia search query: message, table".to_string(), vec![]),
83 ],
84 |text| text.encode_utf16().count(),
85 );
86
87 let ll_line = ll_line.run(&TextMatchAssignResolver::new_case_insensitive({
88 [
89 ("Slack".to_string(), vec![Service::Slack]),
90 ("Algolia".to_string(), vec![Service::Algolia]),
91 ("Magic".to_string(), vec![Service::Magic]),
92 ("Wolfram".to_string(), vec![Service::Wolfram]),
93 ]
94 .iter()
95 .cloned()
96 .collect()
97 }));
98
99 let mut ll_display = LLLineDisplay::new(&ll_line);
100 ll_display.include::<Service>();
101
102 insta::assert_display_snapshot!(ll_display, @r###"
103 when Slack hears a message in # general Algolia search query : message , table
104 ╰───╯Slack
105 ╰─────╯Algolia
106 "###);
107}