use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use crate::testing::AuthorizationClient;
use crate::types::{Context, Relationship};
use crate::Error;
#[derive(Clone)]
pub struct InMemoryClient {
relationships: Arc<RwLock<HashSet<StoredRelationship>>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct StoredRelationship {
resource: String,
relation: String,
subject: String,
}
impl From<&Relationship<'_>> for StoredRelationship {
fn from(rel: &Relationship<'_>) -> Self {
Self {
resource: rel.resource().to_string(),
relation: rel.relation().to_string(),
subject: rel.subject().to_string(),
}
}
}
impl InMemoryClient {
pub fn new() -> Self {
Self {
relationships: Arc::new(RwLock::new(HashSet::new())),
}
}
pub fn write(&self, relationship: Relationship<'_>) {
let stored = StoredRelationship::from(&relationship);
self.relationships.write().unwrap().insert(stored);
}
pub fn write_all<'a>(&self, relationships: impl IntoIterator<Item = Relationship<'a>>) {
let mut store = self.relationships.write().unwrap();
for rel in relationships {
store.insert(StoredRelationship::from(&rel));
}
}
pub fn delete(&self, relationship: &Relationship<'_>) -> bool {
let stored = StoredRelationship::from(relationship);
self.relationships.write().unwrap().remove(&stored)
}
pub fn clear(&self) {
self.relationships.write().unwrap().clear();
}
pub fn len(&self) -> usize {
self.relationships.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.relationships.read().unwrap().is_empty()
}
fn has_direct_relationship(&self, resource: &str, relation: &str, subject: &str) -> bool {
let store = self.relationships.read().unwrap();
store.contains(&StoredRelationship {
resource: resource.to_string(),
relation: relation.to_string(),
subject: subject.to_string(),
})
}
}
impl Default for InMemoryClient {
fn default() -> Self {
Self::new()
}
}
impl AuthorizationClient for InMemoryClient {
fn check(
&self,
subject: &str,
permission: &str,
resource: &str,
) -> Pin<Box<dyn Future<Output = Result<bool, Error>> + Send + '_>> {
let result = self.has_direct_relationship(resource, permission, subject);
Box::pin(async move { Ok(result) })
}
fn check_with_context(
&self,
subject: &str,
permission: &str,
resource: &str,
_context: &Context,
) -> Pin<Box<dyn Future<Output = Result<bool, Error>> + Send + '_>> {
self.check(subject, permission, resource)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_in_memory_client_new() {
let client = InMemoryClient::new();
assert!(client.is_empty());
assert_eq!(client.len(), 0);
}
#[test]
fn test_in_memory_client_write() {
let client = InMemoryClient::new();
client.write(Relationship::new("doc:1", "viewer", "user:alice"));
assert_eq!(client.len(), 1);
}
#[test]
fn test_in_memory_client_write_all() {
let client = InMemoryClient::new();
client.write_all(vec![
Relationship::new("doc:1", "viewer", "user:alice"),
Relationship::new("doc:1", "editor", "user:bob"),
]);
assert_eq!(client.len(), 2);
}
#[test]
fn test_in_memory_client_delete() {
let client = InMemoryClient::new();
let rel = Relationship::new("doc:1", "viewer", "user:alice");
client.write(rel.as_borrowed());
assert!(client.delete(&rel));
assert!(client.is_empty());
assert!(!client.delete(&rel)); }
#[test]
fn test_in_memory_client_clear() {
let client = InMemoryClient::new();
client.write_all(vec![
Relationship::new("doc:1", "viewer", "user:alice"),
Relationship::new("doc:2", "viewer", "user:bob"),
]);
client.clear();
assert!(client.is_empty());
}
#[tokio::test]
async fn test_in_memory_client_check_direct() {
let client = InMemoryClient::new();
client.write(Relationship::new("doc:1", "viewer", "user:alice"));
let result = client.check("user:alice", "viewer", "doc:1").await.unwrap();
assert!(result);
let result = client.check("user:bob", "viewer", "doc:1").await.unwrap();
assert!(!result);
let result = client.check("user:alice", "editor", "doc:1").await.unwrap();
assert!(!result);
}
#[tokio::test]
async fn test_in_memory_client_check_with_context() {
let client = InMemoryClient::new();
client.write(Relationship::new("doc:1", "viewer", "user:alice"));
let context = Context::new().with("env", "test");
let result = client
.check_with_context("user:alice", "viewer", "doc:1", &context)
.await
.unwrap();
assert!(result);
let result = client
.check_with_context("user:bob", "viewer", "doc:1", &context)
.await
.unwrap();
assert!(!result);
}
#[test]
fn test_in_memory_client_clone() {
let client = InMemoryClient::new();
client.write(Relationship::new("doc:1", "viewer", "user:alice"));
let cloned = client.clone();
cloned.write(Relationship::new("doc:2", "viewer", "user:bob"));
assert_eq!(client.len(), 2);
assert_eq!(cloned.len(), 2);
}
}