use std::collections::HashMap;
use crate::{Error, Result};
#[derive(Default, Serialize, Deserialize, Clone)]
pub enum Shield {
#[default]
Disabled,
#[cfg(feature = "shield")]
Enabled {
api_key: String,
strict: bool,
},
}
#[derive(Serialize, Deserialize, Default)]
pub struct ShieldValidationInput {
pub ip: Option<String>,
pub email: Option<String>,
pub headers: Option<HashMap<String, String>>,
pub dry_run: bool,
}
#[derive(Serialize, Deserialize)]
pub struct ValidationResult {
blocked: bool,
reasons: Vec<String>,
}
impl Shield {
pub async fn validate(&self, input: ShieldValidationInput) -> Result<()> {
match &self {
#[cfg(feature = "shield")]
Shield::Enabled { api_key, strict } => {
let client = reqwest::Client::new();
if let Ok(response) = client
.post("https://shield.authifier.com/validate")
.json(&input)
.header("Authorization", api_key)
.send()
.await
{
let result: ValidationResult =
response.json().await.map_err(|_| Error::InternalError)?;
if result.blocked {
Err(Error::BlockedByShield)
} else {
Ok(())
}
} else if *strict {
Err(Error::InternalError)
} else {
Ok(())
}
}
Shield::Disabled => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::Shield;
#[async_std::test]
async fn it_accepts_if_no_shield_service() {
let shield = Shield::Disabled;
assert_eq!(shield.validate(Default::default()).await, Ok(()));
}
}