biscuit_auth/token/builder/
rule.rs1use std::{collections::HashMap, convert::TryFrom, fmt, str::FromStr};
6
7use nom::Finish;
8
9use crate::{
10 datalog::{self, SymbolTable},
11 error, PublicKey,
12};
13
14use super::{Convert, Expression, Predicate, Scope, Term, ToAnyParam};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct Rule {
19 pub head: Predicate,
20 pub body: Vec<Predicate>,
21 pub expressions: Vec<Expression>,
22 pub parameters: Option<HashMap<String, Option<Term>>>,
23 pub scopes: Vec<Scope>,
24 pub scope_parameters: Option<HashMap<String, Option<PublicKey>>>,
25}
26
27impl Rule {
28 pub fn new(
29 head: Predicate,
30 body: Vec<Predicate>,
31 expressions: Vec<Expression>,
32 scopes: Vec<Scope>,
33 ) -> Rule {
34 let mut parameters = HashMap::new();
35 let mut scope_parameters = HashMap::new();
36 for term in &head.terms {
37 term.extract_parameters(&mut parameters);
38 }
39
40 for predicate in &body {
41 for term in &predicate.terms {
42 term.extract_parameters(&mut parameters);
43 }
44 }
45
46 for expression in &expressions {
47 for op in &expression.ops {
48 op.collect_parameters(&mut parameters);
49 }
50 }
51
52 for scope in &scopes {
53 if let Scope::Parameter(name) = &scope {
54 scope_parameters.insert(name.to_string(), None);
55 }
56 }
57
58 Rule {
59 head,
60 body,
61 expressions,
62 parameters: Some(parameters),
63 scopes,
64 scope_parameters: Some(scope_parameters),
65 }
66 }
67
68 pub fn validate_parameters(&self) -> Result<(), error::Token> {
69 let mut invalid_parameters = match &self.parameters {
70 None => vec![],
71 Some(parameters) => parameters
72 .iter()
73 .filter_map(
74 |(name, opt_term)| {
75 if opt_term.is_none() {
76 Some(name)
77 } else {
78 None
79 }
80 },
81 )
82 .map(|name| name.to_string())
83 .collect::<Vec<_>>(),
84 };
85 let mut invalid_scope_parameters = match &self.scope_parameters {
86 None => vec![],
87 Some(parameters) => parameters
88 .iter()
89 .filter_map(
90 |(name, opt_key)| {
91 if opt_key.is_none() {
92 Some(name)
93 } else {
94 None
95 }
96 },
97 )
98 .map(|name| name.to_string())
99 .collect::<Vec<_>>(),
100 };
101 let mut all_invalid_parameters = vec![];
102 all_invalid_parameters.append(&mut invalid_parameters);
103 all_invalid_parameters.append(&mut invalid_scope_parameters);
104
105 if all_invalid_parameters.is_empty() {
106 Ok(())
107 } else {
108 Err(error::Token::Language(
109 biscuit_parser::error::LanguageError::Parameters {
110 missing_parameters: all_invalid_parameters,
111 unused_parameters: vec![],
112 },
113 ))
114 }
115 }
116
117 pub fn validate_variables(&self) -> Result<(), String> {
118 let mut head_variables: std::collections::HashSet<String> = self
119 .head
120 .terms
121 .iter()
122 .filter_map(|term| match term {
123 Term::Variable(s) => Some(s.to_string()),
124 _ => None,
125 })
126 .collect();
127
128 for predicate in self.body.iter() {
129 for term in predicate.terms.iter() {
130 if let Term::Variable(v) = term {
131 head_variables.remove(v);
132 if head_variables.is_empty() {
133 return Ok(());
134 }
135 }
136 }
137 }
138
139 if head_variables.is_empty() {
140 Ok(())
141 } else {
142 Err(format!(
143 "rule head contains variables that are not used in predicates of the rule's body: {}",
144 head_variables
145 .iter()
146 .map(|s| format!("${}", s))
147 .collect::<Vec<_>>()
148 .join(", ")
149 ))
150 }
151 }
152
153 pub fn set<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
155 if let Some(parameters) = self.parameters.as_mut() {
156 match parameters.get_mut(name) {
157 None => Err(error::Token::Language(
158 biscuit_parser::error::LanguageError::Parameters {
159 missing_parameters: vec![],
160 unused_parameters: vec![name.to_string()],
161 },
162 )),
163 Some(v) => {
164 *v = Some(term.into());
165 Ok(())
166 }
167 }
168 } else {
169 Err(error::Token::Language(
170 biscuit_parser::error::LanguageError::Parameters {
171 missing_parameters: vec![],
172 unused_parameters: vec![name.to_string()],
173 },
174 ))
175 }
176 }
177
178 pub fn set_lenient<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
181 if let Some(parameters) = self.parameters.as_mut() {
182 match parameters.get_mut(name) {
183 None => Ok(()),
184 Some(v) => {
185 *v = Some(term.into());
186 Ok(())
187 }
188 }
189 } else {
190 Err(error::Token::Language(
191 biscuit_parser::error::LanguageError::Parameters {
192 missing_parameters: vec![],
193 unused_parameters: vec![name.to_string()],
194 },
195 ))
196 }
197 }
198
199 pub fn set_scope(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
201 if let Some(parameters) = self.scope_parameters.as_mut() {
202 match parameters.get_mut(name) {
203 None => Err(error::Token::Language(
204 biscuit_parser::error::LanguageError::Parameters {
205 missing_parameters: vec![],
206 unused_parameters: vec![name.to_string()],
207 },
208 )),
209 Some(v) => {
210 *v = Some(pubkey);
211 Ok(())
212 }
213 }
214 } else {
215 Err(error::Token::Language(
216 biscuit_parser::error::LanguageError::Parameters {
217 missing_parameters: vec![],
218 unused_parameters: vec![name.to_string()],
219 },
220 ))
221 }
222 }
223
224 pub fn set_scope_lenient(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
227 if let Some(parameters) = self.scope_parameters.as_mut() {
228 match parameters.get_mut(name) {
229 None => Ok(()),
230 Some(v) => {
231 *v = Some(pubkey);
232 Ok(())
233 }
234 }
235 } else {
236 Err(error::Token::Language(
237 biscuit_parser::error::LanguageError::Parameters {
238 missing_parameters: vec![],
239 unused_parameters: vec![name.to_string()],
240 },
241 ))
242 }
243 }
244
245 #[cfg(feature = "datalog-macro")]
246 pub fn set_macro_param<T: ToAnyParam>(
247 &mut self,
248 name: &str,
249 param: T,
250 ) -> Result<(), error::Token> {
251 use super::AnyParam;
252
253 match param.to_any_param() {
254 AnyParam::Term(t) => self.set_lenient(name, t),
255 AnyParam::PublicKey(pubkey) => self.set_scope_lenient(name, pubkey),
256 }
257 }
258
259 pub(super) fn apply_parameters(&mut self) {
260 if let Some(parameters) = self.parameters.clone() {
261 self.head.terms = self
262 .head
263 .terms
264 .drain(..)
265 .map(|t| {
266 if let Term::Parameter(name) = &t {
267 if let Some(Some(term)) = parameters.get(name) {
268 return term.clone();
269 }
270 }
271 t
272 })
273 .collect();
274
275 for predicate in &mut self.body {
276 predicate.terms = predicate
277 .terms
278 .drain(..)
279 .map(|t| {
280 if let Term::Parameter(name) = &t {
281 if let Some(Some(term)) = parameters.get(name) {
282 return term.clone();
283 }
284 }
285 t
286 })
287 .collect();
288 }
289
290 for expression in &mut self.expressions {
291 expression.ops = expression
292 .ops
293 .drain(..)
294 .map(|op| op.apply_parameters(¶meters))
295 .collect();
296 }
297 }
298
299 if let Some(parameters) = self.scope_parameters.clone() {
300 self.scopes = self
301 .scopes
302 .drain(..)
303 .map(|scope| {
304 if let Scope::Parameter(name) = &scope {
305 if let Some(Some(pubkey)) = parameters.get(name) {
306 return Scope::PublicKey(*pubkey);
307 }
308 }
309 scope
310 })
311 .collect();
312 }
313 }
314}
315
316impl Convert<datalog::Rule> for Rule {
317 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Rule {
318 let mut r = self.clone();
319 r.apply_parameters();
320
321 let head = r.head.convert(symbols);
322 let mut body = vec![];
323 let mut expressions = vec![];
324 let mut scopes = vec![];
325
326 for p in r.body.iter() {
327 body.push(p.convert(symbols));
328 }
329
330 for c in r.expressions.iter() {
331 expressions.push(c.convert(symbols));
332 }
333
334 for scope in r.scopes.iter() {
335 scopes.push(match scope {
336 Scope::Authority => crate::token::Scope::Authority,
337 Scope::Previous => crate::token::Scope::Previous,
338 Scope::PublicKey(key) => {
339 crate::token::Scope::PublicKey(symbols.public_keys.insert(key))
340 }
341 Scope::Parameter(s) => panic!("Remaining parameter {}", &s),
344 })
345 }
346 datalog::Rule {
347 head,
348 body,
349 expressions,
350 scopes,
351 }
352 }
353
354 fn convert_from(r: &datalog::Rule, symbols: &SymbolTable) -> Result<Self, error::Format> {
355 Ok(Rule {
356 head: Predicate::convert_from(&r.head, symbols)?,
357 body: r
358 .body
359 .iter()
360 .map(|p| Predicate::convert_from(p, symbols))
361 .collect::<Result<Vec<Predicate>, error::Format>>()?,
362 expressions: r
363 .expressions
364 .iter()
365 .map(|c| Expression::convert_from(c, symbols))
366 .collect::<Result<Vec<_>, error::Format>>()?,
367 parameters: None,
368 scopes: r
369 .scopes
370 .iter()
371 .map(|scope| Scope::convert_from(scope, symbols))
372 .collect::<Result<Vec<Scope>, error::Format>>()?,
373 scope_parameters: None,
374 })
375 }
376}
377
378pub(super) fn display_rule_body(r: &Rule, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379 let mut rule = r.clone();
380 rule.apply_parameters();
381 if !rule.body.is_empty() {
382 write!(f, "{}", rule.body[0])?;
383
384 if rule.body.len() > 1 {
385 for i in 1..rule.body.len() {
386 write!(f, ", {}", rule.body[i])?;
387 }
388 }
389 }
390
391 if !rule.expressions.is_empty() {
392 if !rule.body.is_empty() {
393 write!(f, ", ")?;
394 }
395
396 write!(f, "{}", rule.expressions[0])?;
397
398 if rule.expressions.len() > 1 {
399 for i in 1..rule.expressions.len() {
400 write!(f, ", {}", rule.expressions[i])?;
401 }
402 }
403 }
404
405 if !rule.scopes.is_empty() {
406 write!(f, " trusting {}", rule.scopes[0])?;
407 if rule.scopes.len() > 1 {
408 for i in 1..rule.scopes.len() {
409 write!(f, ", {}", rule.scopes[i])?;
410 }
411 }
412 }
413
414 Ok(())
415}
416
417impl fmt::Display for Rule {
418 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419 let mut r = self.clone();
420 r.apply_parameters();
421
422 write!(f, "{} <- ", r.head)?;
423
424 display_rule_body(&r, f)
425 }
426}
427
428impl From<biscuit_parser::builder::Rule> for Rule {
429 fn from(r: biscuit_parser::builder::Rule) -> Self {
430 Rule {
431 head: r.head.into(),
432 body: r.body.into_iter().map(|p| p.into()).collect(),
433 expressions: r.expressions.into_iter().map(|e| e.into()).collect(),
434 parameters: r.parameters.map(|h| {
435 h.into_iter()
436 .map(|(k, v)| (k, v.map(|term| term.into())))
437 .collect()
438 }),
439 scopes: r.scopes.into_iter().map(|s| s.into()).collect(),
440 scope_parameters: r.scope_parameters.map(|h| {
441 h.into_iter()
442 .map(|(k, v)| {
443 (
444 k,
445 v.map(|pk| {
446 PublicKey::from_bytes(&pk.key, pk.algorithm.into())
447 .expect("invalid public key")
448 }),
449 )
450 })
451 .collect()
452 }),
453 }
454 }
455}
456
457impl TryFrom<&str> for Rule {
458 type Error = error::Token;
459
460 fn try_from(value: &str) -> Result<Self, Self::Error> {
461 Ok(biscuit_parser::parser::rule(value)
462 .finish()
463 .map(|(_, o)| o.into())
464 .map_err(biscuit_parser::error::LanguageError::from)?)
465 }
466}
467
468impl FromStr for Rule {
469 type Err = error::Token;
470
471 fn from_str(s: &str) -> Result<Self, Self::Err> {
472 Ok(biscuit_parser::parser::rule(s)
473 .finish()
474 .map(|(_, o)| o.into())
475 .map_err(biscuit_parser::error::LanguageError::from)?)
476 }
477}