biscuit_auth/token/builder/
policy.rs1use std::{convert::TryFrom, fmt, str::FromStr};
6
7use nom::Finish;
8
9use crate::{error, PublicKey};
10
11use super::{display_rule_body, Rule, Term, ToAnyParam};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum PolicyKind {
15 Allow,
16 Deny,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct Policy {
22 pub queries: Vec<Rule>,
23 pub kind: PolicyKind,
24}
25
26impl Policy {
27 pub fn set<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
29 let term = term.into();
30 self.set_inner(name, term)
31 }
32
33 pub fn set_inner(&mut self, name: &str, term: Term) -> Result<(), error::Token> {
34 let mut found = false;
35 for query in &mut self.queries {
36 if query.set(name, term.clone()).is_ok() {
37 found = true;
38 }
39 }
40
41 if found {
42 Ok(())
43 } else {
44 Err(error::Token::Language(
45 biscuit_parser::error::LanguageError::Parameters {
46 missing_parameters: vec![],
47 unused_parameters: vec![name.to_string()],
48 },
49 ))
50 }
51 }
52
53 pub fn set_scope(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
55 let mut found = false;
56 for query in &mut self.queries {
57 if query.set_scope(name, pubkey).is_ok() {
58 found = true;
59 }
60 }
61
62 if found {
63 Ok(())
64 } else {
65 Err(error::Token::Language(
66 biscuit_parser::error::LanguageError::Parameters {
67 missing_parameters: vec![],
68 unused_parameters: vec![name.to_string()],
69 },
70 ))
71 }
72 }
73
74 pub fn set_lenient<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
76 let term = term.into();
77 for query in &mut self.queries {
78 query.set_lenient(name, term.clone())?;
79 }
80 Ok(())
81 }
82
83 pub fn set_scope_lenient(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
85 for query in &mut self.queries {
86 query.set_scope_lenient(name, pubkey)?;
87 }
88 Ok(())
89 }
90
91 #[cfg(feature = "datalog-macro")]
92 pub fn set_macro_param<T: ToAnyParam>(
93 &mut self,
94 name: &str,
95 param: T,
96 ) -> Result<(), error::Token> {
97 use super::AnyParam;
98
99 match param.to_any_param() {
100 AnyParam::Term(t) => self.set_lenient(name, t),
101 AnyParam::PublicKey(p) => self.set_scope_lenient(name, p),
102 }
103 }
104
105 pub fn validate_parameters(&self) -> Result<(), error::Token> {
106 for query in &self.queries {
107 query.validate_parameters()?;
108 }
109
110 Ok(())
111 }
112
113 pub fn apply_parameters(&mut self) {
114 for rule in self.queries.iter_mut() {
115 rule.apply_parameters();
116 }
117 }
118}
119
120impl fmt::Display for Policy {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 if !self.queries.is_empty() {
123 match self.kind {
124 PolicyKind::Allow => write!(f, "allow if ")?,
125 PolicyKind::Deny => write!(f, "deny if ")?,
126 }
127
128 if !self.queries.is_empty() {
129 display_rule_body(&self.queries[0], f)?;
130
131 if self.queries.len() > 1 {
132 for i in 1..self.queries.len() {
133 write!(f, " or ")?;
134 display_rule_body(&self.queries[i], f)?;
135 }
136 }
137 }
138 } else {
139 match self.kind {
140 PolicyKind::Allow => write!(f, "allow")?,
141 PolicyKind::Deny => write!(f, "deny")?,
142 }
143 }
144
145 Ok(())
146 }
147}
148
149impl From<biscuit_parser::builder::Policy> for Policy {
150 fn from(p: biscuit_parser::builder::Policy) -> Self {
151 Policy {
152 queries: p.queries.into_iter().map(|q| q.into()).collect(),
153 kind: match p.kind {
154 biscuit_parser::builder::PolicyKind::Allow => PolicyKind::Allow,
155 biscuit_parser::builder::PolicyKind::Deny => PolicyKind::Deny,
156 },
157 }
158 }
159}
160
161impl TryFrom<&str> for Policy {
162 type Error = error::Token;
163
164 fn try_from(value: &str) -> Result<Self, Self::Error> {
165 Ok(biscuit_parser::parser::policy(value)
166 .finish()
167 .map(|(_, o)| o.into())
168 .map_err(biscuit_parser::error::LanguageError::from)?)
169 }
170}
171
172impl FromStr for Policy {
173 type Err = error::Token;
174
175 fn from_str(s: &str) -> Result<Self, Self::Err> {
176 Ok(biscuit_parser::parser::policy(s)
177 .finish()
178 .map(|(_, o)| o.into())
179 .map_err(biscuit_parser::error::LanguageError::from)?)
180 }
181}