1pub mod error;
114mod modifiers;
115mod parser;
116use std::{collections::HashMap, rc::Rc, str::FromStr};
117
118pub use crate::error::Error;
119use itertools::Itertools;
120use ordered_float::OrderedFloat;
121#[macro_use]
122extern crate lazy_static;
123
124pub type Result<T> = ::std::result::Result<T, Error>;
128
129#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
130struct Expansion {
131 varrefs: Vec<String>,
132 text: String,
133}
134
135impl Expansion {
136 fn concat(self, expansion: Expansion) -> Self {
137 let mut varrefs = self.varrefs.clone();
138 varrefs.extend(expansion.varrefs);
139 let mut text = self.text;
140 text.push_str(&expansion.text);
141 Expansion { varrefs, text }
142 }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Hash)]
146struct VarRef {
147 var: String,
148 modifier: Option<String>,
149}
150
151impl VarRef {
152 #[allow(dead_code)]
153 fn with_variable(var: &str) -> Self {
154 VarRef {
155 var: var.to_string(),
156 modifier: None,
157 }
158 }
159
160 #[allow(dead_code)]
161 fn with_variable_and_modifier(var: &str, modifier: &str) -> Self {
162 VarRef {
163 var: var.to_string(),
164 modifier: Some(modifier.to_string()),
165 }
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Hash)]
170enum Node {
171 Sequence(Vec<Node>),
172 Optional(Box<Node>),
173 Choice(Vec<Node>),
174 Text(String),
175 VarRef(VarRef),
176 NonTerminal(String),
177}
178
179impl Node {
180 fn expand(&self, grammar: &Grammar, data: &HashMap<String, String>) -> Result<Vec<Expansion>> {
181 match self {
182 Node::Text(text) => Ok(vec![Expansion {
183 varrefs: vec![],
184 text: text.clone(),
185 }]),
186 Node::VarRef(var) => match data.get(&var.var) {
187 Some(value) => {
188 let text = match &var.modifier {
189 Some(modifier) => match grammar.get_modifier(modifier) {
190 Some(modifier) => Ok(modifier(value)),
191 None => Err(Error::UnknownModifierError(modifier.to_string())),
192 },
193 None => Ok(value.clone()),
194 }?;
195 Ok(vec![Expansion {
196 varrefs: vec![var.var.clone()],
197 text,
198 }])
199 }
200 None => Ok(vec![]),
201 },
202 Node::NonTerminal(lhs) => match grammar.rules.get(lhs) {
203 Some(rhs) => rhs.expand(grammar, data),
204 None => Err(Error::UnknownNonTerminalError(lhs.clone())),
205 },
206 Node::Sequence(nodes) => {
207 let x: Vec<Vec<Expansion>> = nodes
208 .iter()
209 .map(|n| n.expand(grammar, data))
210 .collect::<Result<Vec<_>>>()?;
211 let y: Vec<Expansion> = x
212 .iter()
213 .multi_cartesian_product()
214 .map(|c| {
215 c.into_iter()
216 .fold(Expansion::default(), |a, b| a.concat(b.clone()))
217 })
218 .collect();
219 Ok(y)
220 }
221 Node::Optional(node) => {
222 let mut expansions = node.expand(grammar, data)?;
223 expansions.push(Expansion::default());
224 Ok(expansions)
225 }
226 Node::Choice(nodes) => {
227 let expansions: Vec<Expansion> = nodes
228 .iter()
229 .map(|n| n.expand(grammar, data))
232 .flat_map(|result| match result {
233 Ok(vec) => vec.into_iter().map(Ok).collect(),
234 Err(e) => vec![Err(e)],
235 })
236 .collect::<Result<Vec<_>>>()?;
237 Ok(expansions)
238 }
239 }
240 }
241}
242
243impl ToString for Node {
244 fn to_string(&self) -> String {
245 match self {
246 Node::Text(text) => text.to_string(),
247 Node::Sequence(children) => {
248 format!("[{}]", children.iter().map(|n| n.to_string()).join(""))
249 }
250 Node::VarRef(var) => match &var.modifier {
251 Some(modifier) => format!("#{}|{}#", var.var, modifier),
252 None => format!("#{}#", var.var),
253 },
254 Node::NonTerminal(id) => format!("<{}>", id),
255 Node::Optional(ref node) => format!("?:[{}]", node.to_string()),
256 Node::Choice(nodes) => {
257 format!("[{}]", nodes.iter().map(|n| n.to_string()).join("|"))
258 }
259 }
260 }
261}
262
263#[derive(Clone)]
265pub struct Grammar {
266 rules: HashMap<String, Node>,
267 modifiers: HashMap<String, Rc<dyn Fn(&str) -> String>>,
268 default_weights: HashMap<String, f64>,
269}
270
271impl Grammar {
272 fn new() -> Grammar {
273 Grammar {
274 rules: HashMap::new(),
275 modifiers: HashMap::new(),
276 default_weights: HashMap::new(),
277 }
278 }
279
280 fn add_rule(&mut self, name: &str, node: Node) {
281 self.rules.insert(name.to_string(), node);
282 }
283
284 fn get_rule(&self, name: &str) -> Option<&Node> {
285 self.rules.get(name)
286 }
287
288 fn get_modifier(&self, modifier: &str) -> Option<&dyn Fn(&str) -> String> {
289 self.modifiers.get(modifier).map(|x| x.as_ref())
290 }
291
292 pub fn generate(&self, name: &str, data: &HashMap<String, String>) -> Result<Option<String>> {
295 self.generate_with_weights(name, data, &self.default_weights)
296 }
297
298 pub fn generate_all(&self, name: &str, data: &HashMap<String, String>) -> Result<Vec<String>> {
303 self.generate_all_with_weights(name, data, &self.default_weights)
304 }
305
306 pub fn generate_with_weights(
309 &self,
310 name: &str,
311 data: &HashMap<String, String>,
312 weights: &HashMap<String, f64>,
313 ) -> Result<Option<String>> {
314 let node = self.get_rule(name).unwrap();
315 let mut expansions = node.expand(self, data)?;
316 expansions.sort_by_cached_key(|e| OrderedFloat(score_by_varref_weights(e, weights)));
317 Ok(expansions.last().map(|e| e.text.clone()))
318 }
319
320 pub fn generate_all_with_weights(
325 &self,
326 name: &str,
327 data: &HashMap<String, String>,
328 weights: &HashMap<String, f64>,
329 ) -> Result<Vec<String>> {
330 let node = self
331 .get_rule(name)
332 .ok_or_else(|| Error::UnknownNonTerminalError(name.to_string()))?;
333 let mut expansions = node.expand(self, data)?;
334 expansions.sort_by_cached_key(|e| OrderedFloat(score_by_varref_weights(e, weights)));
335 Ok(expansions.into_iter().rev().map(|e| e.text).collect())
336 }
337}
338
339fn score_by_varref_weights(expansion: &Expansion, weights: &HashMap<String, f64>) -> f64 {
340 expansion
341 .varrefs
342 .iter()
343 .map(|varref| weights.get(varref).unwrap_or(&1.0))
344 .sum()
345}
346
347impl Default for Grammar {
348 fn default() -> Self {
349 let mut grammar = Grammar::new();
350 grammar.modifiers = modifiers::get_default_modifiers();
351 grammar
352 }
353}
354
355impl ToString for Grammar {
356 fn to_string(&self) -> String {
357 let mut s = String::new();
358 for (id, node) in &self.rules {
359 match node {
363 Node::Sequence(children) => {
364 s.push_str(&format!(
365 "{} = {}\n",
366 id,
367 children.iter().map(|n| n.to_string()).join("")
368 ));
369 }
370 _ => {
371 s.push_str(&format!("{} = {}\n", id, node.to_string()));
372 }
373 }
374 }
375 s
376 }
377}
378
379impl FromStr for Grammar {
380 type Err = Error;
381
382 fn from_str(s: &str) -> Result<Self> {
383 let mut grammar = parser::parse_grammar(s)?;
384 grammar.modifiers = modifiers::get_default_modifiers();
385 Ok(grammar)
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use std::collections::HashSet;
392
393 use super::*;
394 use maplit::hashmap;
395
396 fn grammar_and_data() -> (Grammar, HashMap<String, String>) {
397 let mut grammar = Grammar::default();
398 grammar.add_rule(
399 "location",
400 Node::VarRef(VarRef::with_variable_and_modifier("city", "capitalize")),
401 );
402 let data = hashmap! {
403 "name".to_string() => "John".to_string(),
404 "city".to_string() => "london".to_string(),
405 };
406 (grammar, data)
407 }
408
409 #[test]
410 fn test_expand_text() {
411 let (grammar, data) = grammar_and_data();
412 let node = Node::Text("hello".to_string());
413 let expansions = node.expand(&grammar, &data).unwrap();
414 assert_eq!(
415 expansions,
416 vec![Expansion {
417 varrefs: vec![],
418 text: "hello".to_string(),
419 }]
420 );
421 }
422
423 #[test]
424 fn test_expand_varref() {
425 let (grammar, data) = grammar_and_data();
426 let node = Node::VarRef(VarRef::with_variable("name"));
427 let expansions = node.expand(&grammar, &data).unwrap();
428 assert_eq!(
429 expansions,
430 vec![Expansion {
431 varrefs: vec!["name".to_string()],
432 text: "John".to_string(),
433 }]
434 );
435 }
436
437 #[test]
438 fn test_expand_nonterminal() {
439 let (grammar, data) = grammar_and_data();
440 let node = Node::NonTerminal("location".to_string());
441 let expansions = node.expand(&grammar, &data).unwrap();
442 assert_eq!(
443 expansions,
444 vec![Expansion {
445 varrefs: vec!["city".to_string()],
446 text: "London".to_string(),
447 }]
448 );
449 }
450
451 #[test]
452 fn test_expand_sequence() {
453 let (grammar, data) = grammar_and_data();
454 let c1 = Node::Text("in ".to_string());
455 let c2 = Node::NonTerminal("location".to_string());
456 let node = Node::Sequence(vec![c1, c2]);
457 let expansions = node.expand(&grammar, &data).unwrap();
458 assert_eq!(
459 expansions,
460 vec![Expansion {
461 varrefs: vec!["city".to_string()],
462 text: "in London".to_string(),
463 }]
464 );
465 }
466
467 #[test]
468 fn test_expand_optional() {
469 let (grammar, data) = grammar_and_data();
470 let hello = Node::Text("Hello ".to_string());
471 let dear = Node::Text("dear ".to_string());
472 let maybe_dear = Node::Optional(Box::new(dear));
473 let friend = Node::Text("friend".to_string());
474 let seq = Node::Sequence(vec![hello, maybe_dear, friend]);
475 let expansions = seq.expand(&grammar, &data).unwrap();
476 assert_eq!(
477 HashSet::<_>::from_iter(expansions),
478 HashSet::from_iter(vec![
479 Expansion {
480 varrefs: vec![],
481 text: "Hello friend".to_string(),
482 },
483 Expansion {
484 varrefs: vec![],
485 text: "Hello dear friend".to_string(),
486 }
487 ])
488 );
489 }
490
491 #[test]
492 fn test_expand_choice() {
493 let (grammar, data) = grammar_and_data();
494 let snoopy = Node::Text("Snoopy".to_string());
495 let name = Node::VarRef(VarRef::with_variable("name"));
496 let linus = Node::Text("Linus".to_string());
497 let choice = Node::Choice(vec![snoopy, name, linus]);
498 let expansions = choice.expand(&grammar, &data).unwrap();
499 assert_eq!(
500 HashSet::<_>::from_iter(expansions),
501 HashSet::from_iter(vec![
502 Expansion {
503 varrefs: vec![],
504 text: "Snoopy".to_string(),
505 },
506 Expansion {
507 varrefs: vec!["name".to_string()],
508 text: "John".to_string(),
509 },
510 Expansion {
511 varrefs: vec![],
512 text: "Linus".to_string(),
513 },
514 ])
515 );
516 }
517
518 #[test]
519 fn test_to_string() {
520 let mut grammar = Grammar::default();
521 grammar.add_rule(
522 "top",
523 Node::Sequence(vec![
524 Node::Text("hi ".to_string()),
525 Node::VarRef(VarRef::with_variable("name")),
526 Node::Text(" in ".to_string()),
527 Node::NonTerminal("location".to_string()),
528 ]),
529 );
530 grammar.add_rule(
531 "location",
532 Node::Sequence(vec![
533 Node::Text("city of ".to_string()),
534 Node::VarRef(VarRef::with_variable("city")),
535 ]),
536 );
537 assert_eq!(
538 HashSet::<_>::from_iter(grammar.to_string().split('\n').filter(|s| !s.is_empty())),
539 HashSet::from_iter(vec![
540 "top = hi #name# in <location>",
541 "location = city of #city#",
542 ])
543 );
544 }
545
546 #[test]
547 fn test_generate() {
548 let grammar = Grammar::from_str(
549 r#"
550 top = Hi <name>?:[, my dear #gender#,] in <location>.
551 name = #name#
552 location = [city of #city#|#city# in #county# county]
553 "#,
554 )
555 .unwrap();
556 let data = hashmap! {
557 "name".to_string() => "John".to_string(),
558 "city".to_string() => "Janesville".to_string(),
559 "county".to_string() => "Rock".to_string(),
560 };
561 let r = grammar.generate("top", &data).unwrap().unwrap();
562 assert_eq!(r, "Hi John in Janesville in Rock county.");
563
564 let exps = HashSet::<_>::from_iter(grammar.generate_all("top", &data).unwrap());
565 assert_eq!(
566 exps,
567 HashSet::from_iter(vec![
568 "Hi John in Janesville in Rock county.".to_string(),
569 "Hi John in city of Janesville.".to_string(),
570 ])
571 );
572 }
573}