1use std::{
6 collections::{BTreeMap, BTreeSet, HashMap},
7 convert::{TryFrom, TryInto},
8 fmt,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12use crate::{
13 datalog::{self, SymbolTable, TemporarySymbolTable},
14 error,
15};
16
17#[cfg(feature = "datalog-macro")]
18use super::AnyParam;
19use super::{set, Convert, Fact, ToAnyParam};
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
23pub enum Term {
24 Variable(String),
25 Integer(i64),
26 Str(String),
27 Date(u64),
28 Bytes(Vec<u8>),
29 Bool(bool),
30 Set(BTreeSet<Term>),
31 Parameter(String),
32 Null,
33 Array(Vec<Term>),
34 Map(BTreeMap<MapKey, Term>),
35}
36
37impl Term {
38 pub(super) fn extract_parameters(&self, parameters: &mut HashMap<String, Option<Term>>) {
39 match self {
40 Term::Parameter(name) => {
41 parameters.insert(name.to_string(), None);
42 }
43 Term::Set(s) => {
44 for term in s {
45 term.extract_parameters(parameters);
46 }
47 }
48 Term::Array(a) => {
49 for term in a {
50 term.extract_parameters(parameters);
51 }
52 }
53 Term::Map(m) => {
54 for (key, term) in m {
55 if let MapKey::Parameter(name) = key {
56 parameters.insert(name.to_string(), None);
57 }
58 term.extract_parameters(parameters);
59 }
60 }
61 _ => {}
62 }
63 }
64
65 pub(super) fn apply_parameters(self, parameters: &HashMap<String, Option<Term>>) -> Term {
66 match self {
67 Term::Parameter(name) => {
68 if let Some(Some(term)) = parameters.get(&name) {
69 term.clone()
70 } else {
71 Term::Parameter(name)
72 }
73 }
74 Term::Map(m) => Term::Map(
75 m.into_iter()
76 .map(|(key, term)| {
77 (
78 match key {
79 MapKey::Parameter(name) => {
80 if let Some(Some(key_term)) = parameters.get(&name) {
81 match key_term {
82 Term::Integer(i) => MapKey::Integer(*i),
83 Term::Str(s) => MapKey::Str(s.clone()),
84 _ => MapKey::Parameter(name),
86 }
87 } else {
88 MapKey::Parameter(name)
89 }
90 }
91 _ => key,
92 },
93 term.apply_parameters(parameters),
94 )
95 })
96 .collect(),
97 ),
98 Term::Array(array) => Term::Array(
99 array
100 .into_iter()
101 .map(|term| term.apply_parameters(parameters))
102 .collect(),
103 ),
104 Term::Set(set) => Term::Set(
105 set.into_iter()
106 .map(|term| term.apply_parameters(parameters))
107 .collect(),
108 ),
109 _ => self,
110 }
111 }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
115pub enum MapKey {
116 Integer(i64),
117 Str(String),
118 Parameter(String),
119}
120
121impl Term {
122 pub fn to_datalog(self, symbols: &mut TemporarySymbolTable) -> datalog::Term {
123 match self {
124 Term::Variable(s) => datalog::Term::Variable(symbols.insert(&s) as u32),
125 Term::Integer(i) => datalog::Term::Integer(i),
126 Term::Str(s) => datalog::Term::Str(symbols.insert(&s)),
127 Term::Date(d) => datalog::Term::Date(d),
128 Term::Bytes(s) => datalog::Term::Bytes(s),
129 Term::Bool(b) => datalog::Term::Bool(b),
130 Term::Set(s) => {
131 datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect())
132 }
133 Term::Null => datalog::Term::Null,
134 Term::Array(a) => {
135 datalog::Term::Array(a.into_iter().map(|i| i.to_datalog(symbols)).collect())
136 }
137 Term::Map(m) => datalog::Term::Map(
138 m.into_iter()
139 .map(|(k, i)| {
140 (
141 match k {
142 MapKey::Integer(i) => datalog::MapKey::Integer(i),
143 MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(&s)),
144 MapKey::Parameter(s) => panic!("Remaining parameter {}", &s),
147 },
148 i.to_datalog(symbols),
149 )
150 })
151 .collect(),
152 ),
153 Term::Parameter(s) => panic!("Remaining parameter {}", &s),
156 }
157 }
158
159 pub fn from_datalog(
160 term: datalog::Term,
161 symbols: &TemporarySymbolTable,
162 ) -> Result<Self, error::Expression> {
163 Ok(match term {
164 datalog::Term::Variable(s) => Term::Variable(
165 symbols
166 .get_symbol(s as u64)
167 .ok_or(error::Expression::UnknownVariable(s))?
168 .to_string(),
169 ),
170 datalog::Term::Integer(i) => Term::Integer(i),
171 datalog::Term::Str(s) => Term::Str(
172 symbols
173 .get_symbol(s)
174 .ok_or(error::Expression::UnknownSymbol(s))?
175 .to_string(),
176 ),
177 datalog::Term::Date(d) => Term::Date(d),
178 datalog::Term::Bytes(s) => Term::Bytes(s),
179 datalog::Term::Bool(b) => Term::Bool(b),
180 datalog::Term::Set(s) => Term::Set(
181 s.into_iter()
182 .map(|i| Self::from_datalog(i, symbols))
183 .collect::<Result<_, _>>()?,
184 ),
185 datalog::Term::Null => Term::Null,
186 datalog::Term::Array(a) => Term::Array(
187 a.into_iter()
188 .map(|i| Self::from_datalog(i, symbols))
189 .collect::<Result<_, _>>()?,
190 ),
191 datalog::Term::Map(m) => Term::Map(
192 m.into_iter()
193 .map(|(k, i)| {
194 Ok((
195 match k {
196 datalog::MapKey::Integer(i) => MapKey::Integer(i),
197 datalog::MapKey::Str(s) => MapKey::Str(
198 symbols
199 .get_symbol(s)
200 .ok_or(error::Expression::UnknownSymbol(s))?
201 .to_string(),
202 ),
203 },
204 Self::from_datalog(i, symbols)?,
205 ))
206 })
207 .collect::<Result<_, _>>()?,
208 ),
209 })
210 }
211}
212
213impl Convert<datalog::Term> for Term {
214 fn convert(&self, symbols: &mut SymbolTable) -> datalog::Term {
215 match self {
216 Term::Variable(s) => datalog::Term::Variable(symbols.insert(s) as u32),
217 Term::Integer(i) => datalog::Term::Integer(*i),
218 Term::Str(s) => datalog::Term::Str(symbols.insert(s)),
219 Term::Date(d) => datalog::Term::Date(*d),
220 Term::Bytes(s) => datalog::Term::Bytes(s.clone()),
221 Term::Bool(b) => datalog::Term::Bool(*b),
222 Term::Set(s) => datalog::Term::Set(s.iter().map(|i| i.convert(symbols)).collect()),
223 Term::Null => datalog::Term::Null,
224 Term::Parameter(s) => panic!("Remaining parameter {}", &s),
227 Term::Array(a) => datalog::Term::Array(a.iter().map(|i| i.convert(symbols)).collect()),
228 Term::Map(m) => datalog::Term::Map(
229 m.iter()
230 .map(|(key, term)| {
231 let key = match key {
232 MapKey::Integer(i) => datalog::MapKey::Integer(*i),
233 MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(s)),
234 MapKey::Parameter(s) => panic!("Remaining parameter {}", &s),
235 };
236
237 (key, term.convert(symbols))
238 })
239 .collect(),
240 ),
241 }
242 }
243
244 fn convert_from(f: &datalog::Term, symbols: &SymbolTable) -> Result<Self, error::Format> {
245 Ok(match f {
246 datalog::Term::Variable(s) => Term::Variable(symbols.print_symbol(*s as u64)?),
247 datalog::Term::Integer(i) => Term::Integer(*i),
248 datalog::Term::Str(s) => Term::Str(symbols.print_symbol(*s)?),
249 datalog::Term::Date(d) => Term::Date(*d),
250 datalog::Term::Bytes(s) => Term::Bytes(s.clone()),
251 datalog::Term::Bool(b) => Term::Bool(*b),
252 datalog::Term::Set(s) => Term::Set(
253 s.iter()
254 .map(|i| Term::convert_from(i, symbols))
255 .collect::<Result<BTreeSet<_>, error::Format>>()?,
256 ),
257 datalog::Term::Null => Term::Null,
258 datalog::Term::Array(a) => Term::Array(
259 a.iter()
260 .map(|i| Term::convert_from(i, symbols))
261 .collect::<Result<Vec<_>, error::Format>>()?,
262 ),
263 datalog::Term::Map(m) => Term::Map(
264 m.iter()
265 .map(|(key, term)| {
266 let key = match key {
267 datalog::MapKey::Integer(i) => Ok(MapKey::Integer(*i)),
268 datalog::MapKey::Str(s) => symbols.print_symbol(*s).map(MapKey::Str),
269 };
270
271 key.and_then(|k| Term::convert_from(term, symbols).map(|term| (k, term)))
272 })
273 .collect::<Result<BTreeMap<_, _>, error::Format>>()?,
274 ),
275 })
276 }
277}
278
279impl From<&Term> for Term {
280 fn from(i: &Term) -> Self {
281 match i {
282 Term::Variable(ref v) => Term::Variable(v.clone()),
283 Term::Integer(ref i) => Term::Integer(*i),
284 Term::Str(ref s) => Term::Str(s.clone()),
285 Term::Date(ref d) => Term::Date(*d),
286 Term::Bytes(ref s) => Term::Bytes(s.clone()),
287 Term::Bool(b) => Term::Bool(*b),
288 Term::Set(ref s) => Term::Set(s.clone()),
289 Term::Parameter(ref p) => Term::Parameter(p.clone()),
290 Term::Null => Term::Null,
291 Term::Array(ref a) => Term::Array(a.clone()),
292 Term::Map(m) => Term::Map(m.clone()),
293 }
294 }
295}
296
297impl From<biscuit_parser::builder::Term> for Term {
298 fn from(t: biscuit_parser::builder::Term) -> Self {
299 match t {
300 biscuit_parser::builder::Term::Variable(v) => Term::Variable(v),
301 biscuit_parser::builder::Term::Integer(i) => Term::Integer(i),
302 biscuit_parser::builder::Term::Str(s) => Term::Str(s),
303 biscuit_parser::builder::Term::Date(d) => Term::Date(d),
304 biscuit_parser::builder::Term::Bytes(s) => Term::Bytes(s),
305 biscuit_parser::builder::Term::Bool(b) => Term::Bool(b),
306 biscuit_parser::builder::Term::Set(s) => {
307 Term::Set(s.into_iter().map(|t| t.into()).collect())
308 }
309 biscuit_parser::builder::Term::Null => Term::Null,
310 biscuit_parser::builder::Term::Parameter(ref p) => Term::Parameter(p.clone()),
311 biscuit_parser::builder::Term::Array(a) => {
312 Term::Array(a.into_iter().map(|t| t.into()).collect())
313 }
314 biscuit_parser::builder::Term::Map(a) => Term::Map(
315 a.into_iter()
316 .map(|(key, term)| {
317 (
318 match key {
319 biscuit_parser::builder::MapKey::Parameter(s) => {
320 MapKey::Parameter(s)
321 }
322 biscuit_parser::builder::MapKey::Integer(i) => MapKey::Integer(i),
323 biscuit_parser::builder::MapKey::Str(s) => MapKey::Str(s),
324 },
325 term.into(),
326 )
327 })
328 .collect(),
329 ),
330 }
331 }
332}
333
334impl AsRef<Term> for Term {
335 fn as_ref(&self) -> &Term {
336 self
337 }
338}
339
340impl fmt::Display for Term {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 match self {
343 Term::Variable(i) => write!(f, "${}", i),
344 Term::Integer(i) => write!(f, "{}", i),
345 Term::Str(s) => write!(f, "\"{}\"", s),
346 Term::Date(d) => {
347 let date = time::OffsetDateTime::from_unix_timestamp(*d as i64)
348 .ok()
349 .and_then(|t| {
350 t.format(&time::format_description::well_known::Rfc3339)
351 .ok()
352 })
353 .unwrap_or_else(|| "<invalid date>".to_string());
354
355 write!(f, "{}", date)
356 }
357 Term::Bytes(s) => write!(f, "hex:{}", hex::encode(s)),
358 Term::Bool(b) => {
359 if *b {
360 write!(f, "true")
361 } else {
362 write!(f, "false")
363 }
364 }
365 Term::Set(s) => {
366 if s.is_empty() {
367 write!(f, "{{,}}")
368 } else {
369 let terms = s.iter().map(|term| term.to_string()).collect::<Vec<_>>();
370 write!(f, "{{{}}}", terms.join(", "))
371 }
372 }
373 Term::Parameter(s) => {
374 write!(f, "{{{}}}", s)
375 }
376 Term::Null => write!(f, "null"),
377 Term::Array(a) => {
378 let terms = a.iter().map(|term| term.to_string()).collect::<Vec<_>>();
379 write!(f, "[{}]", terms.join(", "))
380 }
381 Term::Map(m) => {
382 let terms = m
383 .iter()
384 .map(|(key, term)| match key {
385 MapKey::Integer(i) => format!("{i}: {}", term),
386 MapKey::Str(s) => format!("\"{s}\": {}", term),
387 MapKey::Parameter(s) => format!("{{{s}}}: {}", term),
388 })
389 .collect::<Vec<_>>();
390 write!(f, "{{{}}}", terms.join(", "))
391 }
392 }
393 }
394}
395
396#[cfg(feature = "datalog-macro")]
397impl ToAnyParam for Term {
398 fn to_any_param(&self) -> AnyParam {
399 AnyParam::Term(self.clone())
400 }
401}
402
403impl From<i64> for Term {
404 fn from(i: i64) -> Self {
405 Term::Integer(i)
406 }
407}
408
409#[cfg(feature = "datalog-macro")]
410impl ToAnyParam for i64 {
411 fn to_any_param(&self) -> AnyParam {
412 AnyParam::Term((*self).into())
413 }
414}
415
416impl TryFrom<Term> for i64 {
417 type Error = error::Token;
418 fn try_from(value: Term) -> Result<Self, Self::Error> {
419 match value {
420 Term::Integer(i) => Ok(i),
421 _ => Err(error::Token::ConversionError(format!(
422 "expected integer, got {:?}",
423 value
424 ))),
425 }
426 }
427}
428
429impl From<bool> for Term {
430 fn from(b: bool) -> Self {
431 Term::Bool(b)
432 }
433}
434
435#[cfg(feature = "datalog-macro")]
436impl ToAnyParam for bool {
437 fn to_any_param(&self) -> AnyParam {
438 AnyParam::Term((*self).into())
439 }
440}
441
442impl TryFrom<Term> for bool {
443 type Error = error::Token;
444 fn try_from(value: Term) -> Result<Self, Self::Error> {
445 match value {
446 Term::Bool(b) => Ok(b),
447 _ => Err(error::Token::ConversionError(format!(
448 "expected boolean, got {:?}",
449 value
450 ))),
451 }
452 }
453}
454
455impl From<String> for Term {
456 fn from(s: String) -> Self {
457 Term::Str(s)
458 }
459}
460
461#[cfg(feature = "datalog-macro")]
462impl ToAnyParam for String {
463 fn to_any_param(&self) -> AnyParam {
464 AnyParam::Term((self.clone()).into())
465 }
466}
467
468impl From<&str> for Term {
469 fn from(s: &str) -> Self {
470 Term::Str(s.into())
471 }
472}
473
474#[cfg(feature = "datalog-macro")]
475impl ToAnyParam for &str {
476 fn to_any_param(&self) -> AnyParam {
477 AnyParam::Term(self.to_string().into())
478 }
479}
480
481impl TryFrom<Term> for String {
482 type Error = error::Token;
483 fn try_from(value: Term) -> Result<Self, Self::Error> {
484 match value {
485 Term::Str(s) => Ok(s),
486 _ => Err(error::Token::ConversionError(format!(
487 "expected string or symbol, got {:?}",
488 value
489 ))),
490 }
491 }
492}
493
494impl From<Vec<u8>> for Term {
495 fn from(v: Vec<u8>) -> Self {
496 Term::Bytes(v)
497 }
498}
499
500#[cfg(feature = "datalog-macro")]
501impl ToAnyParam for Vec<u8> {
502 fn to_any_param(&self) -> AnyParam {
503 AnyParam::Term((self.clone()).into())
504 }
505}
506
507impl TryFrom<Term> for Vec<u8> {
508 type Error = error::Token;
509 fn try_from(value: Term) -> Result<Self, Self::Error> {
510 match value {
511 Term::Bytes(b) => Ok(b),
512 _ => Err(error::Token::ConversionError(format!(
513 "expected byte array, got {:?}",
514 value
515 ))),
516 }
517 }
518}
519
520impl From<&[u8]> for Term {
521 fn from(v: &[u8]) -> Self {
522 Term::Bytes(v.into())
523 }
524}
525
526#[cfg(feature = "datalog-macro")]
527impl ToAnyParam for [u8] {
528 fn to_any_param(&self) -> AnyParam {
529 AnyParam::Term(self.into())
530 }
531}
532
533#[cfg(feature = "uuid")]
534impl ToAnyParam for uuid::Uuid {
535 fn to_any_param(&self) -> AnyParam {
536 AnyParam::Term(Term::Bytes(self.as_bytes().to_vec()))
537 }
538}
539
540impl From<SystemTime> for Term {
541 fn from(t: SystemTime) -> Self {
542 let dur = t.duration_since(UNIX_EPOCH).unwrap();
543 Term::Date(dur.as_secs())
544 }
545}
546
547#[cfg(feature = "datalog-macro")]
548impl ToAnyParam for SystemTime {
549 fn to_any_param(&self) -> AnyParam {
550 AnyParam::Term((*self).into())
551 }
552}
553
554impl TryFrom<Term> for SystemTime {
555 type Error = error::Token;
556 fn try_from(value: Term) -> Result<Self, Self::Error> {
557 match value {
558 Term::Date(d) => Ok(UNIX_EPOCH + Duration::from_secs(d)),
559 _ => Err(error::Token::ConversionError(format!(
560 "expected date, got {:?}",
561 value
562 ))),
563 }
564 }
565}
566
567impl From<BTreeSet<Term>> for Term {
568 fn from(value: BTreeSet<Term>) -> Term {
569 set(value)
570 }
571}
572
573#[cfg(feature = "datalog-macro")]
574impl ToAnyParam for BTreeSet<Term> {
575 fn to_any_param(&self) -> AnyParam {
576 AnyParam::Term((self.clone()).into())
577 }
578}
579
580impl<T: Ord + TryFrom<Term, Error = error::Token>> TryFrom<Term> for BTreeSet<T> {
581 type Error = error::Token;
582 fn try_from(value: Term) -> Result<Self, Self::Error> {
583 match value {
584 Term::Set(d) => d.iter().cloned().map(TryFrom::try_from).collect(),
585 _ => Err(error::Token::ConversionError(format!(
586 "expected set, got {:?}",
587 value
588 ))),
589 }
590 }
591}
592
593impl TryFrom<serde_json::Value> for Term {
595 type Error = &'static str;
596
597 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
598 match value {
599 serde_json::Value::Null => Ok(Term::Null),
600 serde_json::Value::Bool(b) => Ok(Term::Bool(b)),
601 serde_json::Value::Number(i) => match i.as_i64() {
602 Some(i) => Ok(Term::Integer(i)),
603 None => Err("Biscuit values do not support floating point numbers"),
604 },
605 serde_json::Value::String(s) => Ok(Term::Str(s)),
606 serde_json::Value::Array(array) => Ok(Term::Array(
607 array
608 .into_iter()
609 .map(|v| v.try_into())
610 .collect::<Result<_, _>>()?,
611 )),
612 serde_json::Value::Object(o) => Ok(Term::Map(
613 o.into_iter()
614 .map(|(key, value)| {
615 let value: Term = value.try_into()?;
616 Ok::<_, &'static str>((MapKey::Str(key), value))
617 })
618 .collect::<Result<_, _>>()?,
619 )),
620 }
621 }
622}
623
624macro_rules! tuple_try_from(
625 ($ty1:ident, $ty2:ident, $($ty:ident),*) => (
626 tuple_try_from!(__impl $ty1, $ty2; $($ty),*);
627 );
628 (__impl $($ty: ident),+; $ty1:ident, $($ty2:ident),*) => (
629 tuple_try_from_impl!($($ty),+);
630 tuple_try_from!(__impl $($ty),+ , $ty1; $($ty2),*);
631 );
632 (__impl $($ty: ident),+; $ty1:ident) => (
633 tuple_try_from_impl!($($ty),+);
634 tuple_try_from_impl!($($ty),+, $ty1);
635 );
636 );
637
638impl<A: TryFrom<Term, Error = error::Token>> TryFrom<Fact> for (A,) {
639 type Error = error::Token;
640 fn try_from(fact: Fact) -> Result<Self, Self::Error> {
641 let mut terms = fact.predicate.terms;
642 let mut it = terms.drain(..);
643
644 Ok((it
645 .next()
646 .ok_or_else(|| error::Token::ConversionError("not enough terms in fact".to_string()))
647 .and_then(A::try_from)?,))
648 }
649}
650
651macro_rules! tuple_try_from_impl(
652 ($($ty: ident),+) => (
653 impl<$($ty: TryFrom<Term, Error = error::Token>),+> TryFrom<Fact> for ($($ty),+) {
654 type Error = error::Token;
655 fn try_from(fact: Fact) -> Result<Self, Self::Error> {
656 let mut terms = fact.predicate.terms;
657 let mut it = terms.drain(..);
658
659 Ok((
660 $(
661 it.next().ok_or(error::Token::ConversionError("not enough terms in fact".to_string())).and_then($ty::try_from)?
662 ),+
663 ))
664
665 }
666 }
667 );
668 );
669
670tuple_try_from!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U);