socratic/
lib.rs

1//! socratic is for dialog systems
2
3#![deny(unused, missing_docs, private_in_public)]
4
5use std::{collections::HashMap, path::Path, str::FromStr};
6
7#[cfg(feature = "cbor")]
8use std::io;
9
10pub use error::ParseError;
11use error::SocraticError;
12pub use lexing::Atom;
13use lexing::AtomOr;
14use serde::{Deserialize, Serialize};
15use tracing::{info, info_span, instrument};
16
17mod error;
18mod lexing;
19mod parsing;
20
21/// A group of atoms representing a section of text.
22#[derive(Debug, Default, Hash, PartialEq, Eq, Serialize, Deserialize, Clone)]
23pub struct Atoms<T = String>(pub Vec<Atom<T>>);
24
25impl Atoms<String> {
26    fn new<I, S>(input: &Vec<AtomOr<String, I>>, state: &mut S) -> Self
27    where
28        S: DialogState<Interpolation = I>,
29    {
30        let mut atoms = Vec::new();
31        for atom in input {
32            match atom {
33                AtomOr::Atom(a) => atoms.push(a.clone()),
34                AtomOr::Interpolate(i) => atoms.push(Atom::Text(state.interpolate(i))),
35            }
36        }
37        Self(atoms)
38    }
39}
40
41impl<T: std::fmt::Display> std::fmt::Display for Atoms<T> {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        self.0.iter().try_for_each(|a| write!(f, "{a}"))
44    }
45}
46
47impl<T> Atoms<T> {
48    /// Returns an iterator over atoms.
49    pub fn iter(&self) -> std::slice::Iter<Atom<T>> {
50        self.0.iter()
51    }
52
53    /// Returns an iterator that allows modifying each value.
54    pub fn iter_mut(&mut self) -> std::slice::IterMut<Atom<T>> {
55        self.0.iter_mut()
56    }
57}
58
59impl<T> IntoIterator for Atoms<T> {
60    type Item = Atom<T>;
61    type IntoIter = std::vec::IntoIter<Atom<T>>;
62
63    fn into_iter(self) -> Self::IntoIter {
64        self.0.into_iter()
65    }
66}
67
68impl<'a, T> IntoIterator for &'a Atoms<T> {
69    type Item = &'a Atom<T>;
70    type IntoIter = std::slice::Iter<'a, Atom<T>>;
71
72    fn into_iter(self) -> Self::IntoIter {
73        self.iter()
74    }
75}
76
77impl<'a, T> IntoIterator for &'a mut Atoms<T> {
78    type Item = &'a mut Atom<T>;
79    type IntoIter = std::slice::IterMut<'a, Atom<T>>;
80
81    fn into_iter(self) -> Self::IntoIter {
82        self.iter_mut()
83    }
84}
85
86/// Dialog stores all the dialog trees, grouped by section name.
87#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
88#[serde(rename = "d")]
89pub struct Dialog<DA, IF, TE> {
90    #[serde(rename = "s")]
91    sections: HashMap<String, DialogTree<DA, IF, TE>>,
92}
93
94impl<DA, IF, TE> Default for Dialog<DA, IF, TE> {
95    fn default() -> Self {
96        Self {
97            sections: Default::default(),
98        }
99    }
100}
101
102#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(rename = "dt")]
104struct DialogTree<DA, IF, TE> {
105    #[serde(rename = "n")]
106    nodes: Vec<DialogNode<DA, IF, TE>>,
107}
108
109#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
110#[serde(rename = "dn")]
111enum DialogNode<DA, IF, TE> {
112    #[serde(rename = "cs")]
113    CharacterSays(String, Vec<AtomOr<String, TE>>),
114    #[serde(rename = "m")]
115    Message(Vec<AtomOr<String, TE>>),
116    #[serde(rename = "gt")]
117    GoTo(String),
118    #[serde(rename = "r")]
119    #[allow(clippy::type_complexity)]
120    Responses(Vec<(Vec<AtomOr<String, TE>>, Option<IF>, DialogTree<DA, IF, TE>)>),
121
122    #[serde(rename = "da")]
123    DoAction(DA),
124    #[serde(rename = "c")]
125    Conditional(Vec<(Option<IF>, DialogTree<DA, IF, TE>)>),
126}
127
128/// DialogItem is a single logical dialog node to be displayed to the user.
129#[derive(Debug, Clone, PartialEq, Eq)]
130pub enum DialogItem {
131    /// A character says something.
132    CharacterSays(String, Atoms),
133    /// A simple message.
134    Message(Atoms),
135    /// Go to another section.
136    GoTo(String),
137    /// List of possible responses.
138    Responses(Vec<Atoms>),
139}
140
141impl std::fmt::Display for DialogItem {
142    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
143        use DialogItem::*;
144        match self {
145            CharacterSays(ch, atoms) => write!(f, "{ch}: {atoms}"),
146            Message(atoms) => write!(f, "{atoms}"),
147            GoTo(gt) => write!(f, "=> {gt}"),
148            Responses(resp) => write!(f, "Responses: [{resp:?}]"),
149        }
150    }
151}
152
153#[derive(Debug, Default, Clone, PartialEq, Eq)]
154struct SubIndex {
155    index: usize,
156    response: Option<usize>,
157    inner: Box<Option<SubIndex>>,
158}
159
160/// DialogIndex is used to track where in a DialogTree a player is.
161#[derive(Debug, Default, Clone, PartialEq, Eq)]
162pub struct DialogIndex {
163    section: String,
164    sub: Option<SubIndex>,
165}
166
167impl std::fmt::Display for DialogIndex {
168    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
169        write!(f, "{}", self.section)?;
170        if let Some(ref sub) = self.sub {
171            write!(f, ".{sub}")?;
172        }
173        Ok(())
174    }
175}
176
177impl std::fmt::Display for SubIndex {
178    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
179        write!(f, "{}", self.index)?;
180        if let Some(response) = self.response {
181            write!(f, "[{response}]")?;
182        }
183        if let Some(sub) = self.inner.as_ref() {
184            write!(f, ".{sub}")?;
185        }
186        Ok(())
187    }
188}
189
190impl SubIndex {
191    fn set_response(&mut self, r: usize) {
192        match self.inner.as_mut() {
193            Some(ref mut i) => i.set_response(r),
194            None => self.response = Some(r),
195        }
196    }
197}
198
199impl DialogIndex {
200    /// When the dialog is on a 'Response' node, this sets the index of the response selected.
201    pub fn set_response(&mut self, r: usize) {
202        self.sub
203            .as_mut()
204            .expect("sub index to not be None")
205            .set_response(r);
206    }
207}
208
209/// When merging two `Dialog` objects together, duplicate sections will trigger this error.
210#[derive(Debug, Default, Clone, PartialEq, Eq, thiserror::Error)]
211#[error("found duplicate section key: {0}")]
212pub struct DuplicateSectionKey(String);
213
214/// Error encountered while validating
215#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
216pub enum ValidationError {
217    /// All section gotos should refer to an extant section.
218    #[error("found redirect (=>) that refers to a non existent section `{0}`")]
219    UnknownSectionGoTo(String),
220}
221
222/// A list of validation errors
223#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
224pub struct ValidationErrors(Vec<ValidationError>);
225
226impl std::fmt::Display for ValidationErrors {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(
229            f,
230            "encountered {} validation error{}:",
231            self.0.len(),
232            if self.0.len() == 1 { "" } else { "s" }
233        )?;
234        for err in &self.0 {
235            write!(f, "\n\t{err}")?;
236        }
237        Ok(())
238    }
239}
240
241impl<DA, IF, TE> Dialog<DA, IF, TE> {
242    /// Build a new empty dialog object.
243    pub fn new() -> Self {
244        Self::default()
245    }
246
247    /// Validate the Dialog
248    pub fn validate(&self) -> Result<(), ValidationErrors> {
249        let sections = self.sections.keys().collect::<Vec<_>>();
250        let mut errors = Vec::new();
251        self.walk(|node| {
252            if let DialogNode::GoTo(gt) = node {
253                if !sections.contains(&gt) {
254                    errors.push(ValidationError::UnknownSectionGoTo(gt.into()));
255                }
256            }
257        });
258        if errors.is_empty() {
259            Ok(())
260        } else {
261            Err(ValidationErrors(errors))
262        }
263    }
264
265    /// Merge `other` into this dialog object.
266    pub fn merge(&mut self, other: Self) -> Result<(), DuplicateSectionKey> {
267        for (section, data) in other.sections {
268            if self.sections.contains_key(&section) {
269                return Err(DuplicateSectionKey(section));
270            }
271            self.sections.insert(section, data);
272        }
273        Ok(())
274    }
275
276    /// parse a dialog tree from a provided string.
277    #[allow(clippy::type_complexity)]
278    pub fn parse_str(s: &str) -> Result<Self, SocraticError<DA::Err, IF::Err, TE::Err>>
279    where
280        DA: FromStr,
281        IF: FromStr,
282        TE: FromStr,
283    {
284        Ok(parsing::dialog::<DA, IF, TE>(s)?)
285    }
286
287    /// Parse a dialog tree from a reader.
288    #[allow(clippy::type_complexity)]
289    pub fn parse_from_reader<R>(
290        mut reader: R,
291    ) -> Result<Self, SocraticError<DA::Err, IF::Err, TE::Err>>
292    where
293        R: std::io::Read,
294        DA: FromStr,
295        IF: FromStr,
296        TE: FromStr,
297    {
298        let mut s = String::new();
299        reader.read_to_string(&mut s)?;
300        Dialog::parse_str(&s)
301    }
302
303    /// Parse a dialog tree from a file.
304    #[allow(clippy::type_complexity)]
305    pub fn parse_from_file<P>(path: P) -> Result<Self, SocraticError<DA::Err, IF::Err, TE::Err>>
306    where
307        P: AsRef<Path>,
308        DA: FromStr,
309        IF: FromStr,
310        TE: FromStr,
311    {
312        let f = std::fs::File::open(path)?;
313        Dialog::parse_from_reader(f)
314    }
315
316    /// Write the Dialog to a writer using the CBOR format.
317    #[cfg(feature = "cbor")]
318    pub fn packed_to_writer<W>(&self, writer: W) -> Result<(), ciborium::ser::Error<W::Error>>
319    where
320        W: ciborium_io::Write,
321        W::Error: core::fmt::Debug,
322        DA: Serialize,
323        IF: Serialize,
324        TE: Serialize,
325    {
326        ciborium::ser::into_writer(&self.sections, writer)
327    }
328
329    /// Write the Dialog to a file using the CBOR format.
330    #[cfg(feature = "cbor")]
331    pub fn packed_to_file<P: AsRef<Path>>(
332        &self,
333        path: P,
334    ) -> Result<(), ciborium::ser::Error<io::Error>>
335    where
336        DA: Serialize,
337        IF: Serialize,
338        TE: Serialize,
339    {
340        let f = std::fs::File::create(path)?;
341        self.packed_to_writer(f)
342    }
343
344    /// Read the Dialog from a Reader using the CBOR format.
345    #[cfg(feature = "cbor")]
346    pub fn packed_from_reader<R>(reader: R) -> Result<Self, ciborium::de::Error<R::Error>>
347    where
348        R: ciborium_io::Read,
349        R::Error: core::fmt::Debug,
350        DA: serde::de::DeserializeOwned,
351        IF: serde::de::DeserializeOwned,
352        TE: serde::de::DeserializeOwned,
353    {
354        let sections: HashMap<String, DialogTree<DA, IF, TE>> = ciborium::de::from_reader(reader)?;
355        Ok(Dialog { sections })
356    }
357
358    /// Read the Dialog from a Reader using the CBOR format.
359    #[cfg(feature = "cbor")]
360    pub fn packed_from_file<P: AsRef<Path>>(path: P) -> Result<Self, ciborium::de::Error<io::Error>>
361    where
362        DA: serde::de::DeserializeOwned,
363        IF: serde::de::DeserializeOwned,
364        TE: serde::de::DeserializeOwned,
365    {
366        let f = std::fs::File::open(path)?;
367        Self::packed_from_reader(f)
368    }
369}
370
371/// Trait for state objects that interact with dialog.
372pub trait DialogState {
373    /// DoAction type
374    type DoAction;
375
376    /// IF Type
377    type IF;
378
379    /// Interpolation type
380    type Interpolation;
381
382    /// Perform an action on the state.
383    fn do_action(&mut self, command: &Self::DoAction);
384
385    /// Check a conditional against the current state.
386    fn check_condition(&self, command: &Self::IF) -> bool;
387
388    /// Get a string to interpolate.
389    fn interpolate(&self, command: &Self::Interpolation) -> String;
390}
391
392impl DialogState for () {
393    type DoAction = String;
394    type IF = String;
395    type Interpolation = String;
396
397    fn do_action(&mut self, _command: &String) {}
398    fn check_condition(&self, _command: &String) -> bool {
399        true
400    }
401    fn interpolate(&self, command: &String) -> String {
402        command.into()
403    }
404}
405
406impl<DA, IF, TE> Dialog<DA, IF, TE> {
407    /// Get the dialog line represented by the |DialogIndex|.
408    ///
409    /// Returns the `DialogItem` for the associated `index` along with the next index.
410    #[instrument(skip(self, state), fields(index = %index))]
411    pub fn get<S: DialogState<DoAction = DA, IF = IF, Interpolation = TE>>(
412        &self,
413        mut index: DialogIndex,
414        state: &mut S,
415    ) -> Option<(DialogItem, DialogIndex)>
416    where
417        DA: std::fmt::Debug,
418    {
419        let tree = self.sections.get(&index.section)?;
420        let (item, sub_index) = tree.get(index.sub, state)?;
421        index.sub = Some(sub_index);
422        Some((item, index))
423    }
424
425    /// Begins a new dialog session at a given section.
426    #[instrument(skip(self, state))]
427    pub fn begin<S: DialogState<DoAction = DA, IF = IF, Interpolation = TE>>(
428        &self,
429        section: &str,
430        state: &mut S,
431    ) -> Option<(DialogItem, DialogIndex)>
432    where
433        DA: std::fmt::Debug,
434    {
435        self.get(
436            DialogIndex {
437                section: section.into(),
438                sub: None,
439            },
440            state,
441        )
442    }
443
444    fn walk<F>(&self, mut cb: F)
445    where
446        F: FnMut(&DialogNode<DA, IF, TE>),
447    {
448        for tree in self.sections.values() {
449            tree.walk(&mut cb);
450        }
451    }
452}
453
454impl<DA, IF, TE> DialogTree<DA, IF, TE> {
455    fn walk<F>(&self, cb: &mut F)
456    where
457        F: FnMut(&DialogNode<DA, IF, TE>),
458    {
459        for node in &self.nodes {
460            cb(node);
461            match node {
462                DialogNode::Conditional(parts) => {
463                    for (_, tree) in parts {
464                        tree.walk(cb);
465                    }
466                }
467                DialogNode::Responses(responses) => {
468                    for (_, _, tree) in responses {
469                        tree.walk(cb);
470                    }
471                }
472                _ => {}
473            }
474        }
475    }
476
477    fn get<S: DialogState<DoAction = DA, IF = IF, Interpolation = TE>>(
478        &self,
479        index: Option<SubIndex>,
480        state: &mut S,
481    ) -> Option<(DialogItem, SubIndex)>
482    where
483        DA: std::fmt::Debug,
484    {
485        let span = match &index {
486            Some(ref i) => info_span!("get", index = %i),
487            None => info_span!("get", index = %"None"),
488        };
489        let _enter = span.enter();
490
491        let mut index = index.unwrap_or_default();
492        match self.nodes.get(index.index)? {
493            DialogNode::CharacterSays(character, says) => {
494                let says = Atoms::new(says, state);
495                info!("CharacterSays({character}, {says})");
496                index.index += 1;
497                Some((DialogItem::CharacterSays(character.clone(), says), index))
498            }
499            DialogNode::Message(msg) => {
500                let msg = Atoms::new(msg, state);
501                info!("Message({msg})");
502                index.index += 1;
503                Some((DialogItem::Message(msg), index))
504            }
505            DialogNode::GoTo(gt) => {
506                info!("GoTo({gt})");
507                index.index += 1;
508                Some((DialogItem::GoTo(gt.clone()), index))
509            }
510            DialogNode::Responses(responses) => {
511                let responses = responses
512                    .iter()
513                    .filter(|(_, i, _)| {
514                        i.as_ref().map(|i| state.check_condition(i)).unwrap_or(true)
515                    })
516                    .collect::<Vec<_>>();
517                info!("Ask...");
518                if let Some(resp) = index.response {
519                    let response_tree = &responses.get(resp)?.2;
520                    match response_tree.get(*index.inner, state) {
521                        None => {
522                            index.index += 1;
523                            index.response = None;
524                            index.inner = Box::new(None);
525                            self.get(Some(index), state)
526                        }
527                        Some((item, inner)) => {
528                            *index.inner = Some(inner);
529                            Some((item, index))
530                        }
531                    }
532                } else {
533                    Some((
534                        DialogItem::Responses(
535                            responses
536                                .iter()
537                                .map(|(q, _, _)| Atoms::new(q, state))
538                                .collect(),
539                        ),
540                        index,
541                    ))
542                }
543            }
544            DialogNode::DoAction(cmd) => {
545                info!("DoAction({cmd:?})");
546                index.index += 1;
547                state.do_action(cmd);
548                self.get(Some(index), state)
549            }
550            DialogNode::Conditional(conditions) => {
551                if index.response.is_none() {
552                    for (i, (check, _)) in conditions.iter().enumerate() {
553                        if let Some(c) = check {
554                            if state.check_condition(c) {
555                                index.response = Some(i);
556                            }
557                        } else {
558                            index.response = Some(i)
559                        }
560                    }
561                    // No checks passed.
562                    if index.response.is_none() {
563                        index.index += 1;
564                        return self.get(Some(index), state);
565                    }
566                }
567                let resp = index.response.expect("to be not none");
568                let response_tree = &conditions.get(resp)?.1;
569                match response_tree.get(*index.inner, state) {
570                    None => {
571                        index.index += 1;
572                        index.response = None;
573                        index.inner = Box::new(None);
574                        self.get(Some(index), state)
575                    }
576                    Some((item, inner)) => {
577                        *index.inner = Some(inner);
578                        Some((item, index))
579                    }
580                }
581            }
582        }
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use test_log::test;
590
591    #[cfg(feature = "cbor")]
592    #[test]
593    fn test_socrates() -> Result<(), anyhow::Error> {
594        let s = Dialog::parse_str(
595            r#"
596
597:: section_name
598doot
599- hi
600    test
601- hello
602    test2
603- trust issues => section_name
604boot
605=> dingle
606
607:: dingle
608bingle"#,
609        )?;
610        let (line, ix) = s.begin("section_name", &mut ()).unwrap();
611        info!("1 {line} {ix:?}");
612        let (line, mut ix) = s.get(ix, &mut ()).unwrap();
613        info!("2 {line} {ix:?}");
614        ix.set_response(2);
615        info!("3 {ix:?}");
616        let (line, ix) = s.get(ix, &mut ()).unwrap();
617        info!("4 {line} {ix:?}");
618        let (line, ix) = s.get(ix, &mut ()).unwrap();
619        info!("5 {line} {ix:?}");
620        s.packed_to_file("test.txt").unwrap();
621
622        let s2 = Dialog::packed_from_file("test.txt").unwrap();
623        println!("{:?}", s2);
624        assert_eq!(s, s2);
625
626        Ok(())
627    }
628}