use anyhow::Result;
use crate::multitenancy::TenantContext;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AccessViolation {
pub field_path: String,
pub reason: String,
}
#[derive(Debug, Default)]
pub struct TenantQueryFilter;
impl TenantQueryFilter {
pub fn new() -> Self {
Self
}
pub fn filter_query(&self, query: &str, context: &TenantContext) -> Result<String> {
let allowed = &context.config.custom_types;
if allowed.is_empty() {
return Ok(query.to_string());
}
let allowed_names: Vec<&str> = allowed.iter().map(|t| t.type_name.as_str()).collect();
let filtered: Vec<&str> = query
.lines()
.filter(|line| {
let trimmed = line.trim();
if trimmed.is_empty()
|| trimmed == "{"
|| trimmed == "}"
|| trimmed.starts_with('#')
{
return true;
}
if let Some(type_name) = Self::extract_type_name(trimmed) {
return allowed_names.contains(&type_name);
}
true
})
.collect();
Ok(filtered.join("\n"))
}
pub fn validate_tenant_access(
&self,
query: &str,
context: &TenantContext,
) -> Vec<AccessViolation> {
let allowed = &context.config.custom_types;
if allowed.is_empty() {
return vec![];
}
let allowed_names: Vec<&str> = allowed.iter().map(|t| t.type_name.as_str()).collect();
let mut violations = Vec::new();
for line in query.lines() {
let trimmed = line.trim();
if let Some(type_name) = Self::extract_type_name(trimmed) {
if !allowed_names.contains(&type_name) {
violations.push(AccessViolation {
field_path: type_name.to_string(),
reason: format!(
"Type '{}' is not in the allowed types list for tenant '{}'",
type_name, context.tenant_id
),
});
}
}
}
violations
}
pub fn validate_field_access(
&self,
type_name: &str,
field_name: &str,
context: &TenantContext,
) -> Option<AccessViolation> {
let allowed = &context.config.custom_types;
if allowed.is_empty() {
return None;
}
let type_def = allowed.iter().find(|t| t.type_name == type_name);
match type_def {
None => Some(AccessViolation {
field_path: format!("{type_name}.{field_name}"),
reason: format!(
"Type '{type_name}' is not allowed for tenant '{}'",
context.tenant_id
),
}),
Some(td) => {
let field_allowed = td.fields.iter().any(|f| f.field_name == field_name);
if field_allowed {
None
} else {
Some(AccessViolation {
field_path: format!("{type_name}.{field_name}"),
reason: format!(
"Field '{field_name}' on type '{type_name}' is not allowed for tenant '{}'",
context.tenant_id
),
})
}
}
}
}
fn extract_type_name(line: &str) -> Option<&str> {
let trimmed = line.trim_end_matches('{').trim();
let candidate = trimmed.split_whitespace().next()?;
if candidate.starts_with(|c: char| c.is_uppercase()) {
Some(candidate)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multitenancy::{TenantConfig, TenantContext, TenantCustomType, TenantField};
use std::sync::Arc;
fn make_context(type_names: &[&str]) -> TenantContext {
let custom_types: Vec<TenantCustomType> = type_names
.iter()
.map(|name| TenantCustomType {
type_name: name.to_string(),
rdf_class: format!("http://example.org/{name}"),
fields: vec![
TenantField {
field_name: "id".to_string(),
rdf_predicate: "http://example.org/id".to_string(),
field_type: "ID".to_string(),
is_required: true,
is_list: false,
},
TenantField {
field_name: "name".to_string(),
rdf_predicate: "http://example.org/name".to_string(),
field_type: "String".to_string(),
is_required: false,
is_list: false,
},
],
})
.collect();
let config = TenantConfig {
tenant_id: "test-tenant".to_string(),
display_name: "Test Tenant".to_string(),
datasets: vec![],
max_query_depth: 10,
max_query_complexity: 1000,
rate_limit_rpm: 60,
allowed_operations: vec![crate::multitenancy::TenantOperation::Query],
custom_types,
};
TenantContext::new(
"test-tenant",
Arc::new(config),
"req-test".to_string(),
None,
)
}
#[test]
fn test_filter_query_open_policy() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&[]);
let query = "{ User { id name } }";
let filtered = filter.filter_query(query, &ctx).expect("should succeed");
assert_eq!(filtered, query);
}
#[test]
fn test_filter_query_removes_disallowed_type() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let query = "query {\n User {\n id\n }\n Admin {\n secret\n }\n}";
let filtered = filter.filter_query(query, &ctx).expect("should succeed");
assert!(!filtered.contains("Admin"), "Admin should be stripped");
assert!(filtered.contains("User"), "User should remain");
}
#[test]
fn test_validate_tenant_access_no_violations() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User", "Product"]);
let query = "{\n User {\n id\n }\n}";
let violations = filter.validate_tenant_access(query, &ctx);
assert!(violations.is_empty());
}
#[test]
fn test_validate_tenant_access_detects_violation() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let query = "{\n User {\n id\n }\n Admin {\n secret\n }\n}";
let violations = filter.validate_tenant_access(query, &ctx);
assert!(!violations.is_empty());
assert!(violations.iter().any(|v| v.field_path.contains("Admin")));
}
#[test]
fn test_validate_field_access_allowed() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let violation = filter.validate_field_access("User", "name", &ctx);
assert!(violation.is_none());
}
#[test]
fn test_validate_field_access_disallowed_field() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let violation = filter.validate_field_access("User", "passwordHash", &ctx);
assert!(violation.is_some());
let v = violation.expect("should succeed");
assert!(v.field_path.contains("passwordHash"));
assert!(v.reason.contains("not allowed"));
}
#[test]
fn test_validate_field_access_disallowed_type() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let violation = filter.validate_field_access("Admin", "id", &ctx);
assert!(violation.is_some());
assert!(violation
.expect("should succeed")
.field_path
.starts_with("Admin"));
}
#[test]
fn test_validate_open_policy_no_violations() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&[]);
let violations = filter.validate_tenant_access("{ Admin { secret } }", &ctx);
assert!(violations.is_empty());
}
#[test]
fn test_violation_reason_mentions_tenant() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let query = "Admin {";
let violations = filter.validate_tenant_access(query, &ctx);
assert!(!violations.is_empty());
assert!(violations[0].reason.contains("test-tenant"));
}
#[test]
fn test_multiple_violations_detected() {
let filter = TenantQueryFilter::new();
let ctx = make_context(&["User"]);
let query = "{\n Admin {\n x\n }\n Superuser {\n y\n }\n}";
let violations = filter.validate_tenant_access(query, &ctx);
assert!(violations.len() >= 2);
}
}