use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
use tracing::info;
use uuid::Uuid;
use uvb_core::{TenantId, UserId};
#[derive(Debug, Error)]
pub enum PushError {
#[error("Push provider error: {0}")]
Provider(String),
#[error("Encryption failed: {0}")]
Encryption(String),
#[error("Decryption failed: {0}")]
Decryption(String),
#[error("Challenge not found: {0}")]
ChallengeNotFound(String),
#[error("Challenge expired: {0}")]
ChallengeExpired(String),
#[error("Invalid number match")]
InvalidNumberMatch,
#[error("Device token not found")]
DeviceTokenNotFound,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum PushProvider {
Firebase,
ApnsPush,
OneSignal,
Custom,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
pub enum DevicePlatform {
iOS,
Android,
Web,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PushChallenge {
pub challenge_id: String,
pub user_id: UserId,
pub tenant_id: TenantId,
pub number_match_code: Option<String>,
pub context: PushContext,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub is_approved: bool,
pub responded_at: Option<DateTime<Utc>>,
pub response: Option<PushResponse>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PushContext {
pub device_info: DeviceInfo,
pub location: LocationInfo,
pub ip_address: String,
pub timestamp: DateTime<Utc>,
pub app_name: String,
pub operation: String,
pub recent_activity: Vec<ActivitySummary>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DeviceInfo {
pub platform: DevicePlatform,
pub os_version: Option<String>,
pub browser: Option<String>,
pub app_version: Option<String>,
pub device_model: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LocationInfo {
pub city: Option<String>,
pub country: Option<String>,
pub coordinates: Option<(f64, f64)>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ActivitySummary {
pub activity_type: String,
pub timestamp: DateTime<Utc>,
pub location: Option<String>,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum PushResponse {
Approved,
Denied,
Timeout,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EncryptedPushPayload {
pub challenge_id: String,
pub encrypted_context: Vec<u8>,
pub nonce: Vec<u8>,
pub encrypted_at: DateTime<Utc>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PushNotificationContent {
pub title: String,
pub body: String,
pub number_match: Option<String>,
pub data: HashMap<String, String>,
}
impl PushNotificationContent {
pub fn from_challenge(challenge: &PushChallenge) -> Self {
let location_str = if let Some(ref city) = challenge.context.location.city {
if let Some(ref country) = challenge.context.location.country {
format!("{}, {}", city, country)
} else {
city.clone()
}
} else {
"Unknown location".to_string()
};
let body = format!(
"Login attempt from {} at {}. Device: {:?}, IP: {}",
location_str,
challenge.context.timestamp.format("%Y-%m-%d %H:%M UTC"),
challenge.context.device_info.platform,
challenge.context.ip_address
);
let mut data = HashMap::new();
data.insert("challenge_id".to_string(), challenge.challenge_id.clone());
data.insert("app_name".to_string(), challenge.context.app_name.clone());
data.insert("operation".to_string(), challenge.context.operation.clone());
Self {
title: format!("Authentication Request - {}", challenge.context.app_name),
body,
number_match: challenge.number_match_code.clone(),
data,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PushNotificationConfig {
pub enabled: bool,
pub provider: PushProvider,
pub enable_encryption: bool,
pub enable_number_matching: bool,
pub number_match_length: usize,
pub challenge_timeout_secs: i64,
pub include_recent_activity: bool,
pub max_recent_activities: usize,
pub never_include_otp: bool,
pub require_client_cert: bool,
}
impl Default for PushNotificationConfig {
fn default() -> Self {
Self {
enabled: true,
provider: PushProvider::Firebase,
enable_encryption: true,
enable_number_matching: true,
number_match_length: 3,
challenge_timeout_secs: 60, include_recent_activity: true,
max_recent_activities: 3,
never_include_otp: true, require_client_cert: true,
}
}
}
#[async_trait]
pub trait PushNotificationProvider: Send + Sync {
async fn send_push(
&self,
device_token: &str,
content: &PushNotificationContent,
) -> Result<String, PushError>;
async fn get_delivery_status(&self, message_id: &str) -> Result<DeliveryStatus, PushError>;
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum DeliveryStatus {
Pending,
Delivered,
Failed,
Expired,
}
#[async_trait]
pub trait PushChallengeStorage: Send + Sync {
async fn save_challenge(&self, challenge: &PushChallenge) -> Result<(), PushError>;
async fn get_challenge(&self, challenge_id: &str) -> Result<Option<PushChallenge>, PushError>;
async fn update_response(
&self,
challenge_id: &str,
response: PushResponse,
) -> Result<(), PushError>;
async fn get_device_token(
&self,
user_id: &UserId,
platform: DevicePlatform,
) -> Result<Option<String>, PushError>;
}
pub struct InMemoryPushStorage {
challenges: tokio::sync::RwLock<HashMap<String, PushChallenge>>,
device_tokens: tokio::sync::RwLock<HashMap<(UserId, DevicePlatform), String>>,
}
impl InMemoryPushStorage {
pub fn new() -> Self {
Self {
challenges: tokio::sync::RwLock::new(HashMap::new()),
device_tokens: tokio::sync::RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryPushStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PushChallengeStorage for InMemoryPushStorage {
async fn save_challenge(&self, challenge: &PushChallenge) -> Result<(), PushError> {
let mut challenges = self.challenges.write().await;
challenges.insert(challenge.challenge_id.clone(), challenge.clone());
Ok(())
}
async fn get_challenge(&self, challenge_id: &str) -> Result<Option<PushChallenge>, PushError> {
let challenges = self.challenges.read().await;
Ok(challenges.get(challenge_id).cloned())
}
async fn update_response(
&self,
challenge_id: &str,
response: PushResponse,
) -> Result<(), PushError> {
let mut challenges = self.challenges.write().await;
if let Some(challenge) = challenges.get_mut(challenge_id) {
challenge.is_approved = matches!(response, PushResponse::Approved);
challenge.response = Some(response);
challenge.responded_at = Some(Utc::now());
}
Ok(())
}
async fn get_device_token(
&self,
user_id: &UserId,
platform: DevicePlatform,
) -> Result<Option<String>, PushError> {
let tokens = self.device_tokens.read().await;
Ok(tokens.get(&(user_id.clone(), platform)).cloned())
}
}
pub struct PushNotificationManager<S: PushChallengeStorage, P: PushNotificationProvider> {
storage: S,
provider: P,
config: PushNotificationConfig,
encryption_key: Option<Aes256Gcm>,
}
impl<S: PushChallengeStorage, P: PushNotificationProvider> PushNotificationManager<S, P> {
pub fn new(
storage: S,
provider: P,
config: PushNotificationConfig,
encryption_key: Option<&[u8; 32]>,
) -> Self {
let cipher = if config.enable_encryption {
encryption_key.map(|key| Aes256Gcm::new(key.into()))
} else {
None
};
Self {
storage,
provider,
config,
encryption_key: cipher,
}
}
pub async fn create_challenge(
&self,
user_id: UserId,
tenant_id: TenantId,
context: PushContext,
) -> Result<PushChallenge, PushError> {
let challenge_id = Uuid::new_v4().to_string();
let number_match_code = if self.config.enable_number_matching {
Some(self.generate_number_match_code())
} else {
None
};
let challenge = PushChallenge {
challenge_id,
user_id: user_id.clone(),
tenant_id,
number_match_code,
context,
created_at: Utc::now(),
expires_at: Utc::now() + chrono::Duration::seconds(self.config.challenge_timeout_secs),
is_approved: false,
responded_at: None,
response: None,
};
self.storage.save_challenge(&challenge).await?;
info!(
"Created push challenge {} for user {:?}",
challenge.challenge_id, user_id
);
Ok(challenge)
}
pub async fn send_push_notification(
&self,
challenge: &PushChallenge,
platform: DevicePlatform,
) -> Result<String, PushError> {
let device_token = self
.storage
.get_device_token(&challenge.user_id, platform)
.await?
.ok_or(PushError::DeviceTokenNotFound)?;
let content = PushNotificationContent::from_challenge(challenge);
if self.config.never_include_otp {
}
let message_id = self.provider.send_push(&device_token, &content).await?;
info!(
"Sent push notification {} for challenge {}",
message_id, challenge.challenge_id
);
Ok(message_id)
}
pub async fn verify_number_match(
&self,
challenge_id: &str,
provided_code: &str,
) -> Result<bool, PushError> {
let challenge = self
.storage
.get_challenge(challenge_id)
.await?
.ok_or_else(|| PushError::ChallengeNotFound(challenge_id.to_string()))?;
if Utc::now() > challenge.expires_at {
return Err(PushError::ChallengeExpired(
challenge.expires_at.to_string(),
));
}
if let Some(ref expected_code) = challenge.number_match_code {
if expected_code == provided_code {
self.storage
.update_response(challenge_id, PushResponse::Approved)
.await?;
info!("Number match verified for challenge {}", challenge_id);
return Ok(true);
}
}
Ok(false)
}
pub async fn respond_to_challenge(
&self,
challenge_id: &str,
response: PushResponse,
) -> Result<(), PushError> {
let challenge = self
.storage
.get_challenge(challenge_id)
.await?
.ok_or_else(|| PushError::ChallengeNotFound(challenge_id.to_string()))?;
if Utc::now() > challenge.expires_at {
return Err(PushError::ChallengeExpired(
challenge.expires_at.to_string(),
));
}
self.storage.update_response(challenge_id, response).await?;
info!("Challenge {} responded: {:?}", challenge_id, response);
Ok(())
}
fn generate_number_match_code(&self) -> String {
let mut rng = rand::thread_rng();
let code: u32 = rng.gen_range(0..10_u32.pow(self.config.number_match_length as u32));
format!("{:0width$}", code, width = self.config.number_match_length)
}
pub fn encrypt_payload(
&self,
challenge: &PushChallenge,
) -> Result<EncryptedPushPayload, PushError> {
let cipher = self
.encryption_key
.as_ref()
.ok_or_else(|| PushError::Encryption("Encryption not enabled".to_string()))?;
let plaintext = serde_json::to_vec(&challenge.context)
.map_err(|e| PushError::Encryption(e.to_string()))?;
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_ref())
.map_err(|e| PushError::Encryption(e.to_string()))?;
Ok(EncryptedPushPayload {
challenge_id: challenge.challenge_id.clone(),
encrypted_context: ciphertext,
nonce: nonce_bytes.to_vec(),
encrypted_at: Utc::now(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_number_match_generation() {
let config = PushNotificationConfig::default();
let storage = InMemoryPushStorage::new();
let provider = MockPushProvider::new();
let manager = PushNotificationManager::new(storage, provider, config, None);
let code = manager.generate_number_match_code();
assert_eq!(code.len(), 3);
assert!(code.chars().all(|c| c.is_ascii_digit()));
}
#[tokio::test]
async fn test_challenge_creation() {
let storage = InMemoryPushStorage::new();
let provider = MockPushProvider::new();
let config = PushNotificationConfig::default();
let manager = PushNotificationManager::new(storage, provider, config, None);
let context = PushContext {
device_info: DeviceInfo {
platform: DevicePlatform::iOS,
os_version: Some("15.0".to_string()),
browser: None,
app_version: Some("1.0".to_string()),
device_model: Some("iPhone 13".to_string()),
},
location: LocationInfo {
city: Some("San Francisco".to_string()),
country: Some("USA".to_string()),
coordinates: None,
},
ip_address: "192.0.2.1".to_string(),
timestamp: Utc::now(),
app_name: "MyApp".to_string(),
operation: "login".to_string(),
recent_activity: vec![],
};
let challenge = manager
.create_challenge(
UserId::new("test-user"),
TenantId::new("test-tenant"),
context,
)
.await
.unwrap();
assert!(challenge.number_match_code.is_some());
assert!(!challenge.is_approved);
}
pub struct MockPushProvider;
impl MockPushProvider {
pub fn new() -> Self {
Self
}
}
impl Default for MockPushProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PushNotificationProvider for MockPushProvider {
async fn send_push(
&self,
_device_token: &str,
_content: &PushNotificationContent,
) -> Result<String, PushError> {
Ok(Uuid::new_v4().to_string())
}
async fn get_delivery_status(
&self,
_message_id: &str,
) -> Result<DeliveryStatus, PushError> {
Ok(DeliveryStatus::Delivered)
}
}
}