llm_chain/parsing.rs
1//! Functions for parsing the output of LLMs, including YAML and Markdown.
2//!
3//! This module provides a set of functions that allow you to parse and extract useful information from the output of Large Language Models (LLMs) in YAML and Markdown formats. These functions can be used to transform the LLM output into a more structured and usable format, enabling seamless integration with your applications.
4//!
5//! Key features include:
6//! - Parsing YAML and Markdown content produced by LLMs
7//! - Handling common edge cases and being lenient with LLM outputs
8//! - Extracting and deserializing YAML objects from text
9//!
10//! With these functions, you can easily work with the outputs of LLMs, simplifying the process of integrating LLMs into your applications and workflows.
11
12use markdown::{
13 mdast::{Code, Node, Text},
14 to_mdast, ParseOptions,
15};
16use serde::de::DeserializeOwned;
17use serde_yaml::Value;
18use std::collections::VecDeque;
19use thiserror::Error;
20
21/// Errors occuring when parsing
22#[derive(Error, Debug)]
23#[error(transparent)]
24pub struct ExtractionError(#[from] ExtractionErrorImpl);
25
26/// Enum representing the different error cases that can occur during YAML extraction.
27#[derive(Error, Debug)]
28enum ExtractionErrorImpl {
29 /// The YAML content was valid, but it did not match the expected format.
30 #[error("The YAML was valid, but it didn't match the expected format: {0}")]
31 YamlFoundButFormatWrong(serde_yaml::Error),
32
33 /// An error occurred while parsing the YAML content.
34 #[error("YAML parsing failed with: {0}")]
35 ParseError(#[from] serde_yaml::Error),
36
37 /// No YAML content was found to parse.
38 #[error("The string to parse was empty")]
39 NoneFound,
40}
41
42impl ExtractionErrorImpl {
43 /// Determines the most representative error between two instances of `ExtractionErrorImpl`.
44 ///
45 /// The function prefers `YamlFoundButFormatWrong` errors over `ParseError` errors,
46 /// and `ParseError` errors over `NoneFound` errors.
47 fn most_representative(a: Self, b: Self) -> Self {
48 match (&a, &b) {
49 (Self::YamlFoundButFormatWrong(_), _) => a,
50 (_, Self::YamlFoundButFormatWrong(_)) => b,
51 (Self::ParseError(_), _) => a,
52 (_, Self::ParseError(_)) => b,
53 _ => a,
54 }
55 }
56}
57
58/// Attempts to extract YAML content from a given code block and deserialize it into the specified type.
59///
60/// # Arguments
61///
62/// * `code_block` - A string slice containing the YAML content to be extracted and deserialized.
63///
64/// # Returns
65///
66/// * `Ok(T)` - If the YAML content is successfully extracted and deserialized into the specified type.
67/// * `Err(ExtractionErrorImpl)` - If an error occurs during extraction or deserialization.
68///
69/// # Type Parameters
70///
71/// * `T: DeserializeOwned` - The type into which the extracted YAML content should be deserialized.
72fn extract_yaml<T: DeserializeOwned>(code_block: &str) -> Result<T, ExtractionErrorImpl> {
73 // Ensure that the input code block is not empty.
74 let code_block = Some(code_block)
75 .filter(|s| !s.is_empty())
76 .ok_or_else(|| ExtractionErrorImpl::NoneFound)?;
77
78 // Parse the code block as YAML.
79 let yaml: Value = serde_yaml::from_str(code_block)?;
80
81 // Attempt to deserialize the YAML into the specified type, handling any format issues.
82 serde_yaml::from_value(yaml).map_err(ExtractionErrorImpl::YamlFoundButFormatWrong)
83}
84
85/// Attempts to find a YAML object in a string and deserialize it into the specified type.
86/// # Arguments
87///
88/// * `text` - A string slice containing the document which may contain YAML content.
89///
90/// # Returns
91///
92/// * `Ok(Vec<T>)` - If the YAML content is successfully extracted and deserialized into the specified type.
93/// * `Err(ExtractionError)` - If an error occurs during extraction or deserialization.
94///
95/// # Type Parameters
96///
97/// * `T: DeserializeOwned` - The type into which the extracted YAML content should be deserialized.
98///
99/// # Examples
100///
101/// It handles the obvious case where it is just YAML.
102///
103/// ```
104/// #[derive(serde::Deserialize)]
105/// struct Dummy {
106/// hello: String
107/// }
108/// use llm_chain::parsing::find_yaml;
109/// let data = "
110/// hello: world
111/// ";
112/// let data: Vec<Dummy> = find_yaml(data).unwrap();
113/// assert_eq!(data[0].hello, "world");
114/// ```
115///
116/// It handles the case where it is in a code block.
117///
118/// ```
119/// use llm_chain::parsing::find_yaml;
120/// // NOTE: we are escaping the backticks because this is a doc test.
121/// let data = "
122/// \u{60}``yaml
123/// hello: world
124/// \u{60}``
125/// ";
126/// find_yaml::<serde_yaml::Value>(data).unwrap();
127/// ```
128pub fn find_yaml<T: DeserializeOwned>(text: &str) -> Result<Vec<T>, ExtractionError> {
129 let mut current_error = ExtractionErrorImpl::NoneFound;
130 if text.is_empty() {
131 return Err(current_error.into());
132 }
133
134 // Attempt YAML parsing if it doesn't look like markdown output.
135 if !text.starts_with("```") {
136 match extract_yaml(text) {
137 Ok(o) => return Ok(vec![o]),
138 Err(e) => current_error = ExtractionErrorImpl::most_representative(current_error, e),
139 }
140 }
141
142 // Parse the input text as markdown.
143 let options = ParseOptions::default();
144 let ast = to_mdast(text, &options).expect("we're not using MDX, so this should never fail");
145
146 // Nodes to visit.
147 let mut nodes = vec![ast];
148
149 let mut found: VecDeque<_> = VecDeque::new();
150 while let Some(node) = nodes.pop() {
151 if let Some(children) = node.children() {
152 children.iter().for_each(|child| nodes.push(child.clone()));
153 }
154
155 // Check for code blocks containing YAML.
156 if let Node::Code(Code { value, lang, .. }) = node {
157 let lang = lang.unwrap_or_default();
158 match lang.as_str() {
159 "yaml" | "yml" | "json" | "" => {
160 let code_block = value.as_str();
161 match extract_yaml(code_block) {
162 Ok(o) => found.push_front(o),
163 Err(e) => {
164 current_error =
165 ExtractionErrorImpl::most_representative(current_error, e)
166 }
167 }
168 }
169 _ => {}
170 }
171 }
172 }
173 if !found.is_empty() {
174 Ok(found.into())
175 } else {
176 Err(current_error.into())
177 }
178}
179
180/// Extracts labeled text from markdown
181///
182/// LLMs often generate text that looks something like this
183/// ```markdown
184/// - *foo*: bar
185/// - hello: world
186/// ```
187/// Which we want to parse as key value pairs (foo, bar), (hello, world).
188///
189/// # Parameters
190/// - `text` the text to parse
191///
192/// # Returns
193/// Vec<(String, String)> A vector of key value pairs.
194///
195/// # Examples
196///
197/// ```
198/// use llm_chain::parsing::extract_labeled_text;
199/// let data = "
200/// - alpha: beta
201/// - *gamma*: delta
202/// ";
203/// let labs = extract_labeled_text(data);
204/// println!("{:?}", labs);
205/// assert_eq!(labs[0], ("alpha".to_string(), "beta".to_string()));
206/// ```
207pub fn extract_labeled_text(text: &str) -> Vec<(String, String)> {
208 let options = ParseOptions::default();
209 let ast = to_mdast(text, &options).expect("markdown parsing can't fail");
210 let mut nodes = VecDeque::new();
211 nodes.push_back(ast);
212 let mut extracted_labels = Vec::new();
213
214 while let Some(node) = nodes.pop_front() {
215 let found = match &node {
216 Node::Text(Text { value, .. }) => {
217 extract_label_and_text(value.to_owned()).map(|(label, text)| (label, text))
218 }
219 Node::Paragraph(_) | Node::ListItem(_) => {
220 find_labeled_text(&node).map(|(label, text)| (label, text))
221 }
222 _ => None,
223 };
224 if let Some(kv) = found {
225 // If found push to found
226 extracted_labels.push(kv)
227 } else if let Some(children) = node.children() {
228 // If not found recur into it.
229 for (index, child) in children.iter().cloned().enumerate() {
230 nodes.insert(index, child);
231 }
232 }
233 }
234 extracted_labels
235}
236
237/// Finds labeled text
238///
239/// This function looks for patterns such as `**label**: text.
240///
241/// Returns an option indicating whether a label was found, and if so, the label and text.
242fn find_labeled_text(n: &Node) -> Option<(String, String)> {
243 if let Node::Text(Text { value, .. }) = n {
244 extract_label_and_text(value.to_owned())
245 } else {
246 let children = n.children()?;
247 // There should be exactly two children...
248 if children.len() == 2 {
249 let key = children
250 .get(0)
251 .map(inner_text)
252 .map(format_key)
253 .filter(|k| !k.is_empty());
254 let value = children.get(1).map(inner_text).map(format_value);
255 key.and_then(|key| value.map(|value| (key, value)))
256 } else {
257 None
258 }
259 }
260}
261
262fn extract_label_and_text(text: String) -> Option<(String, String)> {
263 let value_split: Vec<&str> = text.splitn(2, ':').collect();
264
265 if value_split.len() == 2 {
266 let label = value_split[0].trim().to_string();
267 if label.is_empty() {
268 return None;
269 }
270 let text = value_split[1].trim().to_string();
271 Some((label, text))
272 } else {
273 None
274 }
275}
276
277/// Returns the inner text
278fn inner_text(n: &Node) -> String {
279 if let Node::Text(Text { value, .. }) = n {
280 return value.to_owned();
281 }
282 let mut deq = VecDeque::new();
283 deq.push_back(n.clone());
284 let mut text = String::new();
285 while let Some(node) = deq.pop_front() {
286 if let Some(children) = node.children() {
287 deq.extend(children.iter().cloned());
288 }
289 if let Node::Text(Text { value, .. }) = node {
290 text.push_str(value.as_str());
291 }
292 }
293 text
294}
295
296// Formats the key trimming it and remvove a potential ":" suffix
297fn format_key(s: String) -> String {
298 let key = s.trim();
299 key.strip_suffix(':').unwrap_or(key).to_owned()
300}
301
302// Formats the value trimming, stripping potential ":" and then retrimming the start
303fn format_value(s: String) -> String {
304 s.trim()
305 .strip_prefix(':')
306 .unwrap_or(&s)
307 .trim_start()
308 .to_owned()
309}