use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SurrealismCapabilities {
#[serde(default)]
pub allow_scripting: bool,
#[serde(default)]
pub allow_arbitrary_queries: bool,
#[serde(default)]
pub allow_functions: FunctionTargets,
#[serde(default)]
pub allow_net: Vec<String>,
#[serde(default)]
pub max_memory_bytes: Option<usize>,
#[serde(default, with = "optional_duration")]
pub max_execution_time: Option<std::time::Duration>,
#[serde(default)]
pub max_pool_size: Option<usize>,
#[serde(default)]
pub max_kv_entries: Option<usize>,
#[serde(default)]
pub max_kv_value_bytes: Option<usize>,
#[serde(default = "default_true")]
pub strict_timeout: bool,
}
fn default_true() -> bool {
true
}
impl Default for SurrealismCapabilities {
fn default() -> Self {
Self {
allow_scripting: false,
allow_arbitrary_queries: false,
allow_functions: FunctionTargets::default(),
allow_net: Vec::new(),
max_memory_bytes: None,
max_execution_time: None,
max_pool_size: None,
max_kv_entries: None,
max_kv_value_bytes: None,
strict_timeout: true,
}
}
}
#[derive(Debug, Default, Clone)]
pub enum FunctionTargets {
#[default]
None,
All,
Some(Vec<String>),
}
impl FunctionTargets {
pub fn allows(&self, fnc: &str) -> bool {
match self {
Self::None => false,
Self::All => true,
Self::Some(patterns) => patterns.iter().any(|p| func_pattern_matches(p, fnc)),
}
}
}
fn func_pattern_matches(pattern: &str, fnc: &str) -> bool {
if let Some(family) = pattern.strip_suffix("::*") {
let f = fnc.split_once("::").map(|(f, _)| f).unwrap_or(fnc);
f == family
} else if let Some((pfam, pname)) = pattern.split_once("::") {
let Some((ffam, fname)) = fnc.split_once("::") else {
return false;
};
pfam == ffam && pname == fname
} else {
let f = fnc.split_once("::").map(|(f, _)| f).unwrap_or(fnc);
f == pattern
}
}
impl Serialize for FunctionTargets {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
Self::None => {
let empty: Vec<String> = Vec::new();
empty.serialize(serializer)
}
Self::All => vec!["*".to_string()].serialize(serializer),
Self::Some(patterns) => patterns.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for FunctionTargets {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let entries: Vec<String> = Vec::deserialize(deserializer)?;
if entries.is_empty() {
return Ok(Self::None);
}
if entries.len() == 1 && entries[0] == "*" {
return Ok(Self::All);
}
if entries.iter().any(|e| e == "*") {
return Ok(Self::All);
}
Ok(Self::Some(entries))
}
}
mod optional_duration {
use std::time::Duration;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(val: &Option<Duration>, s: S) -> Result<S::Ok, S::Error> {
match val {
Some(d) => d.as_millis().serialize(s),
None => s.serialize_none(),
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Duration>, D::Error> {
let ms: Option<u64> = Option::deserialize(d)?;
Ok(ms.map(Duration::from_millis))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn func_pattern_exact_match() {
assert!(func_pattern_matches("http::get", "http::get"));
assert!(!func_pattern_matches("http::get", "http::post"));
assert!(!func_pattern_matches("http::get", "string::len"));
}
#[test]
fn func_pattern_family_wildcard() {
assert!(func_pattern_matches("http::*", "http::get"));
assert!(func_pattern_matches("http::*", "http::post"));
assert!(!func_pattern_matches("http::*", "string::len"));
}
#[test]
fn func_pattern_bare_family() {
assert!(func_pattern_matches("http", "http::get"));
assert!(func_pattern_matches("http", "http::post"));
assert!(!func_pattern_matches("http", "string::len"));
assert!(func_pattern_matches("fn", "fn::user_exists"));
}
#[test]
fn func_pattern_fn_prefix() {
assert!(func_pattern_matches("fn::*", "fn::user_exists"));
assert!(func_pattern_matches("fn::user_exists", "fn::user_exists"));
assert!(!func_pattern_matches("fn::user_exists", "fn::other"));
}
#[test]
fn function_targets_none_denies_all() {
let targets = FunctionTargets::None;
assert!(!targets.allows("http::get"));
assert!(!targets.allows("fn::anything"));
}
#[test]
fn function_targets_all_allows_all() {
let targets = FunctionTargets::All;
assert!(targets.allows("http::get"));
assert!(targets.allows("fn::anything"));
assert!(targets.allows("string::len"));
}
#[test]
fn function_targets_some_patterns() {
let targets = FunctionTargets::Some(vec!["http::*".into(), "fn::user_exists".into()]);
assert!(targets.allows("http::get"));
assert!(targets.allows("http::post"));
assert!(targets.allows("fn::user_exists"));
assert!(!targets.allows("fn::other"));
assert!(!targets.allows("string::len"));
}
#[test]
fn function_targets_serde_empty_is_none() {
let toml_str = r#"
[capabilities]
allow_functions = []
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert!(matches!(w.capabilities.allow_functions, FunctionTargets::None));
}
#[test]
fn function_targets_serde_star_is_all() {
let toml_str = r#"
[capabilities]
allow_functions = ["*"]
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert!(matches!(w.capabilities.allow_functions, FunctionTargets::All));
}
#[test]
fn function_targets_serde_patterns() {
let toml_str = r#"
[capabilities]
allow_functions = ["http::*", "fn::user_exists"]
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert!(matches!(w.capabilities.allow_functions, FunctionTargets::Some(_)));
assert!(w.capabilities.allow_functions.allows("http::get"));
assert!(w.capabilities.allow_functions.allows("fn::user_exists"));
assert!(!w.capabilities.allow_functions.allows("string::len"));
}
#[test]
fn function_targets_serde_omitted_is_none() {
let toml_str = r#"
[capabilities]
allow_scripting = false
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert!(matches!(w.capabilities.allow_functions, FunctionTargets::None));
}
#[test]
fn function_targets_roundtrip() {
#[derive(Serialize, Deserialize)]
struct Wrapper {
targets: FunctionTargets,
}
let wrapper = Wrapper {
targets: FunctionTargets::Some(vec!["http::*".into(), "fn::check".into()]),
};
let serialized = toml::to_string(&wrapper).unwrap();
let deserialized: Wrapper = toml::from_str(&serialized).unwrap();
assert!(deserialized.targets.allows("http::get"));
assert!(deserialized.targets.allows("fn::check"));
assert!(!deserialized.targets.allows("string::len"));
}
#[test]
fn kv_limits_parse() {
let toml_str = r#"
[capabilities]
max_kv_entries = 1000
max_kv_value_bytes = 65536
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(w.capabilities.max_kv_entries, Some(1000));
assert_eq!(w.capabilities.max_kv_value_bytes, Some(65536));
}
#[test]
fn duration_serde() {
let toml_str = r#"
[capabilities]
max_execution_time = 5000
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(w.capabilities.max_execution_time, Some(std::time::Duration::from_millis(5000)));
}
#[test]
fn strict_timeout_defaults_true() {
let caps = SurrealismCapabilities::default();
assert!(caps.strict_timeout);
}
#[test]
fn strict_timeout_serde_default_true() {
let toml_str = r#"
[capabilities]
allow_scripting = false
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert!(w.capabilities.strict_timeout);
}
#[test]
fn strict_timeout_serde_false() {
let toml_str = r#"
[capabilities]
strict_timeout = false
"#;
#[derive(Deserialize)]
struct Wrapper {
capabilities: SurrealismCapabilities,
}
let w: Wrapper = toml::from_str(toml_str).unwrap();
assert!(!w.capabilities.strict_timeout);
}
}