use async_trait::async_trait;
use toolkit_odata::filter::{FilterNode, FilterOp, ODataValue};
use toolkit_odata::{CursorV1, ODataOrderBy, OrderKey, Page, SortDir};
use toolkit_security::SecurityContext;
use account_management_sdk::{
IdpDeprovisionFailure, IdpDeprovisionTenantRequest, IdpDeprovisionUserRequest,
IdpListUsersRequest, IdpPluginClient, IdpProvisionFailure, IdpProvisionResult,
IdpProvisionTenantRequest, IdpProvisionUserRequest, IdpUser, IdpUserFilterField,
IdpUserOperationFailure,
};
use super::service::Service;
fn matches_filter(user: &IdpUser, filter: &FilterNode<IdpUserFilterField>) -> bool {
match filter {
FilterNode::Binary { field, op, value } => eval_binary(user, *field, *op, value),
FilterNode::Composite {
op: FilterOp::And,
children,
} => children.iter().all(|c| matches_filter(user, c)),
FilterNode::Composite {
op: FilterOp::Or,
children,
} => children.iter().any(|c| matches_filter(user, c)),
FilterNode::Composite { .. } => unreachable!(
"the OData parser only emits And/Or as composite ops; everything else \
surfaces as Binary/InList/Not - reaching this arm signals a bug \
upstream of the plugin SPI"
),
FilterNode::Not(inner) => !matches_filter(user, inner),
FilterNode::InList { field, values } => values
.iter()
.any(|v| eval_binary(user, *field, FilterOp::Eq, v)),
}
}
fn eval_binary(
user: &IdpUser,
field: IdpUserFilterField,
op: FilterOp,
value: &ODataValue,
) -> bool {
let lhs: Option<String> = match field {
IdpUserFilterField::Id => Some(user.id.to_string()),
IdpUserFilterField::Username => Some(user.username.clone()),
IdpUserFilterField::Email => user.email.clone(),
IdpUserFilterField::DisplayName => user.display_name.clone(),
IdpUserFilterField::FirstName => user.first_name.clone(),
IdpUserFilterField::LastName => user.last_name.clone(),
};
let rhs: String = match value {
ODataValue::String(s) => s.clone(),
ODataValue::Uuid(u) => u.to_string(),
other => unreachable!(
"IdpUserFilterField declares only String and Uuid kinds - the REST parser \
rejects every other ODataValue at the boundary; got {other:?}"
),
};
let Some(lhs) = lhs else {
return matches!(op, FilterOp::Ne);
};
let lo = |s: &str| s.to_lowercase();
match op {
FilterOp::Eq => lhs == rhs,
FilterOp::Ne => lhs != rhs,
FilterOp::Contains => lo(&lhs).contains(&lo(&rhs)),
FilterOp::StartsWith => lo(&lhs).starts_with(&lo(&rhs)),
FilterOp::EndsWith => lo(&lhs).ends_with(&lo(&rhs)),
other => unreachable!(
"Gt/Ge/Lt/Le/In/And/Or are not legal on the String/Uuid IdpUserFilterField \
surface - REST parser rejects upstream; got {other:?}"
),
}
}
fn compare_by_order(a: &IdpUser, b: &IdpUser, order: &ODataOrderBy) -> std::cmp::Ordering {
for key in &order.0 {
let lhs = project_field(a, &key.field);
let rhs = project_field(b, &key.field);
let ord = lhs.cmp(&rhs);
let ord = match key.dir {
SortDir::Asc => ord,
SortDir::Desc => ord.reverse(),
};
if !ord.is_eq() {
return ord;
}
}
std::cmp::Ordering::Equal
}
fn project_field(u: &IdpUser, field: &str) -> String {
match field {
"id" => u.id.to_string(),
"username" => u.username.clone(),
"email" => u.email.clone().unwrap_or_default(),
"display_name" => u.display_name.clone().unwrap_or_default(),
"first_name" => u.first_name.clone().unwrap_or_default(),
"last_name" => u.last_name.clone().unwrap_or_default(),
other => unreachable!(
"unknown order field {other:?}; REST parser whitelists \
IdpUserFilterField only - reaching this arm signals a bug \
upstream of the plugin SPI"
),
}
}
fn project_key_tuple(u: &IdpUser, order: &ODataOrderBy) -> Vec<String> {
order.0.iter().map(|k| project_field(u, &k.field)).collect()
}
fn compare_key_to_cursor(
item_keys: &[String],
cursor_keys: &[String],
order: &ODataOrderBy,
) -> std::cmp::Ordering {
use std::cmp::Ordering;
for (idx, key) in order.0.iter().enumerate() {
let lhs = item_keys.get(idx).map_or("", String::as_str);
let rhs = cursor_keys.get(idx).map_or("", String::as_str);
let ord = lhs.cmp(rhs);
let ord = match key.dir {
SortDir::Asc => ord,
SortDir::Desc => ord.reverse(),
};
if !ord.is_eq() {
return ord;
}
}
Ordering::Equal
}
#[async_trait]
impl IdpPluginClient for Service {
async fn provision_tenant(
&self,
_ctx: &SecurityContext,
req: &IdpProvisionTenantRequest,
) -> Result<IdpProvisionResult, IdpProvisionFailure> {
Ok(IdpProvisionResult::new(Some(Self::echo_tenant_metadata(
req,
))))
}
async fn deprovision_tenant(
&self,
_ctx: &SecurityContext,
_req: &IdpDeprovisionTenantRequest,
) -> Result<(), IdpDeprovisionFailure> {
Ok(())
}
async fn provision_user(
&self,
_ctx: &SecurityContext,
req: &IdpProvisionUserRequest,
) -> Result<IdpUser, IdpUserOperationFailure> {
let tenant_id = req.tenant_context.tenant_id;
let user = Self::echo_user(tenant_id, &req.payload);
self.record_user(tenant_id, user.clone());
Ok(user)
}
async fn deprovision_user(
&self,
_ctx: &SecurityContext,
req: &IdpDeprovisionUserRequest,
) -> Result<(), IdpUserOperationFailure> {
let _ = self.forget_user(req.tenant_context.tenant_id, req.user_id);
Ok(())
}
async fn list_users(
&self,
_ctx: &SecurityContext,
req: &IdpListUsersRequest,
) -> Result<Page<IdpUser>, IdpUserOperationFailure> {
let mut snapshot = self.snapshot_users(req.tenant_context.tenant_id);
if let Some(filter) = req.filter.as_ref() {
snapshot.retain(|u| matches_filter(u, filter));
}
let effective_order = req
.order
.clone()
.unwrap_or_else(|| {
ODataOrderBy(vec![OrderKey {
field: "username".into(),
dir: SortDir::Asc,
}])
})
.ensure_tiebreaker("id", SortDir::Asc);
snapshot.sort_by(|a, b| compare_by_order(a, b, &effective_order));
let cursor: Option<CursorV1> =
match req.pagination.cursor() {
None => None,
Some(raw) => Some(CursorV1::decode(raw).map_err(|err| {
IdpUserOperationFailure::Rejected {
detail: format!("static-idp-plugin: invalid cursor: {err}"),
}
})?),
};
if let Some(c) = cursor.as_ref()
&& let Err(err) = toolkit_odata::validate_cursor_against(c, &effective_order, None)
{
return Err(IdpUserOperationFailure::Rejected {
detail: format!(
"static-idp-plugin: cursor was issued for a different \
$filter / $orderby than the current request: {err}"
),
});
}
let skipped: Vec<IdpUser> = match cursor.as_ref() {
Some(c) => snapshot
.into_iter()
.filter(|u| {
compare_key_to_cursor(
&project_key_tuple(u, &effective_order),
&c.k,
&effective_order,
)
.is_gt()
})
.collect(),
None => snapshot,
};
let top = req.pagination.top() as usize;
let mut page_items: Vec<IdpUser> = skipped.into_iter().take(top + 1).collect();
let next_cursor = if page_items.len() > top {
page_items.pop();
let next = match page_items.last() {
None => {
return Ok(Page::new(
page_items,
toolkit_odata::PageInfo {
next_cursor: None,
prev_cursor: None,
limit: u64::from(req.pagination.top()),
},
));
}
Some(last) => CursorV1 {
k: project_key_tuple(last, &effective_order),
o: effective_order.0.first().map_or(SortDir::Asc, |k| k.dir),
s: effective_order.to_signed_tokens(),
f: None,
d: "fwd".to_owned(),
},
};
Some(
next.encode()
.map_err(|err| IdpUserOperationFailure::Rejected {
detail: format!("static-idp-plugin: failed to encode next cursor: {err}"),
})?,
)
} else {
None
};
Ok(Page::new(
page_items,
toolkit_odata::PageInfo {
next_cursor,
prev_cursor: None, limit: u64::from(req.pagination.top()),
},
))
}
}
#[cfg(test)]
#[path = "client_tests.rs"]
mod pagination_tests;