use super::types::{Object, Relation, RelationTuple, Subject};
use crate::error::{Error, Result};
use lru::LruCache;
use std::collections::{HashMap, HashSet, VecDeque};
use std::num::NonZeroUsize;
use std::sync::{Arc, RwLock};
type CacheKey = (String, String, String);
#[derive(Debug)]
pub struct RelationshipGraph {
forward: HashMap<String, Vec<(String, String)>>,
reverse: HashMap<String, Vec<(String, String)>>,
tuples: HashSet<RelationTuple>,
cache: Arc<RwLock<LruCache<CacheKey, bool>>>,
cache_size: usize,
}
impl RelationshipGraph {
pub fn new() -> Self {
Self::with_cache_size(1000)
}
pub fn with_cache_size(size: usize) -> Self {
let cache_size =
NonZeroUsize::new(size).unwrap_or(unsafe { NonZeroUsize::new_unchecked(1) });
RelationshipGraph {
forward: HashMap::new(),
reverse: HashMap::new(),
tuples: HashSet::new(),
cache: Arc::new(RwLock::new(LruCache::new(cache_size))),
cache_size: size,
}
}
pub fn add_tuple(&mut self, tuple: RelationTuple) -> Result<()> {
if self.tuples.contains(&tuple) {
return Ok(()); }
let subject_key = tuple.subject.to_string_format();
let relation_key = tuple.relation.name.clone();
let object_key = tuple.object.to_string_format();
self.forward
.entry(object_key.clone())
.or_insert_with(Vec::new)
.push((subject_key.clone(), relation_key.clone()));
self.reverse
.entry(subject_key)
.or_insert_with(Vec::new)
.push((relation_key, object_key));
self.tuples.insert(tuple);
self.clear_cache();
Ok(())
}
pub fn remove_tuple(&mut self, tuple: &RelationTuple) -> Result<()> {
if !self.tuples.contains(tuple) {
return Err(Error::InvalidInput("Tuple not found".to_string()));
}
let subject_key = tuple.subject.to_string_format();
let relation_key = &tuple.relation.name;
let object_key = tuple.object.to_string_format();
if let Some(entries) = self.forward.get_mut(&object_key) {
entries.retain(|(s, r)| s != &subject_key || r != relation_key);
if entries.is_empty() {
self.forward.remove(&object_key);
}
}
if let Some(entries) = self.reverse.get_mut(&subject_key) {
entries.retain(|(r, o)| r != relation_key || o != &object_key);
if entries.is_empty() {
self.reverse.remove(&subject_key);
}
}
self.tuples.remove(tuple);
self.clear_cache();
Ok(())
}
pub fn check(&self, subject: &Subject, relation: &Relation, object: &Object) -> Result<bool> {
let cache_key = (
subject.to_string_format(),
relation.name.clone(),
object.to_string_format(),
);
if let Ok(cache_guard) = self.cache.read() {
if let Some(&result) = cache_guard.peek(&cache_key) {
return Ok(result);
}
}
let result = self.check_internal(subject, relation, object)?;
if let Ok(mut cache_guard) = self.cache.write() {
cache_guard.put(cache_key, result);
}
Ok(result)
}
fn check_internal(
&self,
subject: &Subject,
relation: &Relation,
object: &Object,
) -> Result<bool> {
let object_key = object.to_string_format();
let relation_name = &relation.name;
if self.has_direct_relationship(subject, relation, object) {
return Ok(true);
}
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(object_key.clone());
visited.insert(object_key.clone());
while let Some(current_obj) = queue.pop_front() {
if let Some(entries) = self.forward.get(¤t_obj) {
for (subj_key, rel_key) in entries {
if rel_key == relation_name {
if let Ok(current_subj) = Subject::parse(subj_key) {
if self.subject_matches(subject, ¤t_subj) {
return Ok(true);
}
if current_subj.relation.is_some() {
if self.is_member_of_set(subject, ¤t_subj)? {
return Ok(true);
}
}
}
}
}
}
if let Some(parent_entries) = self.forward.get(¤t_obj) {
for (_, rel_key) in parent_entries {
if rel_key == "parent" {
if let Some(reverse_entries) = self.reverse.get(¤t_obj) {
for (parent_rel, parent_obj) in reverse_entries {
if parent_rel == "parent" && !visited.contains(parent_obj) {
visited.insert(parent_obj.clone());
queue.push_back(parent_obj.clone());
}
}
}
}
}
}
}
Ok(false)
}
fn has_direct_relationship(
&self,
subject: &Subject,
relation: &Relation,
object: &Object,
) -> bool {
let subject_key = subject.to_string_format();
let relation_name = &relation.name;
let object_key = object.to_string_format();
if let Some(entries) = self.reverse.get(&subject_key) {
for (rel, obj) in entries {
if rel == relation_name && obj == &object_key {
return true;
}
}
}
false
}
fn subject_matches(&self, target: &Subject, candidate: &Subject) -> bool {
if target.subject_type != candidate.subject_type {
return false;
}
if target.subject_id != candidate.subject_id {
return false;
}
target.relation == candidate.relation
}
fn is_member_of_set(&self, subject: &Subject, set: &Subject) -> Result<bool> {
if let Some(ref set_relation) = set.relation {
let rel = Relation::new(set_relation);
let obj = Object::new(&set.subject_type, &set.subject_id);
self.check(subject, &rel, &obj)
} else {
Ok(false)
}
}
pub fn expand(&self, relation: &Relation, object: &Object) -> Result<Vec<Subject>> {
let object_key = object.to_string_format();
let relation_name = &relation.name;
let mut subjects = Vec::new();
if let Some(entries) = self.forward.get(&object_key) {
for (subj_key, rel_key) in entries {
if rel_key == relation_name {
if let Ok(subject) = Subject::parse(subj_key) {
subjects.push(subject);
}
}
}
}
Ok(subjects)
}
pub fn list_objects(&self, subject: &Subject, relation: &Relation) -> Result<Vec<Object>> {
let subject_key = subject.to_string_format();
let relation_name = &relation.name;
let mut objects = Vec::new();
if let Some(entries) = self.reverse.get(&subject_key) {
for (rel, obj_key) in entries {
if rel == relation_name {
if let Ok(object) = Object::parse(obj_key) {
objects.push(object);
}
}
}
}
Ok(objects)
}
pub fn get_all_tuples(&self) -> Vec<RelationTuple> {
self.tuples.iter().cloned().collect()
}
pub fn clear_cache(&self) {
if let Ok(mut cache_guard) = self.cache.write() {
cache_guard.clear();
}
}
pub fn cache_stats(&self) -> Result<(usize, usize)> {
if let Ok(cache_guard) = self.cache.read() {
Ok((cache_guard.len(), self.cache_size))
} else {
Err(Error::InvalidOperation(
"Failed to acquire cache lock".to_string(),
))
}
}
}
impl Default for RelationshipGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_and_check_tuple() {
let mut graph = RelationshipGraph::new();
let tuple = RelationTuple::new(
Subject::new("user", "alice"),
Relation::new("owner"),
Object::new("document", "123"),
);
graph.add_tuple(tuple.clone()).expect("add should succeed");
let has_rel = graph
.check(
&Subject::new("user", "alice"),
&Relation::new("owner"),
&Object::new("document", "123"),
)
.expect("check should succeed");
assert!(has_rel);
let has_rel = graph
.check(
&Subject::new("user", "bob"),
&Relation::new("owner"),
&Object::new("document", "123"),
)
.expect("check should succeed");
assert!(!has_rel);
}
#[test]
fn test_remove_tuple() {
let mut graph = RelationshipGraph::new();
let tuple = RelationTuple::new(
Subject::new("user", "alice"),
Relation::new("owner"),
Object::new("document", "123"),
);
graph.add_tuple(tuple.clone()).expect("add should succeed");
graph.remove_tuple(&tuple).expect("remove should succeed");
let has_rel = graph
.check(
&Subject::new("user", "alice"),
&Relation::new("owner"),
&Object::new("document", "123"),
)
.expect("check should succeed");
assert!(!has_rel);
}
#[test]
fn test_expand_subjects() {
let mut graph = RelationshipGraph::new();
graph
.add_tuple(RelationTuple::new(
Subject::new("user", "alice"),
Relation::new("viewer"),
Object::new("document", "123"),
))
.expect("add should succeed");
graph
.add_tuple(RelationTuple::new(
Subject::new("user", "bob"),
Relation::new("viewer"),
Object::new("document", "123"),
))
.expect("add should succeed");
let subjects = graph
.expand(&Relation::new("viewer"), &Object::new("document", "123"))
.expect("expand should succeed");
assert_eq!(subjects.len(), 2);
}
#[test]
fn test_list_objects() {
let mut graph = RelationshipGraph::new();
graph
.add_tuple(RelationTuple::new(
Subject::new("user", "alice"),
Relation::new("owner"),
Object::new("document", "123"),
))
.expect("add should succeed");
graph
.add_tuple(RelationTuple::new(
Subject::new("user", "alice"),
Relation::new("owner"),
Object::new("document", "456"),
))
.expect("add should succeed");
let objects = graph
.list_objects(&Subject::new("user", "alice"), &Relation::new("owner"))
.expect("list should succeed");
assert_eq!(objects.len(), 2);
}
#[test]
fn test_cache() {
let graph = RelationshipGraph::with_cache_size(10);
let mut graph_mut = graph;
graph_mut
.add_tuple(RelationTuple::new(
Subject::new("user", "alice"),
Relation::new("owner"),
Object::new("document", "123"),
))
.expect("add should succeed");
let result1 = graph_mut
.check(
&Subject::new("user", "alice"),
&Relation::new("owner"),
&Object::new("document", "123"),
)
.expect("check should succeed");
assert!(result1);
let result2 = graph_mut
.check(
&Subject::new("user", "alice"),
&Relation::new("owner"),
&Object::new("document", "123"),
)
.expect("check should succeed");
assert!(result2);
let (cache_len, cache_cap) = graph_mut.cache_stats().expect("stats should succeed");
assert_eq!(cache_len, 1);
assert_eq!(cache_cap, 10);
}
}