use crate::error::ScimError;
use crate::providers::ResourceProvider;
use crate::resource::ScimOperation;
use crate::schema::{AttributeDefinition, SchemaRegistry};
use crate::schema_discovery::{AuthenticationScheme, ServiceProviderConfig};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderCapabilities {
pub supported_operations: HashMap<String, Vec<ScimOperation>>,
pub supported_schemas: Vec<String>,
pub supported_resource_types: Vec<String>,
pub bulk_capabilities: BulkCapabilities,
pub filter_capabilities: FilterCapabilities,
pub pagination_capabilities: PaginationCapabilities,
pub authentication_capabilities: AuthenticationCapabilities,
pub extended_capabilities: ExtendedCapabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BulkCapabilities {
pub supported: bool,
pub max_operations: Option<usize>,
pub max_payload_size: Option<usize>,
pub fail_on_errors_supported: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FilterCapabilities {
pub supported: bool,
pub max_results: Option<usize>,
pub filterable_attributes: HashMap<String, Vec<String>>,
pub supported_operators: Vec<FilterOperator>,
pub complex_filters_supported: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaginationCapabilities {
pub supported: bool,
pub default_page_size: Option<usize>,
pub max_page_size: Option<usize>,
pub cursor_based_supported: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationCapabilities {
pub schemes: Vec<AuthenticationScheme>,
pub mfa_supported: bool,
pub token_refresh_supported: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtendedCapabilities {
pub etag_supported: bool,
pub patch_supported: bool,
pub change_password_supported: bool,
pub sort_supported: bool,
pub custom_capabilities: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum FilterOperator {
#[serde(rename = "eq")]
Equal,
#[serde(rename = "ne")]
NotEqual,
#[serde(rename = "co")]
Contains,
#[serde(rename = "sw")]
StartsWith,
#[serde(rename = "ew")]
EndsWith,
#[serde(rename = "pr")]
Present,
#[serde(rename = "gt")]
GreaterThan,
#[serde(rename = "ge")]
GreaterThanOrEqual,
#[serde(rename = "lt")]
LessThan,
#[serde(rename = "le")]
LessThanOrEqual,
}
pub trait CapabilityIntrospectable {
fn get_provider_specific_capabilities(&self) -> ExtendedCapabilities {
ExtendedCapabilities::default()
}
fn get_bulk_limits(&self) -> Option<BulkCapabilities> {
None
}
fn get_pagination_limits(&self) -> Option<PaginationCapabilities> {
None
}
fn get_authentication_capabilities(&self) -> Option<AuthenticationCapabilities> {
None
}
}
pub struct CapabilityDiscovery;
impl CapabilityDiscovery {
pub fn discover_capabilities<P>(
schema_registry: &SchemaRegistry,
resource_handlers: &HashMap<String, std::sync::Arc<crate::resource::ResourceHandler>>,
supported_operations: &HashMap<String, Vec<ScimOperation>>,
_provider: &P,
) -> Result<ProviderCapabilities, ScimError>
where
P: ResourceProvider,
{
let supported_schemas = Self::discover_schemas(schema_registry);
let supported_resource_types = Self::discover_resource_types(resource_handlers);
let supported_operations_map = supported_operations.clone();
let filter_capabilities =
Self::discover_filter_capabilities(schema_registry, resource_handlers)?;
let bulk_capabilities = Self::default_bulk_capabilities();
let pagination_capabilities = Self::default_pagination_capabilities();
let authentication_capabilities = Self::default_authentication_capabilities();
let mut extended_capabilities = ExtendedCapabilities::default();
extended_capabilities.etag_supported = true;
extended_capabilities.patch_supported = supported_operations
.values()
.any(|ops| ops.contains(&ScimOperation::Patch));
Ok(ProviderCapabilities {
supported_operations: supported_operations_map,
supported_schemas,
supported_resource_types,
bulk_capabilities,
filter_capabilities,
pagination_capabilities,
authentication_capabilities,
extended_capabilities,
})
}
pub fn discover_capabilities_with_introspection<P>(
schema_registry: &SchemaRegistry,
resource_handlers: &HashMap<String, std::sync::Arc<crate::resource::ResourceHandler>>,
supported_operations: &HashMap<String, Vec<ScimOperation>>,
provider: &P,
) -> Result<ProviderCapabilities, ScimError>
where
P: ResourceProvider + CapabilityIntrospectable,
{
let supported_schemas = Self::discover_schemas(schema_registry);
let supported_resource_types = Self::discover_resource_types(resource_handlers);
let supported_operations_map = supported_operations.clone();
let filter_capabilities =
Self::discover_filter_capabilities(schema_registry, resource_handlers)?;
let bulk_capabilities = provider
.get_bulk_limits()
.unwrap_or_else(|| Self::default_bulk_capabilities());
let pagination_capabilities = provider
.get_pagination_limits()
.unwrap_or_else(|| Self::default_pagination_capabilities());
let authentication_capabilities = provider
.get_authentication_capabilities()
.unwrap_or_else(|| Self::default_authentication_capabilities());
let extended_capabilities = provider.get_provider_specific_capabilities();
Ok(ProviderCapabilities {
supported_operations: supported_operations_map,
supported_schemas,
supported_resource_types,
bulk_capabilities,
filter_capabilities,
pagination_capabilities,
authentication_capabilities,
extended_capabilities,
})
}
fn discover_schemas(schema_registry: &SchemaRegistry) -> Vec<String> {
schema_registry
.get_schemas()
.iter()
.map(|schema| schema.id.clone())
.collect()
}
fn discover_resource_types(
resource_handlers: &HashMap<String, std::sync::Arc<crate::resource::ResourceHandler>>,
) -> Vec<String> {
resource_handlers.keys().cloned().collect()
}
fn discover_filter_capabilities(
schema_registry: &SchemaRegistry,
resource_handlers: &HashMap<String, std::sync::Arc<crate::resource::ResourceHandler>>,
) -> Result<FilterCapabilities, ScimError> {
let mut filterable_attributes = HashMap::new();
for (resource_type, handler) in resource_handlers {
if let Some(schema) = schema_registry.get_schema(&handler.schema.id) {
let attrs = Self::collect_filterable_attributes(&schema.attributes, "");
filterable_attributes.insert(resource_type.clone(), attrs);
}
}
let supported_operators = Self::determine_supported_operators(schema_registry);
Ok(FilterCapabilities {
supported: !filterable_attributes.is_empty(),
max_results: Some(200), filterable_attributes,
supported_operators,
complex_filters_supported: true, })
}
fn is_attribute_filterable(attr: &AttributeDefinition) -> bool {
match attr.data_type {
crate::schema::AttributeType::Complex => false, _ => true, }
}
fn collect_filterable_attributes(
attributes: &[AttributeDefinition],
prefix: &str,
) -> Vec<String> {
let mut filterable = Vec::new();
for attr in attributes {
let attr_name = if prefix.is_empty() {
attr.name.clone()
} else {
format!("{}.{}", prefix, attr.name)
};
if Self::is_attribute_filterable(attr) {
filterable.push(attr_name.clone());
}
if !attr.sub_attributes.is_empty() {
filterable.extend(Self::collect_filterable_attributes(
&attr.sub_attributes,
&attr_name,
));
}
}
filterable
}
fn determine_supported_operators(schema_registry: &SchemaRegistry) -> Vec<FilterOperator> {
let mut operators = HashSet::new();
operators.insert(FilterOperator::Equal);
operators.insert(FilterOperator::NotEqual);
operators.insert(FilterOperator::Present);
if Self::has_string_attributes(schema_registry) {
operators.insert(FilterOperator::Contains);
operators.insert(FilterOperator::StartsWith);
operators.insert(FilterOperator::EndsWith);
}
if Self::has_comparable_attributes(schema_registry) {
operators.insert(FilterOperator::GreaterThan);
operators.insert(FilterOperator::GreaterThanOrEqual);
operators.insert(FilterOperator::LessThan);
operators.insert(FilterOperator::LessThanOrEqual);
}
operators.into_iter().collect()
}
fn has_string_attributes(schema_registry: &SchemaRegistry) -> bool {
fn has_string_in_attributes(attributes: &[AttributeDefinition]) -> bool {
attributes.iter().any(|attr| {
matches!(attr.data_type, crate::schema::AttributeType::String)
|| has_string_in_attributes(&attr.sub_attributes)
})
}
schema_registry
.get_schemas()
.iter()
.any(|schema| has_string_in_attributes(&schema.attributes))
}
fn has_comparable_attributes(schema_registry: &SchemaRegistry) -> bool {
fn has_comparable_in_attributes(attributes: &[AttributeDefinition]) -> bool {
attributes.iter().any(|attr| {
matches!(
attr.data_type,
crate::schema::AttributeType::Integer
| crate::schema::AttributeType::Decimal
| crate::schema::AttributeType::DateTime
) || has_comparable_in_attributes(&attr.sub_attributes)
})
}
schema_registry
.get_schemas()
.iter()
.any(|schema| has_comparable_in_attributes(&schema.attributes))
}
fn default_bulk_capabilities() -> BulkCapabilities {
BulkCapabilities {
supported: false, max_operations: None,
max_payload_size: None,
fail_on_errors_supported: false,
}
}
fn default_pagination_capabilities() -> PaginationCapabilities {
PaginationCapabilities {
supported: true, default_page_size: Some(20),
max_page_size: Some(200),
cursor_based_supported: false, }
}
fn default_authentication_capabilities() -> AuthenticationCapabilities {
AuthenticationCapabilities {
schemes: vec![], mfa_supported: false,
token_refresh_supported: false,
}
}
pub fn generate_service_provider_config(
capabilities: &ProviderCapabilities,
) -> ServiceProviderConfig {
ServiceProviderConfig {
patch_supported: capabilities.extended_capabilities.patch_supported,
bulk_supported: capabilities.bulk_capabilities.supported,
filter_supported: capabilities.filter_capabilities.supported,
change_password_supported: capabilities.extended_capabilities.change_password_supported,
sort_supported: capabilities.extended_capabilities.sort_supported,
etag_supported: capabilities.extended_capabilities.etag_supported,
authentication_schemes: capabilities.authentication_capabilities.schemes.clone(),
bulk_max_operations: capabilities
.bulk_capabilities
.max_operations
.map(|n| n as u32),
bulk_max_payload_size: capabilities
.bulk_capabilities
.max_payload_size
.map(|n| n as u64),
filter_max_results: capabilities
.filter_capabilities
.max_results
.map(|n| n as u32),
}
}
}
impl Default for BulkCapabilities {
fn default() -> Self {
Self {
supported: false,
max_operations: None,
max_payload_size: None,
fail_on_errors_supported: false,
}
}
}
impl Default for FilterCapabilities {
fn default() -> Self {
Self {
supported: false,
max_results: Some(200),
filterable_attributes: HashMap::new(),
supported_operators: vec![FilterOperator::Equal, FilterOperator::Present],
complex_filters_supported: false,
}
}
}
impl Default for PaginationCapabilities {
fn default() -> Self {
Self {
supported: true,
default_page_size: Some(20),
max_page_size: Some(200),
cursor_based_supported: false,
}
}
}
impl Default for AuthenticationCapabilities {
fn default() -> Self {
Self {
schemes: vec![],
mfa_supported: false,
token_refresh_supported: false,
}
}
}
impl Default for ExtendedCapabilities {
fn default() -> Self {
Self {
etag_supported: true, patch_supported: false,
change_password_supported: false,
sort_supported: false,
custom_capabilities: HashMap::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::SchemaRegistry;
use std::collections::HashMap;
#[test]
fn test_discover_schemas() {
let registry = SchemaRegistry::new().expect("Failed to create schema registry");
let schemas = CapabilityDiscovery::discover_schemas(®istry);
assert!(!schemas.is_empty());
assert!(schemas.contains(&"urn:ietf:params:scim:schemas:core:2.0:User".to_string()));
}
#[test]
fn test_has_string_attributes() {
let registry = SchemaRegistry::new().expect("Failed to create schema registry");
assert!(CapabilityDiscovery::has_string_attributes(®istry));
}
#[test]
fn test_has_comparable_attributes() {
let registry = SchemaRegistry::new().expect("Failed to create schema registry");
assert!(CapabilityDiscovery::has_comparable_attributes(®istry));
}
#[test]
fn test_service_provider_config_generation() {
let capabilities = ProviderCapabilities {
supported_operations: HashMap::new(),
supported_schemas: vec!["urn:ietf:params:scim:schemas:core:2.0:User".to_string()],
supported_resource_types: vec!["User".to_string()],
bulk_capabilities: BulkCapabilities {
supported: true,
max_operations: Some(100),
max_payload_size: Some(1024 * 1024),
fail_on_errors_supported: true,
},
filter_capabilities: FilterCapabilities::default(),
pagination_capabilities: PaginationCapabilities::default(),
authentication_capabilities: AuthenticationCapabilities::default(),
extended_capabilities: ExtendedCapabilities {
patch_supported: true,
..Default::default()
},
};
let config = CapabilityDiscovery::generate_service_provider_config(&capabilities);
assert!(config.bulk_supported);
assert!(config.patch_supported);
assert_eq!(config.bulk_max_operations, Some(100));
assert_eq!(config.bulk_max_payload_size, Some(1024 * 1024));
}
#[test]
fn test_filter_operators() {
let registry = SchemaRegistry::new().expect("Failed to create schema registry");
let operators = CapabilityDiscovery::determine_supported_operators(®istry);
log::debug!("Discovered filter operators: {:?}", operators);
assert!(operators.contains(&FilterOperator::Equal));
assert!(operators.contains(&FilterOperator::Present));
assert!(operators.contains(&FilterOperator::Contains));
assert!(operators.contains(&FilterOperator::StartsWith));
assert!(operators.contains(&FilterOperator::GreaterThan));
assert!(operators.contains(&FilterOperator::LessThan));
}
}