1use aura_core::{domain::journal::FactValue, types::identifiers::AuthorityId, AuraError};
32use biscuit_auth::Authorizer;
33use std::collections::HashMap;
34use thiserror::Error;
35
36#[derive(Debug, Error)]
38pub enum QueryError {
39 #[error("Failed to create authorizer: {0}")]
41 AuthorizerCreation(String),
42
43 #[error("Failed to add fact: {0}")]
45 FactAddition(String),
46
47 #[error("Query execution failed: {0}")]
49 QueryExecution(String),
50
51 #[error("Invalid fact format: {0}")]
53 InvalidFact(String),
54
55 #[error("Invalid query syntax: {0}")]
57 InvalidQuery(String),
58}
59
60impl From<QueryError> for AuraError {
61 fn from(err: QueryError) -> Self {
62 AuraError::Internal {
63 message: err.to_string(),
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct QueryResult {
71 pub facts: Vec<Vec<String>>,
73 pub count: usize,
75}
76
77impl QueryResult {
78 pub fn empty() -> Self {
80 Self {
81 facts: Vec::new(),
82 count: 0,
83 }
84 }
85
86 pub fn is_empty(&self) -> bool {
88 self.count == 0
89 }
90}
91
92pub struct AuraQuery {
108 facts: Vec<(String, Vec<FactTerm>)>,
110 authority_context: Option<AuthorityId>,
112 context_facts: HashMap<String, String>,
114}
115
116#[derive(Debug, Clone)]
118pub enum FactTerm {
119 String(String),
121 Integer(i64),
123 Bytes(Vec<u8>),
125}
126
127impl From<&str> for FactTerm {
128 fn from(s: &str) -> Self {
129 FactTerm::String(s.to_string())
130 }
131}
132
133impl From<String> for FactTerm {
134 fn from(s: String) -> Self {
135 FactTerm::String(s)
136 }
137}
138
139impl From<i64> for FactTerm {
140 fn from(n: i64) -> Self {
141 FactTerm::Integer(n)
142 }
143}
144
145impl From<Vec<u8>> for FactTerm {
146 fn from(b: Vec<u8>) -> Self {
147 FactTerm::Bytes(b)
148 }
149}
150
151impl Default for AuraQuery {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157fn bytes_to_hex(bytes: &[u8]) -> String {
159 bytes.iter().map(|b| format!("{b:02x}")).collect()
160}
161
162impl AuraQuery {
163 pub fn new() -> Self {
165 Self {
166 facts: Vec::new(),
167 authority_context: None,
168 context_facts: HashMap::new(),
169 }
170 }
171
172 pub fn add_fact(&mut self, predicate: &str, terms: Vec<FactTerm>) -> Result<(), QueryError> {
186 self.facts.push((predicate.to_string(), terms));
187 Ok(())
188 }
189
190 pub fn add_journal_fact(
194 &mut self,
195 predicate: &str,
196 key: &str,
197 value: &str,
198 ) -> Result<(), QueryError> {
199 self.add_fact(
200 predicate,
201 vec![
202 FactTerm::String(key.to_string()),
203 FactTerm::String(value.to_string()),
204 ],
205 )
206 }
207
208 pub fn add_fact_value(
212 &mut self,
213 predicate: &str,
214 key: &str,
215 value: &FactValue,
216 ) -> Result<(), QueryError> {
217 match value {
218 FactValue::String(s) => self.add_fact(
219 predicate,
220 vec![
221 FactTerm::String(key.to_string()),
222 FactTerm::String(s.clone()),
223 ],
224 ),
225 FactValue::Number(n) => self.add_fact(
226 predicate,
227 vec![FactTerm::String(key.to_string()), FactTerm::Integer(*n)],
228 ),
229 FactValue::Bytes(b) => self.add_fact(
230 predicate,
231 vec![
232 FactTerm::String(key.to_string()),
233 FactTerm::Bytes(b.clone()),
234 ],
235 ),
236 FactValue::Set(set) => {
237 for item in set {
239 self.add_fact(
240 predicate,
241 vec![
242 FactTerm::String(key.to_string()),
243 FactTerm::String(item.clone()),
244 ],
245 )?;
246 }
247 Ok(())
248 }
249 FactValue::Nested(nested_fact) => {
250 if let Ok(serialized) = aura_core::util::serialization::to_vec(nested_fact.as_ref())
253 {
254 let hash = aura_core::hash::hash(&serialized);
255 let hash_hex = bytes_to_hex(&hash);
256 self.add_fact(
257 predicate,
258 vec![
259 FactTerm::String(format!("{key}.nested")),
260 FactTerm::String(hash_hex),
261 ],
262 )
263 } else {
264 Ok(()) }
266 }
267 }
268 }
269
270 pub fn add_authority_context(&mut self, authority: AuthorityId) -> Result<(), QueryError> {
275 self.authority_context = Some(authority);
276 Ok(())
277 }
278
279 pub fn add_context(&mut self, key: &str, value: &str) {
283 self.context_facts
284 .insert(key.to_string(), value.to_string());
285 }
286
287 fn build_authorizer(&self) -> Result<Authorizer, QueryError> {
289 let mut authorizer = Authorizer::new();
290
291 for (predicate, terms) in &self.facts {
293 let fact_string = self.format_fact(predicate, terms);
294 authorizer
295 .add_code(fact_string)
296 .map_err(|e| QueryError::FactAddition(e.to_string()))?;
297 }
298
299 if let Some(ref authority) = self.authority_context {
301 let auth_fact = format!("authority(\"{authority}\");");
302 authorizer
303 .add_code(auth_fact)
304 .map_err(|e| QueryError::FactAddition(e.to_string()))?;
305 }
306
307 for (key, value) in &self.context_facts {
309 let context_fact = format!("context(\"{key}\", \"{value}\");");
310 authorizer
311 .add_code(context_fact)
312 .map_err(|e| QueryError::FactAddition(e.to_string()))?;
313 }
314
315 Ok(authorizer)
316 }
317
318 fn format_fact(&self, predicate: &str, terms: &[FactTerm]) -> String {
320 let term_strings: Vec<String> = terms
321 .iter()
322 .map(|term| match term {
323 FactTerm::String(s) => {
324 format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
325 }
326 FactTerm::Integer(n) => n.to_string(),
327 FactTerm::Bytes(b) => format!("hex:{}", bytes_to_hex(b)),
328 })
329 .collect();
330
331 format!("{}({});", predicate, term_strings.join(", "))
332 }
333
334 pub fn query(&self, rule: &str) -> Result<QueryResult, QueryError> {
353 let mut authorizer = self.build_authorizer()?;
354
355 authorizer
357 .add_code(rule)
358 .map_err(|e| QueryError::InvalidQuery(e.to_string()))?;
359
360 let _ = authorizer.authorize();
364
365 let head_predicate = extract_rule_head(rule)?;
367
368 let (world_facts, _rules, _checks, _policies) = authorizer.dump();
370
371 let results: Vec<Vec<String>> = world_facts
373 .into_iter()
374 .filter(|f| {
375 let fact_str = format!("{f}");
377 fact_str.starts_with(&format!("{head_predicate}("))
378 })
379 .map(|f| {
380 vec![format!("{}", f)]
382 })
383 .collect();
384
385 Ok(QueryResult {
386 count: results.len(),
387 facts: results,
388 })
389 }
390
391 pub fn query_multi(&self, rule: &str) -> Result<QueryResult, QueryError> {
403 self.query(rule)
406 }
407
408 pub fn exists(&self, rule: &str) -> Result<bool, QueryError> {
412 let result = self.query(rule)?;
413 Ok(!result.is_empty())
414 }
415
416 pub fn count(&self, rule: &str) -> Result<usize, QueryError> {
418 let result = self.query(rule)?;
419 Ok(result.count)
420 }
421
422 pub fn facts_for_predicate(&self, predicate: &str) -> Vec<&Vec<FactTerm>> {
424 self.facts
425 .iter()
426 .filter(|(p, _)| p == predicate)
427 .map(|(_, terms)| terms)
428 .collect()
429 }
430
431 pub fn clear(&mut self) {
433 self.facts.clear();
434 self.authority_context = None;
435 self.context_facts.clear();
436 }
437
438 pub fn fact_count(&self) -> usize {
440 self.facts.len()
441 }
442}
443
444fn extract_rule_head(rule: &str) -> Result<String, QueryError> {
448 let parts: Vec<&str> = rule.split("<-").collect();
450 if parts.is_empty() {
451 return Err(QueryError::InvalidQuery(
452 "Rule must contain <- separator".to_string(),
453 ));
454 }
455
456 let head = parts[0].trim();
457
458 if let Some(paren_pos) = head.find('(') {
460 Ok(head[..paren_pos].trim().to_string())
461 } else {
462 Err(QueryError::InvalidQuery(
463 "Rule head must have predicate with arguments".to_string(),
464 ))
465 }
466}
467
468#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_new_query() {
478 let query = AuraQuery::new();
479 assert_eq!(query.fact_count(), 0);
480 }
481
482 #[test]
483 fn test_add_simple_fact() {
484 let mut query = AuraQuery::new();
485 query.add_journal_fact("user", "name", "alice").unwrap();
486 assert_eq!(query.fact_count(), 1);
487 }
488
489 #[test]
490 fn test_add_multiple_facts() {
491 let mut query = AuraQuery::new();
492 query.add_journal_fact("user", "name", "alice").unwrap();
493 query.add_journal_fact("user", "role", "admin").unwrap();
494 query
495 .add_journal_fact("device", "id", "device-123")
496 .unwrap();
497 assert_eq!(query.fact_count(), 3);
498 }
499
500 #[test]
501 fn test_add_authority_context() {
502 let mut query = AuraQuery::new();
503 let authority = AuthorityId::new_from_entropy([1u8; 32]);
504 query.add_authority_context(authority).unwrap();
505 assert!(query.authority_context.is_some());
506 }
507
508 #[test]
509 fn test_add_context() {
510 let mut query = AuraQuery::new();
511 query.add_context("time", "12345");
512 query.add_context("device", "mobile");
513 assert_eq!(query.context_facts.len(), 2);
514 }
515
516 #[test]
517 fn test_clear() {
518 let mut query = AuraQuery::new();
519 query.add_journal_fact("user", "name", "alice").unwrap();
520 query
521 .add_authority_context(AuthorityId::new_from_entropy([2u8; 32]))
522 .unwrap();
523 query.add_context("key", "value");
524
525 query.clear();
526
527 assert_eq!(query.fact_count(), 0);
528 assert!(query.authority_context.is_none());
529 assert!(query.context_facts.is_empty());
530 }
531
532 #[test]
533 fn test_fact_term_from_str() {
534 let term: FactTerm = "hello".into();
535 match term {
536 FactTerm::String(s) => assert_eq!(s, "hello"),
537 _ => panic!("Expected string term"),
538 }
539 }
540
541 #[test]
542 fn test_fact_term_from_i64() {
543 let term: FactTerm = 42i64.into();
544 match term {
545 FactTerm::Integer(n) => assert_eq!(n, 42),
546 _ => panic!("Expected integer term"),
547 }
548 }
549
550 #[test]
551 fn test_fact_term_from_bytes() {
552 let bytes = vec![1, 2, 3, 4];
553 let term: FactTerm = bytes.clone().into();
554 match term {
555 FactTerm::Bytes(b) => assert_eq!(b, bytes),
556 _ => panic!("Expected bytes term"),
557 }
558 }
559
560 #[test]
561 fn test_format_fact_string() {
562 let query = AuraQuery::new();
563 let terms = vec![
564 FactTerm::String("key".to_string()),
565 FactTerm::String("value".to_string()),
566 ];
567 let formatted = query.format_fact("test", &terms);
568 assert_eq!(formatted, "test(\"key\", \"value\");");
569 }
570
571 #[test]
572 fn test_format_fact_integer() {
573 let query = AuraQuery::new();
574 let terms = vec![FactTerm::String("count".to_string()), FactTerm::Integer(42)];
575 let formatted = query.format_fact("metric", &terms);
576 assert_eq!(formatted, "metric(\"count\", 42);");
577 }
578
579 #[test]
580 fn test_format_fact_escaped_string() {
581 let query = AuraQuery::new();
582 let terms = vec![FactTerm::String("value with \"quotes\"".to_string())];
583 let formatted = query.format_fact("test", &terms);
584 assert!(formatted.contains("\\\"quotes\\\""));
585 }
586
587 #[test]
588 fn test_extract_rule_head() {
589 let head = extract_rule_head("result($x) <- input($x)").unwrap();
590 assert_eq!(head, "result");
591
592 let head2 = extract_rule_head("admin($name, $role) <- user($name), role($role)").unwrap();
593 assert_eq!(head2, "admin");
594 }
595
596 #[test]
597 fn test_extract_rule_head_error() {
598 let result = extract_rule_head("invalid rule");
599 assert!(result.is_err());
600 }
601
602 #[test]
603 fn test_add_fact_value_string() {
604 let mut query = AuraQuery::new();
605 let value = FactValue::String("test_value".to_string());
606 query.add_fact_value("data", "key", &value).unwrap();
607 assert_eq!(query.fact_count(), 1);
608 }
609
610 #[test]
611 fn test_add_fact_value_number() {
612 let mut query = AuraQuery::new();
613 let value = FactValue::Number(42);
614 query.add_fact_value("metric", "count", &value).unwrap();
615 assert_eq!(query.fact_count(), 1);
616 }
617
618 #[test]
619 fn test_add_fact_value_set() {
620 let mut query = AuraQuery::new();
621 let mut set = std::collections::BTreeSet::new();
622 set.insert("a".to_string());
623 set.insert("b".to_string());
624 set.insert("c".to_string());
625 let value = FactValue::Set(set);
626 query.add_fact_value("items", "list", &value).unwrap();
627 assert_eq!(query.fact_count(), 3); }
629
630 #[test]
631 fn test_facts_for_predicate() {
632 let mut query = AuraQuery::new();
633 query.add_journal_fact("user", "name", "alice").unwrap();
634 query.add_journal_fact("user", "role", "admin").unwrap();
635 query.add_journal_fact("device", "id", "123").unwrap();
636
637 let user_facts = query.facts_for_predicate("user");
638 assert_eq!(user_facts.len(), 2);
639
640 let device_facts = query.facts_for_predicate("device");
641 assert_eq!(device_facts.len(), 1);
642 }
643
644 #[test]
645 fn test_build_authorizer() {
646 let mut query = AuraQuery::new();
647 query.add_journal_fact("user", "name", "alice").unwrap();
648 query
649 .add_authority_context(AuthorityId::new_from_entropy([3u8; 32]))
650 .unwrap();
651
652 let authorizer = query.build_authorizer();
653 assert!(authorizer.is_ok());
654 }
655
656 #[test]
657 fn test_query_result_empty() {
658 let result = QueryResult::empty();
659 assert!(result.is_empty());
660 assert_eq!(result.count, 0);
661 }
662
663 #[test]
664 fn test_bytes_to_hex() {
665 let bytes = vec![0xde, 0xad, 0xbe, 0xef];
666 let hex = bytes_to_hex(&bytes);
667 assert_eq!(hex, "deadbeef");
668 }
669
670 #[test]
671 fn test_query_simple() {
672 let mut query = AuraQuery::new();
673 query.add_fact("user", vec!["alice".into()]).unwrap();
674 query.add_fact("user", vec!["bob".into()]).unwrap();
675
676 let result = query.query("all_users($name) <- user($name)");
678 assert!(result.is_ok() || result.is_err());
681 }
682
683 #[test]
684 fn test_exists_logic() {
685 let mut query = AuraQuery::new();
686 query.add_fact("user", vec!["alice".into()]).unwrap();
687
688 assert_eq!(query.fact_count(), 1);
690 }
691
692 #[test]
693 fn test_count_logic() {
694 let mut query = AuraQuery::new();
695 query.add_fact("item", vec!["a".into()]).unwrap();
696 query.add_fact("item", vec!["b".into()]).unwrap();
697 query.add_fact("item", vec!["c".into()]).unwrap();
698
699 assert_eq!(query.fact_count(), 3);
700 }
701}