llm_chain/
parsing.rs

1//! Functions for parsing the output of LLMs, including YAML and Markdown.
2//!
3//! This module provides a set of functions that allow you to parse and extract useful information from the output of Large Language Models (LLMs) in YAML and Markdown formats. These functions can be used to transform the LLM output into a more structured and usable format, enabling seamless integration with your applications.
4//!
5//! Key features include:
6//! - Parsing YAML and Markdown content produced by LLMs
7//! - Handling common edge cases and being lenient with LLM outputs
8//! - Extracting and deserializing YAML objects from text
9//!
10//! With these functions, you can easily work with the outputs of LLMs, simplifying the process of integrating LLMs into your applications and workflows.
11
12use markdown::{
13    mdast::{Code, Node, Text},
14    to_mdast, ParseOptions,
15};
16use serde::de::DeserializeOwned;
17use serde_yaml::Value;
18use std::collections::VecDeque;
19use thiserror::Error;
20
21/// Errors occuring when parsing
22#[derive(Error, Debug)]
23#[error(transparent)]
24pub struct ExtractionError(#[from] ExtractionErrorImpl);
25
26/// Enum representing the different error cases that can occur during YAML extraction.
27#[derive(Error, Debug)]
28enum ExtractionErrorImpl {
29    /// The YAML content was valid, but it did not match the expected format.
30    #[error("The YAML was valid, but it didn't match the expected format: {0}")]
31    YamlFoundButFormatWrong(serde_yaml::Error),
32
33    /// An error occurred while parsing the YAML content.
34    #[error("YAML parsing failed with: {0}")]
35    ParseError(#[from] serde_yaml::Error),
36
37    /// No YAML content was found to parse.
38    #[error("The string to parse was empty")]
39    NoneFound,
40}
41
42impl ExtractionErrorImpl {
43    /// Determines the most representative error between two instances of `ExtractionErrorImpl`.
44    ///
45    /// The function prefers `YamlFoundButFormatWrong` errors over `ParseError` errors,
46    /// and `ParseError` errors over `NoneFound` errors.
47    fn most_representative(a: Self, b: Self) -> Self {
48        match (&a, &b) {
49            (Self::YamlFoundButFormatWrong(_), _) => a,
50            (_, Self::YamlFoundButFormatWrong(_)) => b,
51            (Self::ParseError(_), _) => a,
52            (_, Self::ParseError(_)) => b,
53            _ => a,
54        }
55    }
56}
57
58/// Attempts to extract YAML content from a given code block and deserialize it into the specified type.
59///
60/// # Arguments
61///
62/// * `code_block` - A string slice containing the YAML content to be extracted and deserialized.
63///
64/// # Returns
65///
66/// * `Ok(T)` - If the YAML content is successfully extracted and deserialized into the specified type.
67/// * `Err(ExtractionErrorImpl)` - If an error occurs during extraction or deserialization.
68///
69/// # Type Parameters
70///
71/// * `T: DeserializeOwned` - The type into which the extracted YAML content should be deserialized.
72fn extract_yaml<T: DeserializeOwned>(code_block: &str) -> Result<T, ExtractionErrorImpl> {
73    // Ensure that the input code block is not empty.
74    let code_block = Some(code_block)
75        .filter(|s| !s.is_empty())
76        .ok_or_else(|| ExtractionErrorImpl::NoneFound)?;
77
78    // Parse the code block as YAML.
79    let yaml: Value = serde_yaml::from_str(code_block)?;
80
81    // Attempt to deserialize the YAML into the specified type, handling any format issues.
82    serde_yaml::from_value(yaml).map_err(ExtractionErrorImpl::YamlFoundButFormatWrong)
83}
84
85/// Attempts to find a YAML object in a string and deserialize it into the specified type.
86/// # Arguments
87///
88/// * `text` - A string slice containing the document which may contain YAML content.
89///
90/// # Returns
91///
92/// * `Ok(Vec<T>)` - If the YAML content is successfully extracted and deserialized into the specified type.
93/// * `Err(ExtractionError)` - If an error occurs during extraction or deserialization.
94///
95/// # Type Parameters
96///
97/// * `T: DeserializeOwned` - The type into which the extracted YAML content should be deserialized.
98///
99/// # Examples
100///
101/// It handles the obvious case where it is just YAML.
102///
103/// ```
104/// #[derive(serde::Deserialize)]
105/// struct Dummy {
106///    hello: String
107/// }
108/// use llm_chain::parsing::find_yaml;
109/// let data = "
110/// hello: world
111/// ";
112/// let data: Vec<Dummy> = find_yaml(data).unwrap();
113/// assert_eq!(data[0].hello, "world");
114/// ```
115///
116/// It handles the case where it is in a code block.
117///
118/// ```
119/// use llm_chain::parsing::find_yaml;
120/// // NOTE: we are escaping the backticks because this is a doc test.
121/// let data = "
122/// \u{60}``yaml
123/// hello: world
124/// \u{60}``
125/// ";
126/// find_yaml::<serde_yaml::Value>(data).unwrap();
127/// ```
128pub fn find_yaml<T: DeserializeOwned>(text: &str) -> Result<Vec<T>, ExtractionError> {
129    let mut current_error = ExtractionErrorImpl::NoneFound;
130    if text.is_empty() {
131        return Err(current_error.into());
132    }
133
134    // Attempt YAML parsing if it doesn't look like markdown output.
135    if !text.starts_with("```") {
136        match extract_yaml(text) {
137            Ok(o) => return Ok(vec![o]),
138            Err(e) => current_error = ExtractionErrorImpl::most_representative(current_error, e),
139        }
140    }
141
142    // Parse the input text as markdown.
143    let options = ParseOptions::default();
144    let ast = to_mdast(text, &options).expect("we're not using MDX, so this should never fail");
145
146    // Nodes to visit.
147    let mut nodes = vec![ast];
148
149    let mut found: VecDeque<_> = VecDeque::new();
150    while let Some(node) = nodes.pop() {
151        if let Some(children) = node.children() {
152            children.iter().for_each(|child| nodes.push(child.clone()));
153        }
154
155        // Check for code blocks containing YAML.
156        if let Node::Code(Code { value, lang, .. }) = node {
157            let lang = lang.unwrap_or_default();
158            match lang.as_str() {
159                "yaml" | "yml" | "json" | "" => {
160                    let code_block = value.as_str();
161                    match extract_yaml(code_block) {
162                        Ok(o) => found.push_front(o),
163                        Err(e) => {
164                            current_error =
165                                ExtractionErrorImpl::most_representative(current_error, e)
166                        }
167                    }
168                }
169                _ => {}
170            }
171        }
172    }
173    if !found.is_empty() {
174        Ok(found.into())
175    } else {
176        Err(current_error.into())
177    }
178}
179
180/// Extracts labeled text from markdown
181///
182/// LLMs often generate text that looks something like this
183/// ```markdown
184/// - *foo*: bar
185/// - hello: world
186/// ```
187/// Which we want to parse as key value pairs (foo, bar), (hello, world).
188///
189/// # Parameters
190/// - `text` the text to parse
191///
192/// # Returns
193/// Vec<(String, String)> A vector of key value pairs.
194///
195/// # Examples
196///
197/// ```
198/// use llm_chain::parsing::extract_labeled_text;
199/// let data = "
200/// - alpha: beta
201/// - *gamma*: delta
202/// ";
203/// let labs = extract_labeled_text(data);
204/// println!("{:?}", labs);
205/// assert_eq!(labs[0], ("alpha".to_string(), "beta".to_string()));
206/// ```
207pub fn extract_labeled_text(text: &str) -> Vec<(String, String)> {
208    let options = ParseOptions::default();
209    let ast = to_mdast(text, &options).expect("markdown parsing can't fail");
210    let mut nodes = VecDeque::new();
211    nodes.push_back(ast);
212    let mut extracted_labels = Vec::new();
213
214    while let Some(node) = nodes.pop_front() {
215        let found = match &node {
216            Node::Text(Text { value, .. }) => {
217                extract_label_and_text(value.to_owned()).map(|(label, text)| (label, text))
218            }
219            Node::Paragraph(_) | Node::ListItem(_) => {
220                find_labeled_text(&node).map(|(label, text)| (label, text))
221            }
222            _ => None,
223        };
224        if let Some(kv) = found {
225            // If found push to found
226            extracted_labels.push(kv)
227        } else if let Some(children) = node.children() {
228            // If not found recur into it.
229            for (index, child) in children.iter().cloned().enumerate() {
230                nodes.insert(index, child);
231            }
232        }
233    }
234    extracted_labels
235}
236
237/// Finds labeled text
238///
239/// This function looks for patterns such as `**label**: text.
240///
241/// Returns an option indicating whether a label was found, and if so, the label and text.
242fn find_labeled_text(n: &Node) -> Option<(String, String)> {
243    if let Node::Text(Text { value, .. }) = n {
244        extract_label_and_text(value.to_owned())
245    } else {
246        let children = n.children()?;
247        // There should be exactly two children...
248        if children.len() == 2 {
249            let key = children
250                .get(0)
251                .map(inner_text)
252                .map(format_key)
253                .filter(|k| !k.is_empty());
254            let value = children.get(1).map(inner_text).map(format_value);
255            key.and_then(|key| value.map(|value| (key, value)))
256        } else {
257            None
258        }
259    }
260}
261
262fn extract_label_and_text(text: String) -> Option<(String, String)> {
263    let value_split: Vec<&str> = text.splitn(2, ':').collect();
264
265    if value_split.len() == 2 {
266        let label = value_split[0].trim().to_string();
267        if label.is_empty() {
268            return None;
269        }
270        let text = value_split[1].trim().to_string();
271        Some((label, text))
272    } else {
273        None
274    }
275}
276
277/// Returns the inner text
278fn inner_text(n: &Node) -> String {
279    if let Node::Text(Text { value, .. }) = n {
280        return value.to_owned();
281    }
282    let mut deq = VecDeque::new();
283    deq.push_back(n.clone());
284    let mut text = String::new();
285    while let Some(node) = deq.pop_front() {
286        if let Some(children) = node.children() {
287            deq.extend(children.iter().cloned());
288        }
289        if let Node::Text(Text { value, .. }) = node {
290            text.push_str(value.as_str());
291        }
292    }
293    text
294}
295
296// Formats the key trimming it and remvove a potential ":" suffix
297fn format_key(s: String) -> String {
298    let key = s.trim();
299    key.strip_suffix(':').unwrap_or(key).to_owned()
300}
301
302// Formats the value trimming, stripping potential ":" and then retrimming the start
303fn format_value(s: String) -> String {
304    s.trim()
305        .strip_prefix(':')
306        .unwrap_or(&s)
307        .trim_start()
308        .to_owned()
309}