use async_trait::async_trait;
use casbin::{Adapter, Filter, Model, Result as CasbinResult};
use surrealdb::{Surreal, engine::any::Any};
use surrealdb_types::{RecordId, SurrealValue};
pub const TABLE: &str = "casbin_rule";
#[derive(Debug, Clone, SurrealValue)]
struct CasbinRule {
id: Option<RecordId>,
sec: String,
ptype: String,
v0: Option<String>,
v1: Option<String>,
v2: Option<String>,
v3: Option<String>,
v4: Option<String>,
v5: Option<String>,
}
impl CasbinRule {
fn new(sec: &str, ptype: &str, rule: &[String]) -> Self {
let get = |i: usize| rule.get(i).cloned();
Self {
id: None,
sec: sec.to_owned(),
ptype: ptype.to_owned(),
v0: get(0),
v1: get(1),
v2: get(2),
v3: get(3),
v4: get(4),
v5: get(5),
}
}
fn to_rule(&self) -> Vec<String> {
[&self.v0, &self.v1, &self.v2, &self.v3, &self.v4, &self.v5]
.iter()
.filter_map(|v| v.as_deref().map(str::to_owned))
.collect()
}
fn bind_values<'a>(
&self,
q: surrealdb::method::Query<'a, Any>,
) -> surrealdb::method::Query<'a, Any> {
q.bind(("v0", self.v0.clone()))
.bind(("v1", self.v1.clone()))
.bind(("v2", self.v2.clone()))
.bind(("v3", self.v3.clone()))
.bind(("v4", self.v4.clone()))
.bind(("v5", self.v5.clone()))
}
}
fn load_policy_line(m: &mut dyn Model, rule: &CasbinRule) {
let values = rule.to_rule();
if values.is_empty() {
return;
}
if let Some(sec_map) = m.get_mut_model().get_mut(&rule.sec)
&& let Some(assertion) = sec_map.get_mut(&rule.ptype)
{
assertion.get_mut_policy().insert(values);
}
}
pub struct SurrealAdapter {
db: Surreal<Any>,
table: String,
is_filtered: bool,
}
impl SurrealAdapter {
pub fn new(db: Surreal<Any>) -> Self {
Self {
db,
table: TABLE.to_owned(),
is_filtered: false,
}
}
pub fn with_table(db: Surreal<Any>, table: impl Into<String>) -> Self {
Self {
db,
table: table.into(),
is_filtered: false,
}
}
pub async fn create_table(&self) -> Result<(), surrealdb::Error> {
self.db
.query("DEFINE TABLE IF NOT EXISTS $table SCHEMALESS;")
.bind(("table", self.table.clone()))
.await?
.check()?;
Ok(())
}
}
#[async_trait]
impl Adapter for SurrealAdapter {
async fn load_policy(&mut self, m: &mut dyn Model) -> CasbinResult<()> {
for rule in self.get_all_rules().await? {
load_policy_line(m, &rule);
}
self.is_filtered = false;
Ok(())
}
async fn load_filtered_policy<'a>(
&mut self,
m: &mut dyn Model,
f: Filter<'a>,
) -> CasbinResult<()> {
for (sec, filter) in [("p", &f.p), ("g", &f.g)] {
let has_filter = filter.iter().any(|fv| !fv.is_empty());
let rules = if has_filter {
self.get_filtered_rules(sec, filter).await?
} else {
self.get_rules_by_sec(sec).await?
};
for rule in &rules {
load_policy_line(m, rule);
}
}
self.is_filtered = true;
Ok(())
}
async fn save_policy(&mut self, m: &mut dyn Model) -> CasbinResult<()> {
self.clear_policy().await?;
let mut all_rules: Vec<CasbinRule> = Vec::new();
for sec in ["p", "g"] {
if let Some(sec_map) = m.get_model().get(sec) {
for (ptype, assertion) in sec_map {
for policy in assertion.get_policy() {
all_rules.push(CasbinRule::new(sec, ptype, policy));
}
}
}
}
if !all_rules.is_empty() {
self.insert_entries(all_rules).await?;
}
Ok(())
}
async fn clear_policy(&mut self) -> CasbinResult<()> {
self.db
.query("DELETE type::table($table);")
.bind(("table", self.table.clone()))
.await
.map_err(io_err)?
.check()
.map_err(io_err)?;
Ok(())
}
fn is_filtered(&self) -> bool {
self.is_filtered
}
async fn add_policy(
&mut self,
sec: &str,
ptype: &str,
rule: Vec<String>,
) -> CasbinResult<bool> {
if self.rule_exists(sec, ptype, &rule).await? {
return Ok(false);
}
let entry = CasbinRule::new(sec, ptype, &rule);
let _: Option<CasbinRule> = self
.db
.create(&*self.table)
.content(entry)
.await
.map_err(io_err)?;
Ok(true)
}
async fn add_policies(
&mut self,
sec: &str,
ptype: &str,
rules: Vec<Vec<String>>,
) -> CasbinResult<bool> {
if self.any_rules_exist(sec, ptype, &rules).await? {
return Ok(false);
}
let entries: Vec<CasbinRule> = rules
.iter()
.map(|r| CasbinRule::new(sec, ptype, r))
.collect();
self.insert_entries(entries).await?;
Ok(true)
}
async fn remove_policy(
&mut self,
sec: &str,
ptype: &str,
rule: Vec<String>,
) -> CasbinResult<bool> {
self.delete_exact(sec, ptype, &rule).await
}
async fn remove_policies(
&mut self,
sec: &str,
ptype: &str,
rules: Vec<Vec<String>>,
) -> CasbinResult<bool> {
if rules.is_empty() {
return Ok(false);
}
self.delete_exact_batch(sec, ptype, &rules).await
}
async fn remove_filtered_policy(
&mut self,
sec: &str,
ptype: &str,
field_index: usize,
field_values: Vec<String>,
) -> CasbinResult<bool> {
self.delete_filtered(sec, ptype, field_index, &field_values)
.await
}
}
impl SurrealAdapter {
async fn insert_entries(&self, entries: Vec<CasbinRule>) -> CasbinResult<bool> {
let _: Vec<CasbinRule> = self
.db
.insert(&*self.table)
.content(entries)
.await
.map_err(io_err)?;
Ok(true)
}
async fn get_all_rules(&self) -> CasbinResult<Vec<CasbinRule>> {
self.db.select(&*self.table).await.map_err(io_err)
}
async fn get_rules_by_sec(&self, sec: &str) -> CasbinResult<Vec<CasbinRule>> {
let rules: Vec<CasbinRule> = self
.db
.query("SELECT * FROM type::table($table) WHERE sec = $sec")
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()))
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(rules)
}
async fn get_filtered_rules(
&self,
sec: &str,
filter: &[&str],
) -> CasbinResult<Vec<CasbinRule>> {
let mut conditions = vec!["sec = $sec".to_owned()];
let mut binds: Vec<(String, String)> = Vec::new();
for (i, fv) in filter.iter().enumerate() {
if !fv.is_empty() {
let param = format!("fv{i}");
conditions.push(format!("v{i} = ${param}"));
binds.push((param, (*fv).to_owned()));
}
}
let query = format!(
"SELECT * FROM type::table($table) WHERE {}",
conditions.join(" AND ")
);
let mut q = self
.db
.query(&query)
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()));
for (k, v) in binds {
q = q.bind((k, v));
}
let rules: Vec<CasbinRule> = q
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(rules)
}
async fn rule_exists(&self, sec: &str, ptype: &str, rule: &[String]) -> CasbinResult<bool> {
let entry = CasbinRule::new(sec, ptype, rule);
let q = self
.db
.query(
"SELECT * FROM type::table($table)
WHERE sec = $sec AND ptype = $ptype
AND v0 = $v0 AND v1 = $v1 AND v2 = $v2
AND v3 = $v3 AND v4 = $v4 AND v5 = $v5
LIMIT 1",
)
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()))
.bind(("ptype", ptype.to_owned()));
let found: Vec<CasbinRule> = entry
.bind_values(q)
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(!found.is_empty())
}
async fn any_rules_exist(
&self,
sec: &str,
ptype: &str,
rules: &[Vec<String>],
) -> CasbinResult<bool> {
if rules.is_empty() {
return Ok(false);
}
let mut or_clauses = Vec::new();
let mut binds: Vec<(String, Option<String>)> = Vec::new();
for (ri, rule) in rules.iter().enumerate() {
let entry = CasbinRule::new(sec, ptype, rule);
let fields = [
&entry.v0, &entry.v1, &entry.v2, &entry.v3, &entry.v4, &entry.v5,
];
let mut field_conditions = Vec::new();
for (fi, val) in fields.iter().enumerate() {
let param = format!("r{ri}v{fi}");
field_conditions.push(format!("v{fi} = ${param}"));
binds.push((param, (*val).clone()));
}
or_clauses.push(format!("({})", field_conditions.join(" AND ")));
}
let query = format!(
"SELECT * FROM type::table($table) WHERE sec = $sec AND ptype = $ptype AND ({}) LIMIT 1",
or_clauses.join(" OR ")
);
let mut q = self
.db
.query(&query)
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()))
.bind(("ptype", ptype.to_owned()));
for (k, v) in binds {
q = q.bind((k, v));
}
let found: Vec<CasbinRule> = q
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(!found.is_empty())
}
async fn delete_exact(&self, sec: &str, ptype: &str, rule: &[String]) -> CasbinResult<bool> {
let entry = CasbinRule::new(sec, ptype, rule);
let q = self
.db
.query(
"DELETE type::table($table)
WHERE sec = $sec AND ptype = $ptype
AND v0 = $v0 AND v1 = $v1 AND v2 = $v2
AND v3 = $v3 AND v4 = $v4 AND v5 = $v5
RETURN BEFORE",
)
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()))
.bind(("ptype", ptype.to_owned()));
let deleted: Vec<CasbinRule> = entry
.bind_values(q)
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(!deleted.is_empty())
}
async fn delete_exact_batch(
&self,
sec: &str,
ptype: &str,
rules: &[Vec<String>],
) -> CasbinResult<bool> {
let mut or_clauses = Vec::new();
let mut binds: Vec<(String, Option<String>)> = Vec::new();
for (ri, rule) in rules.iter().enumerate() {
let entry = CasbinRule::new(sec, ptype, rule);
let fields = [
&entry.v0, &entry.v1, &entry.v2, &entry.v3, &entry.v4, &entry.v5,
];
let mut field_conditions = Vec::new();
for (fi, val) in fields.iter().enumerate() {
let param = format!("r{ri}v{fi}");
field_conditions.push(format!("v{fi} = ${param}"));
binds.push((param, (*val).clone()));
}
or_clauses.push(format!("({})", field_conditions.join(" AND ")));
}
let query = format!(
"DELETE type::table($table) WHERE sec = $sec AND ptype = $ptype AND ({}) RETURN BEFORE",
or_clauses.join(" OR ")
);
let mut q = self
.db
.query(&query)
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()))
.bind(("ptype", ptype.to_owned()));
for (k, v) in binds {
q = q.bind((k, v));
}
let deleted: Vec<CasbinRule> = q
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(!deleted.is_empty())
}
async fn delete_filtered(
&self,
sec: &str,
ptype: &str,
field_index: usize,
field_values: &[String],
) -> CasbinResult<bool> {
let mut col_conditions = Vec::new();
let mut binds: Vec<(String, String)> = Vec::new();
for (offset, v) in field_values.iter().enumerate() {
if !v.is_empty() {
let col = field_index + offset;
let param = format!("fv{offset}");
col_conditions.push(format!("v{col} = ${param}"));
binds.push((param, v.clone()));
}
}
let where_clause = if col_conditions.is_empty() {
"sec = $sec AND ptype = $ptype".to_owned()
} else {
format!(
"sec = $sec AND ptype = $ptype AND {}",
col_conditions.join(" AND ")
)
};
let query = format!("DELETE type::table($table) WHERE {where_clause} RETURN BEFORE");
let mut q = self
.db
.query(&query)
.bind(("table", self.table.clone()))
.bind(("sec", sec.to_owned()))
.bind(("ptype", ptype.to_owned()));
for (k, v) in binds {
q = q.bind((k, v));
}
let deleted: Vec<CasbinRule> = q
.await
.map_err(io_err)?
.check()
.map_err(io_err)?
.take(0)
.map_err(io_err)?;
Ok(!deleted.is_empty())
}
}
fn io_err(e: impl std::fmt::Display) -> casbin::Error {
casbin::Error::IoError(std::io::Error::other(e.to_string()))
}