use crate::encryption::{ClientSecretEncryption, ConnectionStringEncryption, EncryptionError};
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AuthType {
ConnectionString,
AzureAd,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthConfig {
pub primary_method: AuthType,
pub fallback_enabled: bool,
pub connection_string: Option<ConnectionStringConfig>,
pub azure_ad: Option<AzureAdAuthConfig>,
}
impl AuthConfig {
pub fn has_encrypted_data(&self) -> bool {
let connection_string_encrypted = self
.connection_string
.as_ref()
.map(|cs| cs.is_encrypted())
.unwrap_or(false);
let azure_ad_encrypted = self
.azure_ad
.as_ref()
.map(|ad| ad.has_encrypted_data())
.unwrap_or(false);
connection_string_encrypted || azure_ad_encrypted
}
pub fn get_encrypted_auth_methods(&self) -> Vec<String> {
let mut methods = Vec::new();
if let Some(cs) = &self.connection_string {
if cs.is_encrypted() {
methods.push("Connection String".to_string());
}
}
if let Some(ad) = &self.azure_ad {
if ad.has_encrypted_client_secret() {
methods.push("Azure AD Client Secret".to_string());
}
}
methods
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct ConnectionStringConfig {
pub value: String,
pub encrypted_value: Option<String>,
pub encryption_salt: Option<String>,
}
impl ConnectionStringConfig {
pub fn get_connection_string(&self, password: Option<&str>) -> Result<String, EncryptionError> {
if let (Some(encrypted), Some(salt)) = (&self.encrypted_value, &self.encryption_salt) {
let password = password.ok_or_else(|| {
EncryptionError::InvalidData(
"Password required for encrypted connection string".to_string(),
)
})?;
let encryption = ConnectionStringEncryption::from_salt_base64(salt)?;
encryption.decrypt_connection_string(encrypted, password)
} else {
Ok(self.value.clone())
}
}
pub fn is_encrypted(&self) -> bool {
self.encrypted_value.is_some() && self.encryption_salt.is_some()
}
pub fn encrypt_with_password(&mut self, password: &str) -> Result<(), EncryptionError> {
if self.value.trim().is_empty() {
return Err(EncryptionError::InvalidData(
"Connection string cannot be empty".to_string(),
));
}
let encryption = ConnectionStringEncryption::new();
let encrypted = encryption.encrypt_connection_string(&self.value, password)?;
self.encrypted_value = Some(encrypted);
self.encryption_salt = Some(encryption.salt_base64());
self.value.clear();
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct AzureAdAuthConfig {
#[serde(default = "default_auth_method")]
pub auth_method: String,
pub tenant_id: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub encrypted_client_secret: Option<String>,
pub client_secret_encryption_salt: Option<String>,
pub subscription_id: Option<String>,
pub resource_group: Option<String>,
pub namespace: Option<String>,
pub authority_host: Option<String>,
pub scope: Option<String>,
}
fn default_auth_method() -> String {
"device_code".to_string()
}
impl AzureAdAuthConfig {
pub fn get_client_secret(
&self,
password: Option<&str>,
) -> Result<Option<String>, EncryptionError> {
if let (Some(encrypted), Some(salt)) = (
&self.encrypted_client_secret,
&self.client_secret_encryption_salt,
) {
let password = password.ok_or_else(|| {
EncryptionError::InvalidData(
"Password required for encrypted client secret".to_string(),
)
})?;
let encryption = ClientSecretEncryption::from_salt_base64(salt)?;
let decrypted = encryption.decrypt_client_secret(encrypted, password)?;
Ok(Some(decrypted))
} else {
Ok(self.client_secret.clone())
}
}
pub fn has_encrypted_client_secret(&self) -> bool {
self.encrypted_client_secret.is_some() && self.client_secret_encryption_salt.is_some()
}
pub fn has_encrypted_data(&self) -> bool {
self.has_encrypted_client_secret()
}
pub fn encrypt_client_secret_with_password(
&mut self,
password: &str,
) -> Result<(), EncryptionError> {
let client_secret = match &self.client_secret {
Some(secret) if !secret.trim().is_empty() => secret,
_ => {
return Err(EncryptionError::InvalidData(
"Client secret cannot be empty".to_string(),
));
}
};
let encryption = ClientSecretEncryption::new();
let encrypted = encryption.encrypt_client_secret(client_secret, password)?;
self.encrypted_client_secret = Some(encrypted);
self.client_secret_encryption_salt = Some(encryption.salt_base64());
self.client_secret = None;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct CachedToken {
pub token: String,
pub expires_at: Instant,
pub token_type: String,
}
impl CachedToken {
pub fn new(token: String, expires_in: Duration, token_type: String) -> Self {
Self {
token,
expires_at: Instant::now() + expires_in,
token_type,
}
}
pub fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
pub fn needs_refresh(&self) -> bool {
let buffer = Duration::from_secs(300); Instant::now() + buffer >= self.expires_at
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DeviceCodeInfo {
pub user_code: String,
pub verification_uri: String,
pub message: String,
}