use super::IAMStatement;
use crate::{
core::IAMVersion,
validation::{Validate, ValidationContext, ValidationError, ValidationResult, helpers},
};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
mod one_or_many {
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<T, S>(value: &Vec<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
if value.len() == 1 {
value[0].serialize(serializer)
} else {
value.serialize(serializer)
}
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
where
T: DeserializeOwned,
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum OneOrMany<T> {
One(T),
Many(Vec<T>),
}
match OneOrMany::deserialize(deserializer)? {
OneOrMany::One(val) => Ok(vec![val]),
OneOrMany::Many(vals) => Ok(vals),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
pub struct IAMPolicy {
#[serde(rename = "Version")]
pub version: IAMVersion,
#[serde(rename = "Id", skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "Statement", with = "one_or_many")]
#[cfg_attr(
feature = "utoipa",
schema(value_type = IAMStatements)
)]
pub statement: Vec<IAMStatement>,
}
#[cfg(feature = "utoipa")]
#[derive(utoipa::ToSchema, Serialize, Deserialize)]
#[serde(untagged)]
#[allow(dead_code, clippy::large_enum_variant)]
enum IAMStatements {
Single(IAMStatement),
Multiple(Vec<IAMStatement>),
}
impl IAMPolicy {
#[must_use]
pub fn new() -> Self {
Self {
version: IAMVersion::default(),
id: None,
statement: Vec::new(),
}
}
#[must_use]
pub fn with_version(version: IAMVersion) -> Self {
Self {
version,
id: None,
statement: Vec::new(),
}
}
#[must_use]
pub fn add_statement(mut self, statement: IAMStatement) -> Self {
self.statement.push(statement);
self
}
#[must_use]
pub fn with_id<S: Into<String>>(mut self, id: S) -> Self {
self.id = Some(id.into());
self
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
}
impl Default for IAMPolicy {
fn default() -> Self {
Self::new()
}
}
impl Validate for IAMPolicy {
fn validate(&self, context: &mut ValidationContext) -> ValidationResult {
context.with_segment("Policy", |ctx| {
let mut results = Vec::new();
if self.statement.is_empty() {
results.push(Err(ValidationError::MissingField {
field: "Statement".to_string(),
context: ctx.current_path(),
}));
return helpers::collect_errors(results);
}
for (i, statement) in self.statement.iter().enumerate() {
ctx.with_segment(&format!("Statement[{i}]"), |stmt_ctx| {
results.push(statement.validate(stmt_ctx));
});
}
let mut seen_sids = HashSet::new();
for (i, statement) in self.statement.iter().enumerate() {
if let Some(ref sid) = statement.sid {
if seen_sids.contains(sid) {
results.push(Err(ValidationError::LogicalError {
message: format!(
"Duplicate statement ID '{sid}' found at position {i}"
),
}));
} else {
seen_sids.insert(sid.clone());
}
}
}
match self.version {
IAMVersion::V20121017 => {
}
#[allow(deprecated)]
IAMVersion::V20081017 => {
results.push(Err(ValidationError::InvalidValue {
field: "Version".to_string(),
value: format!("{:?}", self.version),
reason: "Only IAM version 2012-10-17 is supported".to_string(),
}));
}
}
if let Some(ref id) = self.id
&& id.is_empty()
{
results.push(Err(ValidationError::InvalidValue {
field: "Id".to_string(),
value: id.clone(),
reason: "Policy ID cannot be empty".to_string(),
}));
}
helpers::collect_errors(results)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{IAMAction, IAMEffect, IAMResource, IAMVersion};
#[test]
fn test_policy_validation() {
let valid_policy = IAMPolicy::new()
.with_id("550e8400-e29b-41d4-a716-446655440000")
.add_statement(
IAMStatement::new(IAMEffect::Allow)
.with_sid("AllowS3Read")
.with_action(IAMAction::Single("s3:GetObject".to_string()))
.with_resource(IAMResource::Single("arn:aws:s3:::bucket/*".to_string())),
);
assert!(valid_policy.is_valid());
let empty_policy = IAMPolicy::new();
assert!(!empty_policy.is_valid());
let duplicate_sid_policy = IAMPolicy::new()
.with_id("550e8400-e29b-41d4-a716-446655440001")
.add_statement(
IAMStatement::new(IAMEffect::Allow)
.with_sid("DuplicateId")
.with_action(IAMAction::Single("s3:GetObject".to_string()))
.with_resource(IAMResource::Single("*".to_string())),
)
.add_statement(
IAMStatement::new(IAMEffect::Deny)
.with_sid("DuplicateId")
.with_action(IAMAction::Single("s3:DeleteObject".to_string()))
.with_resource(IAMResource::Single("*".to_string())),
);
assert!(!duplicate_sid_policy.is_valid());
}
#[test]
fn test_policy_id_validation() {
let mut empty_id_policy = IAMPolicy::new();
empty_id_policy.id = Some(String::new());
empty_id_policy.statement.push(
IAMStatement::new(IAMEffect::Allow)
.with_action(IAMAction::Single("s3:GetObject".to_string()))
.with_resource(IAMResource::Single("*".to_string())),
);
assert!(!empty_id_policy.is_valid());
let short_id_policy = IAMPolicy::new().with_id("short").add_statement(
IAMStatement::new(IAMEffect::Allow)
.with_action(IAMAction::Single("s3:GetObject".to_string()))
.with_resource(IAMResource::Single("*".to_string())),
);
assert!(short_id_policy.is_valid());
}
#[test]
fn test_iam_policy_creation() {
let policy = IAMPolicy::new().with_id("test-policy").add_statement(
IAMStatement::new(IAMEffect::Allow)
.with_sid("AllowS3Access")
.with_action(IAMAction::Single("s3:GetObject".to_string()))
.with_resource(IAMResource::Single("arn:aws:s3:::mybucket/*".to_string())),
);
assert_eq!(policy.version, IAMVersion::V20121017);
assert_eq!(policy.id, Some("test-policy".to_string()));
assert_eq!(policy.statement.len(), 1);
assert_eq!(policy.statement[0].effect, IAMEffect::Allow);
}
#[test]
fn test_policy_serialization() {
let policy = IAMPolicy::new().add_statement(
IAMStatement::new(IAMEffect::Allow)
.with_action(IAMAction::Single("s3:GetObject".to_string()))
.with_resource(IAMResource::Single("*".to_string())),
);
let json = policy.to_json().unwrap();
let parsed_policy = IAMPolicy::from_json(&json).unwrap();
assert_eq!(policy, parsed_policy);
}
#[test]
fn test_policy_roundtrip_from_files() {
let policies_dir = "tests/policies";
let mut policy_files = std::fs::read_dir(policies_dir)
.unwrap_or_else(|e| panic!("Failed to read policies directory '{policies_dir}': {e}"))
.filter_map(|entry| {
let entry = entry.ok()?;
let path = entry.path();
if path.extension()? == "json" {
Some(path)
} else {
None
}
})
.collect::<Vec<_>>();
assert!(
!policy_files.is_empty(),
"No policy JSON files found in {policies_dir}/"
);
policy_files.sort_by_key(|p| {
p.file_name()
.and_then(|n| n.to_str())
.map(|s| s.split('.').next().unwrap().parse::<u32>().unwrap())
.map(|n| format!("{n:010}"))
});
println!(
"Testing {} policy files from {}/",
policy_files.len(),
policies_dir
);
for (index, policy_file) in policy_files.iter().enumerate() {
let filename = policy_file
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown");
println!("Testing policy #{}: {} ... ", index + 1, filename);
let json_content = std::fs::read_to_string(policy_file).unwrap_or_else(|e| {
panic!("Failed to read file '{}': {}", policy_file.display(), e)
});
let original_policy = IAMPolicy::from_json(&json_content)
.unwrap_or_else(|e| panic!("Failed to parse JSON policy: {e:?}"));
assert!(
original_policy.is_valid(),
"Policy {} is invalid: {:?}",
filename,
original_policy.validate(&mut ValidationContext::new())
);
let serialized_json = original_policy
.to_json()
.unwrap_or_else(|e| panic!("Failed to serialize policy to JSON: {e:?}"));
assert_eq!(
serialized_json,
json_content.trim_end_matches('\n'),
"Serialized JSON does not match original prettified JSON for file '{filename}'"
);
}
}
}