elements_miniscript/
expression.rs

1// Written in 2019 by Andrew Poelstra <apoelstra@wpsoftware.net>
2// SPDX-License-Identifier: CC0-1.0
3
4//! # Function-like Expression Language
5//!
6
7use std::fmt;
8use std::str::FromStr;
9
10use bitcoin_miniscript::expression::check_valid_chars;
11
12use crate::{errstr, Error, MAX_RECURSION_DEPTH};
13
14#[derive(Debug, Clone)]
15/// A token of the form `x(...)` or `x`
16pub struct Tree<'a> {
17    /// The name `x`
18    pub name: &'a str,
19    /// The comma-separated contents of the `(...)`, if any
20    pub args: Vec<Tree<'a>>,
21}
22// or_b(pk(A),pk(B))
23//
24// A = musig(musig(B,C),D,E)
25// or_b()
26// pk(A), pk(B)
27
28/// A trait for extracting a structure from a Tree representation in token form
29pub trait FromTree: Sized {
30    /// Extract a structure from Tree representation
31    fn from_tree(top: &Tree<'_>) -> Result<Self, Error>;
32}
33
34impl<'a> fmt::Display for Tree<'a> {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        write!(f, "({}", self.name)?;
37        for arg in &self.args {
38            write!(f, ",{}", arg)?;
39        }
40        write!(f, ")")
41    }
42}
43
44enum Found {
45    Nothing,
46    LBracket(usize), // Either a left ( or {
47    Comma(usize),
48    RBracket(usize), // Either a right ) or }
49}
50
51fn next_expr(sl: &str, delim: char) -> Found {
52    // Decide whether we are parsing a key or not.
53    // When parsing a key ignore all the '(' and ')'.
54    // We keep count of lparan whenever we are inside a key context
55    // We exit the context whenever we find the corresponding ')'
56    // in which we entered the context. This allows to special case
57    // parse the '(' ')' inside key expressions.(key or musig(keys)).
58    let mut key_ctx = false;
59    let mut key_lparan_count = 0;
60    let mut found = Found::Nothing;
61    if delim == '(' {
62        for (n, ch) in sl.char_indices() {
63            match ch {
64                '(' => {
65                    // already inside a key context
66                    if key_ctx {
67                        key_lparan_count += 1;
68                    } else if &sl[..n] == "musig" {
69                        key_lparan_count = 1;
70                        key_ctx = true;
71                    } else {
72                        found = Found::LBracket(n);
73                        break;
74                    }
75                }
76                ',' => {
77                    if !key_ctx {
78                        found = Found::Comma(n);
79                        break;
80                    }
81                }
82                ')' => {
83                    if key_ctx {
84                        key_lparan_count -= 1;
85                        if key_lparan_count == 0 {
86                            key_ctx = false;
87                        }
88                    } else {
89                        found = Found::RBracket(n);
90                        break;
91                    }
92                }
93                _ => {}
94            }
95        }
96    } else if delim == '{' {
97        let mut new_count = 0;
98        for (n, ch) in sl.char_indices() {
99            match ch {
100                '{' => {
101                    found = Found::LBracket(n);
102                    break;
103                }
104                '(' => {
105                    new_count += 1;
106                }
107                ',' => {
108                    if new_count == 0 {
109                        found = Found::Comma(n);
110                        break;
111                    }
112                }
113                ')' => {
114                    new_count -= 1;
115                }
116                '}' => {
117                    found = Found::RBracket(n);
118                    break;
119                }
120                _ => {}
121            }
122        }
123    } else {
124        unreachable!("{}", "Internal: delimiters in parsing must be '(' or '{'");
125    }
126    found
127}
128
129// Get the corresponding delim
130fn closing_delim(delim: char) -> char {
131    match delim {
132        '(' => ')',
133        '{' => '}',
134        _ => unreachable!("Unknown delimiter"),
135    }
136}
137
138impl<'a> Tree<'a> {
139    /// Parse an expression with round brackets
140    pub fn from_slice(sl: &'a str) -> Result<(Tree<'a>, &'a str), Error> {
141        // Parsing TapTree or just miniscript
142        Self::from_slice_delim(sl, 0u32, '(')
143    }
144
145    pub(crate) fn from_slice_delim(
146        mut sl: &'a str,
147        depth: u32,
148        delim: char,
149    ) -> Result<(Tree<'a>, &'a str), Error> {
150        if depth >= MAX_RECURSION_DEPTH {
151            return Err(Error::MaxRecursiveDepthExceeded);
152        }
153
154        match next_expr(sl, delim) {
155            // String-ending terminal
156            Found::Nothing => Ok((
157                Tree {
158                    name: sl,
159                    args: vec![],
160                },
161                "",
162            )),
163            // Terminal
164            Found::Comma(n) | Found::RBracket(n) => Ok((
165                Tree {
166                    name: &sl[..n],
167                    args: vec![],
168                },
169                &sl[n..],
170            )),
171            // Function call
172            Found::LBracket(n) => {
173                let mut ret = Tree {
174                    name: &sl[..n],
175                    args: vec![],
176                };
177
178                sl = &sl[n + 1..];
179                loop {
180                    let (arg, new_sl) = Tree::from_slice_delim(sl, depth + 1, delim)?;
181                    ret.args.push(arg);
182
183                    if new_sl.is_empty() {
184                        return Err(Error::ExpectedChar(closing_delim(delim)));
185                    }
186
187                    sl = &new_sl[1..];
188                    match new_sl.as_bytes()[0] {
189                        b',' => {}
190                        last_byte => {
191                            if last_byte == closing_delim(delim) as u8 {
192                                break;
193                            } else {
194                                return Err(Error::ExpectedChar(closing_delim(delim)));
195                            }
196                        }
197                    }
198                }
199                Ok((ret, sl))
200            }
201        }
202    }
203
204    /// Parses a tree from a string
205    #[allow(clippy::should_implement_trait)] // seems to be a false positive
206    pub fn from_str(s: &'a str) -> Result<Tree<'a>, Error> {
207        check_valid_chars(s)?;
208
209        let (top, rem) = Tree::from_slice(s)?;
210        if rem.is_empty() {
211            Ok(top)
212        } else {
213            Err(errstr(rem))
214        }
215    }
216}
217
218/// Parse a string as a u32, for timelocks or thresholds
219pub fn parse_num<T: FromStr>(s: &str) -> Result<T, Error> {
220    if s.len() > 1 {
221        let ch = s.chars().next().unwrap();
222        let ch = if ch == '-' {
223            s.chars().nth(1).ok_or(Error::Unexpected(
224                "Negative number must follow dash sign".to_string(),
225            ))?
226        } else {
227            ch
228        };
229        if !('1'..='9').contains(&ch) {
230            return Err(Error::Unexpected(
231                "Number must start with a digit 1-9".to_string(),
232            ));
233        }
234    }
235    T::from_str(s).map_err(|_| errstr(s))
236}
237
238/// Attempts to parse a terminal expression
239pub fn terminal<T, F, Err>(term: &Tree<'_>, convert: F) -> Result<T, Error>
240where
241    F: FnOnce(&str) -> Result<T, Err>,
242    Err: ToString,
243{
244    if term.args.is_empty() {
245        convert(term.name).map_err(|e| Error::Unexpected(e.to_string()))
246    } else {
247        Err(errstr(term.name))
248    }
249}
250
251/// Attempts to parse an expression with exactly one child
252pub fn unary<L, T, F>(term: &Tree<'_>, convert: F) -> Result<T, Error>
253where
254    L: FromTree,
255    F: FnOnce(L) -> T,
256{
257    if term.args.len() == 1 {
258        let left = FromTree::from_tree(&term.args[0])?;
259        Ok(convert(left))
260    } else {
261        Err(errstr(term.name))
262    }
263}
264
265/// Attempts to parse an expression with exactly two children
266pub fn binary<L, R, T, F>(term: &Tree<'_>, convert: F) -> Result<T, Error>
267where
268    L: FromTree,
269    R: FromTree,
270    F: FnOnce(L, R) -> T,
271{
272    if term.args.len() == 2 {
273        let left = FromTree::from_tree(&term.args[0])?;
274        let right = FromTree::from_tree(&term.args[1])?;
275        Ok(convert(left, right))
276    } else {
277        Err(errstr(term.name))
278    }
279}
280
281#[cfg(test)]
282mod tests {
283
284    use super::parse_num;
285
286    #[test]
287    fn test_parse_num() {
288        assert!(parse_num::<u32>("0").is_ok());
289        assert!(parse_num::<u32>("00").is_err());
290        assert!(parse_num::<u32>("0000").is_err());
291        assert!(parse_num::<u32>("06").is_err());
292        assert!(parse_num::<u32>("+6").is_err());
293        assert!(parse_num::<u32>("-6").is_err());
294    }
295}