use std::collections::{BTreeMap, BTreeSet};
use std::iter::FromIterator;
use crate::payload::Payload;
use crate::specs::trust_task_discovery::v0_1 as wire;
use crate::type_uri::TypeUri;
const DEFAULT_FRAMEWORK_VERSION: &str = "0.1";
pub fn match_slug(pattern: &str, slug: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/*") {
if prefix.is_empty() {
return false;
}
let mut full_prefix = String::with_capacity(prefix.len() + 1);
full_prefix.push_str(prefix);
full_prefix.push('/');
return slug.starts_with(&full_prefix);
}
pattern == slug
}
pub fn query_matches<S: AsRef<str>>(patterns: &[S], slug: &str) -> bool {
if patterns.is_empty() {
return true;
}
patterns.iter().any(|p| match_slug(p.as_ref(), slug))
}
#[derive(Debug, Clone, Default)]
pub struct DiscoveryRegistry {
type_uris: BTreeSet<String>,
required_ext: BTreeMap<String, BTreeSet<String>>,
framework_version: Option<String>,
}
impl DiscoveryRegistry {
pub fn new() -> Self {
Self {
framework_version: Some(DEFAULT_FRAMEWORK_VERSION.to_string()),
..Self::default()
}
}
pub fn framework_version(mut self, version: impl Into<String>) -> Self {
self.framework_version = Some(version.into());
self
}
pub fn no_framework_version(mut self) -> Self {
self.framework_version = None;
self
}
pub fn with_required_ext<I, S>(mut self, uri: TypeUri, namespaces: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let bare = uri.bare().to_string();
self.type_uris.insert(bare.clone());
let entry = self.required_ext.entry(bare).or_default();
entry.extend(namespaces.into_iter().map(Into::into));
self
}
pub fn require_ext<I, S>(&mut self, uri: TypeUri, namespaces: I)
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let bare = uri.bare().to_string();
self.type_uris.insert(bare.clone());
let entry = self.required_ext.entry(bare).or_default();
entry.extend(namespaces.into_iter().map(Into::into));
}
pub fn with<P: Payload>(self) -> Self {
let uri = P::type_uri();
self.with_type_uri(uri)
}
pub fn with_type_uri(mut self, uri: TypeUri) -> Self {
self.type_uris.insert(uri.bare().to_string());
self
}
pub fn register(&mut self, uri: TypeUri) {
self.type_uris.insert(uri.bare().to_string());
}
pub fn register_payload<P: Payload>(&mut self) {
self.register(P::type_uri());
}
pub fn register_str(&mut self, uri: impl Into<String>) {
self.type_uris.insert(uri.into());
}
pub fn with_str(mut self, uri: impl Into<String>) -> Self {
self.register_str(uri);
self
}
pub fn supported_types(&self) -> Vec<&str> {
self.type_uris.iter().map(String::as_str).collect()
}
pub fn respond_to(&self, query: &wire::Payload) -> wire::Response {
let patterns: Vec<&str> = query.patterns.iter().map(|p| p.as_str()).collect();
let supported_types: Vec<wire::ResponseSupportedTypesItem> = self
.type_uris
.iter()
.filter(|uri| match parse_slug(uri) {
Some(slug) => query_matches(&patterns, slug),
None => false,
})
.map(|uri| self.entry_for(uri))
.collect();
let framework_version = self
.framework_version
.as_deref()
.map(|v| {
v.parse::<wire::ResponseFrameworkVersion>()
.expect("framework_version was set to a value that does not match the spec's MAJOR.MINOR pattern")
});
wire::Response {
supported_types,
framework_version,
}
}
fn entry_for(&self, uri: &str) -> wire::ResponseSupportedTypesItem {
match self.required_ext.get(uri) {
Some(namespaces) if !namespaces.is_empty() => {
let required_ext: Vec<wire::ResponseSupportedTypesItemObjectRequiredExtItem> =
namespaces
.iter()
.map(|ns| {
ns.parse().expect(
"required_ext namespace must match the reverse-DNS pattern; \
was set via with_required_ext / require_ext",
)
})
.collect();
wire::ResponseSupportedTypesItem::Object {
type_: uri.to_string(),
required_ext: Some(required_ext),
}
}
_ => wire::ResponseSupportedTypesItem::Uri(uri.to_string()),
}
}
}
impl<S: Into<String>> FromIterator<S> for DiscoveryRegistry {
fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
let mut registry = Self::new();
for uri in iter {
registry.register_str(uri);
}
registry
}
}
fn parse_slug(uri: &str) -> Option<&str> {
let after_spec = uri.split_once("/spec/")?.1;
let last_slash = after_spec.rfind('/')?;
Some(&after_spec[..last_slash])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn star_matches_anything() {
assert!(match_slug("*", "acl/grant"));
assert!(match_slug("*", "kyc-handoff"));
assert!(match_slug("*", "trust-task-discovery"));
}
#[test]
fn prefix_wildcard_matches_descendants() {
assert!(match_slug("acl/*", "acl/grant"));
assert!(match_slug("acl/*", "acl/revoke"));
assert!(match_slug("acl/*", "acl/grant/sub"));
}
#[test]
fn prefix_wildcard_does_not_match_bare_prefix() {
assert!(!match_slug("acl/*", "acl"));
assert!(!match_slug("acl/*", "aclx"));
assert!(!match_slug("acl/*", "kyc-handoff"));
}
#[test]
fn exact_pattern_matches_exact_slug_only() {
assert!(match_slug("kyc-handoff", "kyc-handoff"));
assert!(!match_slug("kyc-handoff", "kyc-handoff/v2"));
assert!(!match_slug("kyc-handoff", "kyc"));
}
#[test]
fn interior_wildcards_are_treated_literally() {
assert!(!match_slug("acl/*/grant", "acl/grant"));
assert!(!match_slug("a*b", "ab"));
}
#[test]
fn empty_patterns_match_everything() {
let patterns: &[&str] = &[];
assert!(query_matches(patterns, "acl/grant"));
assert!(query_matches(patterns, "anything"));
}
#[test]
fn or_semantics_across_patterns() {
let patterns = ["acl/*", "kyc-handoff"];
assert!(query_matches(&patterns, "acl/grant"));
assert!(query_matches(&patterns, "kyc-handoff"));
assert!(!query_matches(&patterns, "consent/give"));
}
#[test]
fn registry_dedupes_and_sorts() {
let registry = DiscoveryRegistry::new()
.with_type_uri(TypeUri::canonical("acl/revoke", 0, 1).unwrap())
.with_type_uri(TypeUri::canonical("acl/grant", 0, 1).unwrap())
.with_type_uri(TypeUri::canonical("acl/grant", 0, 1).unwrap());
let types = registry.supported_types();
assert_eq!(
types,
vec![
"https://trusttasks.org/spec/acl/grant/0.1",
"https://trusttasks.org/spec/acl/revoke/0.1",
]
);
}
#[test]
fn registry_responds_to_query_with_filtered_subset() {
use crate::specs::acl::{change_role, grant, list, revoke, show};
let registry = DiscoveryRegistry::new()
.with::<grant::v0_1::Payload>()
.with::<revoke::v0_1::Payload>()
.with::<show::v0_1::Payload>()
.with::<list::v0_1::Payload>()
.with::<change_role::v0_1::Payload>();
let acl_only = wire::Payload {
patterns: vec!["acl/*".parse().unwrap()],
};
let response = registry.respond_to(&acl_only);
assert_eq!(response.supported_types.len(), 5);
let only_grant = wire::Payload {
patterns: vec!["acl/grant".parse().unwrap()],
};
let response = registry.respond_to(&only_grant);
assert_eq!(
uris_in(&response),
vec!["https://trusttasks.org/spec/acl/grant/0.1"]
);
let everything = wire::Payload { patterns: vec![] };
let response = registry.respond_to(&everything);
assert_eq!(response.supported_types.len(), 5);
let nothing = wire::Payload {
patterns: vec!["does-not-exist/*".parse().unwrap()],
};
let response = registry.respond_to(¬hing);
assert!(response.supported_types.is_empty());
let response = registry.respond_to(&wire::Payload { patterns: vec![] });
assert_eq!(
response.framework_version.as_ref().map(|v| v.to_string()),
Some("0.1".to_string())
);
}
#[test]
fn no_framework_version_suppresses_field() {
let registry = DiscoveryRegistry::new().no_framework_version();
let response = registry.respond_to(&wire::Payload { patterns: vec![] });
assert!(response.framework_version.is_none());
}
#[test]
fn override_framework_version_is_emitted_verbatim() {
let registry = DiscoveryRegistry::new().framework_version("0.2");
let response = registry.respond_to(&wire::Payload { patterns: vec![] });
assert_eq!(
response.framework_version.as_ref().map(|v| v.to_string()),
Some("0.2".to_string())
);
}
#[test]
fn with_required_ext_advertises_namespace_policy_in_expanded_form() {
let grant_uri = TypeUri::canonical("acl/grant", 0, 1).unwrap();
let registry = DiscoveryRegistry::new()
.with::<crate::specs::acl::revoke::v0_1::Payload>()
.with_required_ext(grant_uri, ["vnd.affinidi.webvh"]);
let response = registry.respond_to(&wire::Payload { patterns: vec![] });
let grant_entry = response
.supported_types
.iter()
.find(|e| uri_of(e) == "https://trusttasks.org/spec/acl/grant/0.1")
.expect("acl/grant entry present");
match grant_entry {
wire::ResponseSupportedTypesItem::Object { required_ext, .. } => {
let namespaces: Vec<String> = required_ext
.as_ref()
.expect("requiredExt populated")
.iter()
.map(|n| n.to_string())
.collect();
assert_eq!(namespaces, vec!["vnd.affinidi.webvh".to_string()]);
}
other => panic!("expected expanded Object form, got {other:?}"),
}
let revoke_entry = response
.supported_types
.iter()
.find(|e| uri_of(e) == "https://trusttasks.org/spec/acl/revoke/0.1")
.expect("acl/revoke entry present");
assert!(matches!(
revoke_entry,
wire::ResponseSupportedTypesItem::Uri(_)
));
}
fn uris_in(response: &wire::Response) -> Vec<&str> {
response.supported_types.iter().map(uri_of).collect()
}
fn uri_of(entry: &wire::ResponseSupportedTypesItem) -> &str {
match entry {
wire::ResponseSupportedTypesItem::Uri(s) => s.as_str(),
wire::ResponseSupportedTypesItem::Object { type_, .. } => type_.as_str(),
}
}
}