biscuit_auth/token/builder/
rule.rs

1/*
2 * Copyright (c) 2019 Geoffroy Couprie <contact@geoffroycouprie.com> and Contributors to the Eclipse Foundation.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5use std::{collections::HashMap, convert::TryFrom, fmt, str::FromStr};
6
7use nom::Finish;
8
9use crate::{
10    datalog::{self, SymbolTable},
11    error, PublicKey,
12};
13
14use super::{Convert, Expression, Predicate, Scope, Term, ToAnyParam};
15
16/// Builder for a Datalog rule
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct Rule {
19    pub head: Predicate,
20    pub body: Vec<Predicate>,
21    pub expressions: Vec<Expression>,
22    pub parameters: Option<HashMap<String, Option<Term>>>,
23    pub scopes: Vec<Scope>,
24    pub scope_parameters: Option<HashMap<String, Option<PublicKey>>>,
25}
26
27impl Rule {
28    pub fn new(
29        head: Predicate,
30        body: Vec<Predicate>,
31        expressions: Vec<Expression>,
32        scopes: Vec<Scope>,
33    ) -> Rule {
34        let mut parameters = HashMap::new();
35        let mut scope_parameters = HashMap::new();
36        for term in &head.terms {
37            term.extract_parameters(&mut parameters);
38        }
39
40        for predicate in &body {
41            for term in &predicate.terms {
42                term.extract_parameters(&mut parameters);
43            }
44        }
45
46        for expression in &expressions {
47            for op in &expression.ops {
48                op.collect_parameters(&mut parameters);
49            }
50        }
51
52        for scope in &scopes {
53            if let Scope::Parameter(name) = &scope {
54                scope_parameters.insert(name.to_string(), None);
55            }
56        }
57
58        Rule {
59            head,
60            body,
61            expressions,
62            parameters: Some(parameters),
63            scopes,
64            scope_parameters: Some(scope_parameters),
65        }
66    }
67
68    pub fn validate_parameters(&self) -> Result<(), error::Token> {
69        let mut invalid_parameters = match &self.parameters {
70            None => vec![],
71            Some(parameters) => parameters
72                .iter()
73                .filter_map(
74                    |(name, opt_term)| {
75                        if opt_term.is_none() {
76                            Some(name)
77                        } else {
78                            None
79                        }
80                    },
81                )
82                .map(|name| name.to_string())
83                .collect::<Vec<_>>(),
84        };
85        let mut invalid_scope_parameters = match &self.scope_parameters {
86            None => vec![],
87            Some(parameters) => parameters
88                .iter()
89                .filter_map(
90                    |(name, opt_key)| {
91                        if opt_key.is_none() {
92                            Some(name)
93                        } else {
94                            None
95                        }
96                    },
97                )
98                .map(|name| name.to_string())
99                .collect::<Vec<_>>(),
100        };
101        let mut all_invalid_parameters = vec![];
102        all_invalid_parameters.append(&mut invalid_parameters);
103        all_invalid_parameters.append(&mut invalid_scope_parameters);
104
105        if all_invalid_parameters.is_empty() {
106            Ok(())
107        } else {
108            Err(error::Token::Language(
109                biscuit_parser::error::LanguageError::Parameters {
110                    missing_parameters: all_invalid_parameters,
111                    unused_parameters: vec![],
112                },
113            ))
114        }
115    }
116
117    pub fn validate_variables(&self) -> Result<(), String> {
118        let mut head_variables: std::collections::HashSet<String> = self
119            .head
120            .terms
121            .iter()
122            .filter_map(|term| match term {
123                Term::Variable(s) => Some(s.to_string()),
124                _ => None,
125            })
126            .collect();
127
128        for predicate in self.body.iter() {
129            for term in predicate.terms.iter() {
130                if let Term::Variable(v) = term {
131                    head_variables.remove(v);
132                    if head_variables.is_empty() {
133                        return Ok(());
134                    }
135                }
136            }
137        }
138
139        if head_variables.is_empty() {
140            Ok(())
141        } else {
142            Err(format!(
143                    "rule head contains variables that are not used in predicates of the rule's body: {}",
144                    head_variables
145                    .iter()
146                    .map(|s| format!("${}", s))
147                    .collect::<Vec<_>>()
148                    .join(", ")
149                    ))
150        }
151    }
152
153    /// replace a parameter with the term argument
154    pub fn set<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
155        if let Some(parameters) = self.parameters.as_mut() {
156            match parameters.get_mut(name) {
157                None => Err(error::Token::Language(
158                    biscuit_parser::error::LanguageError::Parameters {
159                        missing_parameters: vec![],
160                        unused_parameters: vec![name.to_string()],
161                    },
162                )),
163                Some(v) => {
164                    *v = Some(term.into());
165                    Ok(())
166                }
167            }
168        } else {
169            Err(error::Token::Language(
170                biscuit_parser::error::LanguageError::Parameters {
171                    missing_parameters: vec![],
172                    unused_parameters: vec![name.to_string()],
173                },
174            ))
175        }
176    }
177
178    /// replace a parameter with the term argument, without raising an error if the
179    /// parameter is not present in the rule
180    pub fn set_lenient<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
181        if let Some(parameters) = self.parameters.as_mut() {
182            match parameters.get_mut(name) {
183                None => Ok(()),
184                Some(v) => {
185                    *v = Some(term.into());
186                    Ok(())
187                }
188            }
189        } else {
190            Err(error::Token::Language(
191                biscuit_parser::error::LanguageError::Parameters {
192                    missing_parameters: vec![],
193                    unused_parameters: vec![name.to_string()],
194                },
195            ))
196        }
197    }
198
199    /// replace a scope parameter with the pubkey argument
200    pub fn set_scope(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
201        if let Some(parameters) = self.scope_parameters.as_mut() {
202            match parameters.get_mut(name) {
203                None => Err(error::Token::Language(
204                    biscuit_parser::error::LanguageError::Parameters {
205                        missing_parameters: vec![],
206                        unused_parameters: vec![name.to_string()],
207                    },
208                )),
209                Some(v) => {
210                    *v = Some(pubkey);
211                    Ok(())
212                }
213            }
214        } else {
215            Err(error::Token::Language(
216                biscuit_parser::error::LanguageError::Parameters {
217                    missing_parameters: vec![],
218                    unused_parameters: vec![name.to_string()],
219                },
220            ))
221        }
222    }
223
224    /// replace a scope parameter with the public key argument, without raising an error if the
225    /// parameter is not present in the rule scope
226    pub fn set_scope_lenient(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
227        if let Some(parameters) = self.scope_parameters.as_mut() {
228            match parameters.get_mut(name) {
229                None => Ok(()),
230                Some(v) => {
231                    *v = Some(pubkey);
232                    Ok(())
233                }
234            }
235        } else {
236            Err(error::Token::Language(
237                biscuit_parser::error::LanguageError::Parameters {
238                    missing_parameters: vec![],
239                    unused_parameters: vec![name.to_string()],
240                },
241            ))
242        }
243    }
244
245    #[cfg(feature = "datalog-macro")]
246    pub fn set_macro_param<T: ToAnyParam>(
247        &mut self,
248        name: &str,
249        param: T,
250    ) -> Result<(), error::Token> {
251        use super::AnyParam;
252
253        match param.to_any_param() {
254            AnyParam::Term(t) => self.set_lenient(name, t),
255            AnyParam::PublicKey(pubkey) => self.set_scope_lenient(name, pubkey),
256        }
257    }
258
259    pub(super) fn apply_parameters(&mut self) {
260        if let Some(parameters) = self.parameters.clone() {
261            self.head.terms = self
262                .head
263                .terms
264                .drain(..)
265                .map(|t| {
266                    if let Term::Parameter(name) = &t {
267                        if let Some(Some(term)) = parameters.get(name) {
268                            return term.clone();
269                        }
270                    }
271                    t
272                })
273                .collect();
274
275            for predicate in &mut self.body {
276                predicate.terms = predicate
277                    .terms
278                    .drain(..)
279                    .map(|t| {
280                        if let Term::Parameter(name) = &t {
281                            if let Some(Some(term)) = parameters.get(name) {
282                                return term.clone();
283                            }
284                        }
285                        t
286                    })
287                    .collect();
288            }
289
290            for expression in &mut self.expressions {
291                expression.ops = expression
292                    .ops
293                    .drain(..)
294                    .map(|op| op.apply_parameters(&parameters))
295                    .collect();
296            }
297        }
298
299        if let Some(parameters) = self.scope_parameters.clone() {
300            self.scopes = self
301                .scopes
302                .drain(..)
303                .map(|scope| {
304                    if let Scope::Parameter(name) = &scope {
305                        if let Some(Some(pubkey)) = parameters.get(name) {
306                            return Scope::PublicKey(*pubkey);
307                        }
308                    }
309                    scope
310                })
311                .collect();
312        }
313    }
314}
315
316impl Convert<datalog::Rule> for Rule {
317    fn convert(&self, symbols: &mut SymbolTable) -> datalog::Rule {
318        let mut r = self.clone();
319        r.apply_parameters();
320
321        let head = r.head.convert(symbols);
322        let mut body = vec![];
323        let mut expressions = vec![];
324        let mut scopes = vec![];
325
326        for p in r.body.iter() {
327            body.push(p.convert(symbols));
328        }
329
330        for c in r.expressions.iter() {
331            expressions.push(c.convert(symbols));
332        }
333
334        for scope in r.scopes.iter() {
335            scopes.push(match scope {
336                Scope::Authority => crate::token::Scope::Authority,
337                Scope::Previous => crate::token::Scope::Previous,
338                Scope::PublicKey(key) => {
339                    crate::token::Scope::PublicKey(symbols.public_keys.insert(key))
340                }
341                // The error is caught in the `add_xxx` functions, so this should
342                // not happen™
343                Scope::Parameter(s) => panic!("Remaining parameter {}", &s),
344            })
345        }
346        datalog::Rule {
347            head,
348            body,
349            expressions,
350            scopes,
351        }
352    }
353
354    fn convert_from(r: &datalog::Rule, symbols: &SymbolTable) -> Result<Self, error::Format> {
355        Ok(Rule {
356            head: Predicate::convert_from(&r.head, symbols)?,
357            body: r
358                .body
359                .iter()
360                .map(|p| Predicate::convert_from(p, symbols))
361                .collect::<Result<Vec<Predicate>, error::Format>>()?,
362            expressions: r
363                .expressions
364                .iter()
365                .map(|c| Expression::convert_from(c, symbols))
366                .collect::<Result<Vec<_>, error::Format>>()?,
367            parameters: None,
368            scopes: r
369                .scopes
370                .iter()
371                .map(|scope| Scope::convert_from(scope, symbols))
372                .collect::<Result<Vec<Scope>, error::Format>>()?,
373            scope_parameters: None,
374        })
375    }
376}
377
378pub(super) fn display_rule_body(r: &Rule, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379    let mut rule = r.clone();
380    rule.apply_parameters();
381    if !rule.body.is_empty() {
382        write!(f, "{}", rule.body[0])?;
383
384        if rule.body.len() > 1 {
385            for i in 1..rule.body.len() {
386                write!(f, ", {}", rule.body[i])?;
387            }
388        }
389    }
390
391    if !rule.expressions.is_empty() {
392        if !rule.body.is_empty() {
393            write!(f, ", ")?;
394        }
395
396        write!(f, "{}", rule.expressions[0])?;
397
398        if rule.expressions.len() > 1 {
399            for i in 1..rule.expressions.len() {
400                write!(f, ", {}", rule.expressions[i])?;
401            }
402        }
403    }
404
405    if !rule.scopes.is_empty() {
406        write!(f, " trusting {}", rule.scopes[0])?;
407        if rule.scopes.len() > 1 {
408            for i in 1..rule.scopes.len() {
409                write!(f, ", {}", rule.scopes[i])?;
410            }
411        }
412    }
413
414    Ok(())
415}
416
417impl fmt::Display for Rule {
418    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419        let mut r = self.clone();
420        r.apply_parameters();
421
422        write!(f, "{} <- ", r.head)?;
423
424        display_rule_body(&r, f)
425    }
426}
427
428impl From<biscuit_parser::builder::Rule> for Rule {
429    fn from(r: biscuit_parser::builder::Rule) -> Self {
430        Rule {
431            head: r.head.into(),
432            body: r.body.into_iter().map(|p| p.into()).collect(),
433            expressions: r.expressions.into_iter().map(|e| e.into()).collect(),
434            parameters: r.parameters.map(|h| {
435                h.into_iter()
436                    .map(|(k, v)| (k, v.map(|term| term.into())))
437                    .collect()
438            }),
439            scopes: r.scopes.into_iter().map(|s| s.into()).collect(),
440            scope_parameters: r.scope_parameters.map(|h| {
441                h.into_iter()
442                    .map(|(k, v)| {
443                        (
444                            k,
445                            v.map(|pk| {
446                                PublicKey::from_bytes(&pk.key, pk.algorithm.into())
447                                    .expect("invalid public key")
448                            }),
449                        )
450                    })
451                    .collect()
452            }),
453        }
454    }
455}
456
457impl TryFrom<&str> for Rule {
458    type Error = error::Token;
459
460    fn try_from(value: &str) -> Result<Self, Self::Error> {
461        Ok(biscuit_parser::parser::rule(value)
462            .finish()
463            .map(|(_, o)| o.into())
464            .map_err(biscuit_parser::error::LanguageError::from)?)
465    }
466}
467
468impl FromStr for Rule {
469    type Err = error::Token;
470
471    fn from_str(s: &str) -> Result<Self, Self::Err> {
472        Ok(biscuit_parser::parser::rule(s)
473            .finish()
474            .map(|(_, o)| o.into())
475            .map_err(biscuit_parser::error::LanguageError::from)?)
476    }
477}