genex/
lib.rs

1//! Rust library implementing a custom text generation/templating system. Genex
2//! is similar to [Tracery](https://tracery.io), but with some extra
3//! functionality around using external data.
4//!
5//! # Usage
6//!
7//! First create a grammar, then generate an expansion or multiple expansions
8//! from it.
9//!
10//! ```rust
11//! use std::collections::HashSet;
12//! use std::str::FromStr;
13//! use maplit::hashmap;
14//! use genex::Grammar;
15//!
16//! let grammar = Grammar::from_str(
17//!     r#"
18//!       RULES:
19//!       top = The <adj> <noun> #action|ed# #object|a#?:[ with gusto] in <place>.
20//!       adj = [glistening|#adj#]
21//!       noun = key
22//!       place = [the #room#|#city#]
23//!
24//!       WEIGHTS:
25//!       room = 2
26//!       city = 1
27//!     "#,
28//! )
29//! .unwrap();
30//!
31//! let data = hashmap! {
32//!     "action".to_string() => "pick".to_string(),
33//!     "object".to_string() => "lizard".to_string(),
34//!     "room".to_string() => "kitchen".to_string(),
35//!     "city".to_string() => "New York".to_string(),
36//! };
37//!
38//! // Now we find the top-scoring expansion. The score is the sum of the
39//! // weights of all variables used in an expansion. We know that the top
40//! // scoring expansion is going to end with "the kitchen" because we gave
41//! // `room` a higher weight than `city`.
42//!
43//! let best_expansion = grammar.generate("top", &data).unwrap().unwrap();
44//!
45//! assert_eq!(
46//!     best_expansion,
47//!     "The glistening key picked a lizard in the kitchen.".to_string()
48//! );
49//!
50//! // Now get all possible expansions:
51//!
52//! let all_expansions = grammar.generate_all("top", &data).unwrap();
53//!
54//! assert_eq!(
55//!     HashSet::<_>::from_iter(all_expansions),
56//!     HashSet::<_>::from_iter(vec![
57//!         "The glistening key picked a lizard in New York.".to_string(),
58//!         "The glistening key picked a lizard with gusto in New York.".to_string(),
59//!         "The glistening key picked a lizard with gusto in the kitchen.".to_string(),
60//!         "The glistening key picked a lizard in the kitchen.".to_string(),
61//!     ])
62//! );
63//! ```
64//!
65//! # Features
66//!
67//! Genex tries to make it easy to generate text based on varying amounts of
68//! external data. For example you can write a single expansion grammar that
69//! works when all you know is the  name of an object, but uses the additional
70//! information if you know the object's size, location, color, or other
71//! qualities.
72//!
73//! The default behavior is for genex to try to find an expansion that uses the
74//! most external data possible, but by changing the weights assigned to
75//! variables you can prioritize which variables are used, even prioritizing the
76//! use of a single important variable over the use of multiple, less important
77//! variables.
78//!
79//! # Grammar syntax
80//!
81//! ## Rules
82//!
83//! "`RULES:`" indicates the rules section of the grammar. Rules are defined by
84//! a left-hand side (LHS) and a right-hand side (RHS). The LHS is the name of
85//! the rule. The RHS is a sequence of terms.
86//!
87//! Terms: 
88//! * Sequence: `[term1 term2 ...]`
89//! * Choice: `[term1|term2|...]` (You can put a newline after a `|` character.)
90//! * Optional: `?:[term1 term2 ...]`
91//! * Variable: `#variable#` or `#variable|modifier#`
92//! * Non-terminal: `<rule-name>`
93//! * Plain text: `I am some plain text. I hope I get expanded.`
94//!
95//! ## Weights
96//! 
97//! "`WEIGHTS:`" indicates the weights section of the grammar. Weights are of
98//! the form &lt;_rule-name_&gt; = &lt;_number_&gt;.
99//! 
100//! ## Modifiers
101//! 
102//! Modifiers are used to transform variable values during expansion.
103//! 
104//! Modifiers:
105//! * `capitalize`: Capitalizes the first letter of the value.
106//! * `capitalizeAll`: Capitalizes the first letter of each word in the value.
107//! * `inQuotes`: Surrounds the value with double quotes.
108//! * `comma`: Adds a comma after the value, if it doesn't already end with punctuation.
109//! * `s`: Pluralizes the value.
110//! * `a`: Prefixes the value with an "a"/"an" article as appropriate.
111//! * `ed`: Changes the first word of the value to be past tense.
112//! 
113pub mod error;
114mod modifiers;
115mod parser;
116use std::{collections::HashMap, rc::Rc, str::FromStr};
117
118pub use crate::error::Error;
119use itertools::Itertools;
120use ordered_float::OrderedFloat;
121#[macro_use]
122extern crate lazy_static;
123
124/// A convenience type for a `Result` of `T` or [`Error`]
125///
126/// [`Error`]: enum.Error.html
127pub type Result<T> = ::std::result::Result<T, Error>;
128
129#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
130struct Expansion {
131    varrefs: Vec<String>,
132    text: String,
133}
134
135impl Expansion {
136    fn concat(self, expansion: Expansion) -> Self {
137        let mut varrefs = self.varrefs.clone();
138        varrefs.extend(expansion.varrefs);
139        let mut text = self.text;
140        text.push_str(&expansion.text);
141        Expansion { varrefs, text }
142    }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Hash)]
146struct VarRef {
147    var: String,
148    modifier: Option<String>,
149}
150
151impl VarRef {
152    #[allow(dead_code)]
153    fn with_variable(var: &str) -> Self {
154        VarRef {
155            var: var.to_string(),
156            modifier: None,
157        }
158    }
159
160    #[allow(dead_code)]
161    fn with_variable_and_modifier(var: &str, modifier: &str) -> Self {
162        VarRef {
163            var: var.to_string(),
164            modifier: Some(modifier.to_string()),
165        }
166    }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Hash)]
170enum Node {
171    Sequence(Vec<Node>),
172    Optional(Box<Node>),
173    Choice(Vec<Node>),
174    Text(String),
175    VarRef(VarRef),
176    NonTerminal(String),
177}
178
179impl Node {
180    fn expand(&self, grammar: &Grammar, data: &HashMap<String, String>) -> Result<Vec<Expansion>> {
181        match self {
182            Node::Text(text) => Ok(vec![Expansion {
183                varrefs: vec![],
184                text: text.clone(),
185            }]),
186            Node::VarRef(var) => match data.get(&var.var) {
187                Some(value) => {
188                    let text = match &var.modifier {
189                        Some(modifier) => match grammar.get_modifier(modifier) {
190                            Some(modifier) => Ok(modifier(value)),
191                            None => Err(Error::UnknownModifierError(modifier.to_string())),
192                        },
193                        None => Ok(value.clone()),
194                    }?;
195                    Ok(vec![Expansion {
196                        varrefs: vec![var.var.clone()],
197                        text,
198                    }])
199                }
200                None => Ok(vec![]),
201            },
202            Node::NonTerminal(lhs) => match grammar.rules.get(lhs) {
203                Some(rhs) => rhs.expand(grammar, data),
204                None => Err(Error::UnknownNonTerminalError(lhs.clone())),
205            },
206            Node::Sequence(nodes) => {
207                let x: Vec<Vec<Expansion>> = nodes
208                    .iter()
209                    .map(|n| n.expand(grammar, data))
210                    .collect::<Result<Vec<_>>>()?;
211                let y: Vec<Expansion> = x
212                    .iter()
213                    .multi_cartesian_product()
214                    .map(|c| {
215                        c.into_iter()
216                            .fold(Expansion::default(), |a, b| a.concat(b.clone()))
217                    })
218                    .collect();
219                Ok(y)
220            }
221            Node::Optional(node) => {
222                let mut expansions = node.expand(grammar, data)?;
223                expansions.push(Expansion::default());
224                Ok(expansions)
225            }
226            Node::Choice(nodes) => {
227                let expansions: Vec<Expansion> = nodes
228                    .iter()
229                    // See https://stackoverflow.com/a/59852696/122762, "How to
230                    // handle Result in flat_map"
231                    .map(|n| n.expand(grammar, data))
232                    .flat_map(|result| match result {
233                        Ok(vec) => vec.into_iter().map(Ok).collect(),
234                        Err(e) => vec![Err(e)],
235                    })
236                    .collect::<Result<Vec<_>>>()?;
237                Ok(expansions)
238            }
239        }
240    }
241}
242
243impl ToString for Node {
244    fn to_string(&self) -> String {
245        match self {
246            Node::Text(text) => text.to_string(),
247            Node::Sequence(children) => {
248                format!("[{}]", children.iter().map(|n| n.to_string()).join(""))
249            }
250            Node::VarRef(var) => match &var.modifier {
251                Some(modifier) => format!("#{}|{}#", var.var, modifier),
252                None => format!("#{}#", var.var),
253            },
254            Node::NonTerminal(id) => format!("<{}>", id),
255            Node::Optional(ref node) => format!("?:[{}]", node.to_string()),
256            Node::Choice(nodes) => {
257                format!("[{}]", nodes.iter().map(|n| n.to_string()).join("|"))
258            }
259        }
260    }
261}
262
263/// A grammar is a set of expansion rules.
264#[derive(Clone)]
265pub struct Grammar {
266    rules: HashMap<String, Node>,
267    modifiers: HashMap<String, Rc<dyn Fn(&str) -> String>>,
268    default_weights: HashMap<String, f64>,
269}
270
271impl Grammar {
272    fn new() -> Grammar {
273        Grammar {
274            rules: HashMap::new(),
275            modifiers: HashMap::new(),
276            default_weights: HashMap::new(),
277        }
278    }
279
280    fn add_rule(&mut self, name: &str, node: Node) {
281        self.rules.insert(name.to_string(), node);
282    }
283
284    fn get_rule(&self, name: &str) -> Option<&Node> {
285        self.rules.get(name)
286    }
287
288    fn get_modifier(&self, modifier: &str) -> Option<&dyn Fn(&str) -> String> {
289        self.modifiers.get(modifier).map(|x| x.as_ref())
290    }
291
292    /// Returns the top-scoring expansion of the given rule, using the supplied
293    /// data.
294    pub fn generate(&self, name: &str, data: &HashMap<String, String>) -> Result<Option<String>> {
295        self.generate_with_weights(name, data, &self.default_weights)
296    }
297
298    /// Generates all possible expansions of the given rule, using the supplied
299    /// data.
300    ///
301    /// Returns expansions in descending order by score.
302    pub fn generate_all(&self, name: &str, data: &HashMap<String, String>) -> Result<Vec<String>> {
303        self.generate_all_with_weights(name, data, &self.default_weights)
304    }
305
306    /// Generates the top-scoring expansion of the given rule, using the
307    /// supplied data and weights.
308    pub fn generate_with_weights(
309        &self,
310        name: &str,
311        data: &HashMap<String, String>,
312        weights: &HashMap<String, f64>,
313    ) -> Result<Option<String>> {
314        let node = self.get_rule(name).unwrap();
315        let mut expansions = node.expand(self, data)?;
316        expansions.sort_by_cached_key(|e| OrderedFloat(score_by_varref_weights(e, weights)));
317        Ok(expansions.last().map(|e| e.text.clone()))
318    }
319
320    /// Generates all possible expansions of the given rule, using the supplied
321    /// data and weights.
322    ///
323    /// Returns expansions in descending order by score.
324    pub fn generate_all_with_weights(
325        &self,
326        name: &str,
327        data: &HashMap<String, String>,
328        weights: &HashMap<String, f64>,
329    ) -> Result<Vec<String>> {
330        let node = self
331            .get_rule(name)
332            .ok_or_else(|| Error::UnknownNonTerminalError(name.to_string()))?;
333        let mut expansions = node.expand(self, data)?;
334        expansions.sort_by_cached_key(|e| OrderedFloat(score_by_varref_weights(e, weights)));
335        Ok(expansions.into_iter().rev().map(|e| e.text).collect())
336    }
337}
338
339fn score_by_varref_weights(expansion: &Expansion, weights: &HashMap<String, f64>) -> f64 {
340    expansion
341        .varrefs
342        .iter()
343        .map(|varref| weights.get(varref).unwrap_or(&1.0))
344        .sum()
345}
346
347impl Default for Grammar {
348    fn default() -> Self {
349        let mut grammar = Grammar::new();
350        grammar.modifiers = modifiers::get_default_modifiers();
351        grammar
352    }
353}
354
355impl ToString for Grammar {
356    fn to_string(&self) -> String {
357        let mut s = String::new();
358        for (id, node) in &self.rules {
359            // If the RHS is a sequence, we take advantage of the fact that
360            // RHSes are an implicit sequence, and do not print the brackets
361            // around it.
362            match node {
363                Node::Sequence(children) => {
364                    s.push_str(&format!(
365                        "{} = {}\n",
366                        id,
367                        children.iter().map(|n| n.to_string()).join("")
368                    ));
369                }
370                _ => {
371                    s.push_str(&format!("{} = {}\n", id, node.to_string()));
372                }
373            }
374        }
375        s
376    }
377}
378
379impl FromStr for Grammar {
380    type Err = Error;
381
382    fn from_str(s: &str) -> Result<Self> {
383        let mut grammar = parser::parse_grammar(s)?;
384        grammar.modifiers = modifiers::get_default_modifiers();
385        Ok(grammar)
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use std::collections::HashSet;
392
393    use super::*;
394    use maplit::hashmap;
395
396    fn grammar_and_data() -> (Grammar, HashMap<String, String>) {
397        let mut grammar = Grammar::default();
398        grammar.add_rule(
399            "location",
400            Node::VarRef(VarRef::with_variable_and_modifier("city", "capitalize")),
401        );
402        let data = hashmap! {
403            "name".to_string() => "John".to_string(),
404            "city".to_string() => "london".to_string(),
405        };
406        (grammar, data)
407    }
408
409    #[test]
410    fn test_expand_text() {
411        let (grammar, data) = grammar_and_data();
412        let node = Node::Text("hello".to_string());
413        let expansions = node.expand(&grammar, &data).unwrap();
414        assert_eq!(
415            expansions,
416            vec![Expansion {
417                varrefs: vec![],
418                text: "hello".to_string(),
419            }]
420        );
421    }
422
423    #[test]
424    fn test_expand_varref() {
425        let (grammar, data) = grammar_and_data();
426        let node = Node::VarRef(VarRef::with_variable("name"));
427        let expansions = node.expand(&grammar, &data).unwrap();
428        assert_eq!(
429            expansions,
430            vec![Expansion {
431                varrefs: vec!["name".to_string()],
432                text: "John".to_string(),
433            }]
434        );
435    }
436
437    #[test]
438    fn test_expand_nonterminal() {
439        let (grammar, data) = grammar_and_data();
440        let node = Node::NonTerminal("location".to_string());
441        let expansions = node.expand(&grammar, &data).unwrap();
442        assert_eq!(
443            expansions,
444            vec![Expansion {
445                varrefs: vec!["city".to_string()],
446                text: "London".to_string(),
447            }]
448        );
449    }
450
451    #[test]
452    fn test_expand_sequence() {
453        let (grammar, data) = grammar_and_data();
454        let c1 = Node::Text("in ".to_string());
455        let c2 = Node::NonTerminal("location".to_string());
456        let node = Node::Sequence(vec![c1, c2]);
457        let expansions = node.expand(&grammar, &data).unwrap();
458        assert_eq!(
459            expansions,
460            vec![Expansion {
461                varrefs: vec!["city".to_string()],
462                text: "in London".to_string(),
463            }]
464        );
465    }
466
467    #[test]
468    fn test_expand_optional() {
469        let (grammar, data) = grammar_and_data();
470        let hello = Node::Text("Hello ".to_string());
471        let dear = Node::Text("dear ".to_string());
472        let maybe_dear = Node::Optional(Box::new(dear));
473        let friend = Node::Text("friend".to_string());
474        let seq = Node::Sequence(vec![hello, maybe_dear, friend]);
475        let expansions = seq.expand(&grammar, &data).unwrap();
476        assert_eq!(
477            HashSet::<_>::from_iter(expansions),
478            HashSet::from_iter(vec![
479                Expansion {
480                    varrefs: vec![],
481                    text: "Hello friend".to_string(),
482                },
483                Expansion {
484                    varrefs: vec![],
485                    text: "Hello dear friend".to_string(),
486                }
487            ])
488        );
489    }
490
491    #[test]
492    fn test_expand_choice() {
493        let (grammar, data) = grammar_and_data();
494        let snoopy = Node::Text("Snoopy".to_string());
495        let name = Node::VarRef(VarRef::with_variable("name"));
496        let linus = Node::Text("Linus".to_string());
497        let choice = Node::Choice(vec![snoopy, name, linus]);
498        let expansions = choice.expand(&grammar, &data).unwrap();
499        assert_eq!(
500            HashSet::<_>::from_iter(expansions),
501            HashSet::from_iter(vec![
502                Expansion {
503                    varrefs: vec![],
504                    text: "Snoopy".to_string(),
505                },
506                Expansion {
507                    varrefs: vec!["name".to_string()],
508                    text: "John".to_string(),
509                },
510                Expansion {
511                    varrefs: vec![],
512                    text: "Linus".to_string(),
513                },
514            ])
515        );
516    }
517
518    #[test]
519    fn test_to_string() {
520        let mut grammar = Grammar::default();
521        grammar.add_rule(
522            "top",
523            Node::Sequence(vec![
524                Node::Text("hi ".to_string()),
525                Node::VarRef(VarRef::with_variable("name")),
526                Node::Text(" in ".to_string()),
527                Node::NonTerminal("location".to_string()),
528            ]),
529        );
530        grammar.add_rule(
531            "location",
532            Node::Sequence(vec![
533                Node::Text("city of ".to_string()),
534                Node::VarRef(VarRef::with_variable("city")),
535            ]),
536        );
537        assert_eq!(
538            HashSet::<_>::from_iter(grammar.to_string().split('\n').filter(|s| !s.is_empty())),
539            HashSet::from_iter(vec![
540                "top = hi #name# in <location>",
541                "location = city of #city#",
542            ])
543        );
544    }
545
546    #[test]
547    fn test_generate() {
548        let grammar = Grammar::from_str(
549            r#"
550            top = Hi <name>?:[, my dear #gender#,] in <location>.
551            name = #name#
552            location = [city of #city#|#city# in #county# county]
553            "#,
554        )
555        .unwrap();
556        let data = hashmap! {
557            "name".to_string() => "John".to_string(),
558            "city".to_string() => "Janesville".to_string(),
559            "county".to_string() => "Rock".to_string(),
560        };
561        let r = grammar.generate("top", &data).unwrap().unwrap();
562        assert_eq!(r, "Hi John in Janesville in Rock county.");
563
564        let exps = HashSet::<_>::from_iter(grammar.generate_all("top", &data).unwrap());
565        assert_eq!(
566            exps,
567            HashSet::from_iter(vec![
568                "Hi John in Janesville in Rock county.".to_string(),
569                "Hi John in city of Janesville.".to_string(),
570            ])
571        );
572    }
573}