biscuit_auth/token/builder/
check.rs1use std::{convert::TryFrom, fmt, str::FromStr};
6
7use nom::Finish;
8
9use crate::{
10 datalog::{self, SymbolTable},
11 error, PublicKey,
12};
13
14use super::{display_rule_body, Convert, Rule, Term, ToAnyParam};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct Check {
19 pub queries: Vec<Rule>,
20 pub kind: CheckKind,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum CheckKind {
26 One,
27 All,
28 Reject,
29}
30
31impl Check {
32 pub fn set<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
34 let term = term.into();
35 self.set_inner(name, term)
36 }
37
38 fn set_inner(&mut self, name: &str, term: Term) -> Result<(), error::Token> {
39 let mut found = false;
40 for query in &mut self.queries {
41 if query.set(name, term.clone()).is_ok() {
42 found = true;
43 }
44 }
45
46 if found {
47 Ok(())
48 } else {
49 Err(error::Token::Language(
50 biscuit_parser::error::LanguageError::Parameters {
51 missing_parameters: vec![],
52 unused_parameters: vec![name.to_string()],
53 },
54 ))
55 }
56 }
57
58 pub fn set_scope(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
60 let mut found = false;
61 for query in &mut self.queries {
62 if query.set_scope(name, pubkey).is_ok() {
63 found = true;
64 }
65 }
66
67 if found {
68 Ok(())
69 } else {
70 Err(error::Token::Language(
71 biscuit_parser::error::LanguageError::Parameters {
72 missing_parameters: vec![],
73 unused_parameters: vec![name.to_string()],
74 },
75 ))
76 }
77 }
78
79 pub fn set_lenient<T: Into<Term>>(&mut self, name: &str, term: T) -> Result<(), error::Token> {
82 let term = term.into();
83 for query in &mut self.queries {
84 query.set_lenient(name, term.clone())?;
85 }
86 Ok(())
87 }
88
89 pub fn set_scope_lenient(&mut self, name: &str, pubkey: PublicKey) -> Result<(), error::Token> {
92 for query in &mut self.queries {
93 query.set_scope_lenient(name, pubkey)?;
94 }
95 Ok(())
96 }
97
98 #[cfg(feature = "datalog-macro")]
99 pub fn set_macro_param<T: ToAnyParam>(
100 &mut self,
101 name: &str,
102 param: T,
103 ) -> Result<(), error::Token> {
104 use super::AnyParam;
105
106 match param.to_any_param() {
107 AnyParam::Term(t) => self.set_lenient(name, t),
108 AnyParam::PublicKey(p) => self.set_scope_lenient(name, p),
109 }
110 }
111
112 pub fn validate_parameters(&self) -> Result<(), error::Token> {
113 for rule in &self.queries {
114 rule.validate_parameters()?;
115 }
116
117 Ok(())
118 }
119
120 pub(super) fn apply_parameters(&mut self) {
121 for rule in self.queries.iter_mut() {
122 rule.apply_parameters();
123 }
124 }
125}
126
127impl Convert<datalog::Check> for Check {
128 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Check {
129 let mut queries = vec![];
130 for q in self.queries.iter() {
131 queries.push(q.convert(symbols));
132 }
133
134 datalog::Check {
135 queries,
136 kind: self.kind.clone(),
137 }
138 }
139
140 fn convert_from(r: &datalog::Check, symbols: &SymbolTable) -> Result<Self, error::Format> {
141 let mut queries = vec![];
142 for q in r.queries.iter() {
143 queries.push(Rule::convert_from(q, symbols)?);
144 }
145
146 Ok(Check {
147 queries,
148 kind: r.kind.clone(),
149 })
150 }
151}
152
153impl TryFrom<Rule> for Check {
154 type Error = error::Token;
155
156 fn try_from(value: Rule) -> Result<Self, Self::Error> {
157 Ok(Check {
158 queries: vec![value],
159 kind: CheckKind::One,
160 })
161 }
162}
163
164impl TryFrom<&[Rule]> for Check {
165 type Error = error::Token;
166
167 fn try_from(values: &[Rule]) -> Result<Self, Self::Error> {
168 Ok(Check {
169 queries: values.to_vec(),
170 kind: CheckKind::One,
171 })
172 }
173}
174
175impl fmt::Display for Check {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 match self.kind {
178 CheckKind::One => write!(f, "check if ")?,
179 CheckKind::All => write!(f, "check all ")?,
180 CheckKind::Reject => write!(f, "reject if ")?,
181 };
182
183 if !self.queries.is_empty() {
184 let mut q0 = self.queries[0].clone();
185 q0.apply_parameters();
186 display_rule_body(&q0, f)?;
187
188 if self.queries.len() > 1 {
189 for i in 1..self.queries.len() {
190 write!(f, " or ")?;
191 let mut qn = self.queries[i].clone();
192 qn.apply_parameters();
193 display_rule_body(&qn, f)?;
194 }
195 }
196 }
197
198 Ok(())
199 }
200}
201
202impl From<biscuit_parser::builder::Check> for Check {
203 fn from(c: biscuit_parser::builder::Check) -> Self {
204 Check {
205 queries: c.queries.into_iter().map(|q| q.into()).collect(),
206 kind: match c.kind {
207 biscuit_parser::builder::CheckKind::One => CheckKind::One,
208 biscuit_parser::builder::CheckKind::All => CheckKind::All,
209 biscuit_parser::builder::CheckKind::Reject => CheckKind::Reject,
210 },
211 }
212 }
213}
214
215impl TryFrom<&str> for Check {
216 type Error = error::Token;
217
218 fn try_from(value: &str) -> Result<Self, Self::Error> {
219 Ok(biscuit_parser::parser::check(value)
220 .finish()
221 .map(|(_, o)| o.into())
222 .map_err(biscuit_parser::error::LanguageError::from)?)
223 }
224}
225
226impl FromStr for Check {
227 type Err = error::Token;
228
229 fn from_str(s: &str) -> Result<Self, Self::Err> {
230 Ok(biscuit_parser::parser::check(s)
231 .finish()
232 .map(|(_, o)| o.into())
233 .map_err(biscuit_parser::error::LanguageError::from)?)
234 }
235}