use auditor::configuration::{TLSParams, get_configuration};
use auditor::metrics::DatabaseMetricsWatcher;
use auditor::startup::run;
use auditor::telemetry::{get_subscriber, init_subscriber};
use sqlx::postgres::PgPoolOptions;
use rustls::{
RootCertStore, ServerConfig, pki_types::CertificateDer, pki_types::PrivateKeyDer,
pki_types::PrivatePkcs8KeyDer, pki_types::pem::PemObject, server::WebPkiClientVerifier,
};
use std::{fs::File, io::BufReader, sync::Arc};
use std::env;
use casbin::{CoreApi, DefaultModel, Enforcer, FileAdapter, MgmtApi, RbacApi};
use tokio::signal;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let configuration = get_configuration().expect("Failed to read configuration.");
let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
let subscriber = get_subscriber("AUDITOR".into(), configuration.log_level, std::io::stdout);
init_subscriber(subscriber);
let connection_pool = PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(configuration.database.with_db());
let db_metrics_watcher = DatabaseMetricsWatcher::new(connection_pool.clone(), &configuration)?;
let db_metrics_watcher_task = db_metrics_watcher.clone();
tokio::spawn(async move {
db_metrics_watcher_task.monitor().await.unwrap();
});
if let Some(tls) = configuration.tls_config {
if tls.use_tls {
let mut cert_store = RootCertStore::empty();
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap();
if let Err(e) = tls.validate_tls_paths() {
eprintln!("Configuration error: {e}");
}
let ca_cert_path = tls.ca_cert_path.as_ref().unwrap();
let server_key_path = tls.server_key_path.as_ref().unwrap();
let server_cert_path = tls.server_cert_path.as_ref().unwrap();
match env::current_dir() {
Ok(path) => println!("Current directory: {}", path.display()),
Err(e) => eprintln!("Error getting current directory: {e}"),
}
let ca_cert = &mut BufReader::new(File::open(ca_cert_path)?);
let ca_cert = CertificateDer::pem_reader_iter(ca_cert)
.collect::<Result<Vec<_>, _>>()
.unwrap();
for cert in ca_cert {
cert_store.add(cert).expect("root CA not added to store");
}
let client_auth = WebPkiClientVerifier::builder(Arc::new(cert_store))
.build()
.unwrap();
let config = ServerConfig::builder().with_client_cert_verifier(client_auth);
let cert_file = &mut BufReader::new(File::open(server_cert_path)?);
let key_file = &mut BufReader::new(File::open(server_key_path)?);
let cert_chain = CertificateDer::pem_reader_iter(cert_file)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut keys: Vec<PrivatePkcs8KeyDer<'_>> =
PrivatePkcs8KeyDer::pem_reader_iter(key_file)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let config = config
.with_single_cert(cert_chain, PrivateKeyDer::Pkcs8(keys.remove(0)))
.unwrap();
let tls_params = TLSParams {
config,
https_addr: tls.https_addr,
https_port: tls.https_port,
use_tls: tls.use_tls,
};
let enforce_rbac: bool = configuration
.rbac_config
.as_ref()
.map(|r| r.enforce_rbac)
.unwrap_or(false);
if let Some(rbac_config) = &configuration.rbac_config
&& rbac_config.enforce_rbac
{
let m = DefaultModel::from_file("model.conf").await.unwrap();
let _writer_file = File::create("policy.csv")?;
let adapter = FileAdapter::new("policy.csv");
let mut enforcer = Enforcer::new(m, adapter).await.unwrap();
enforcer
.add_policies(rbac_config.base_policies.clone())
.await
.unwrap();
enforcer.save_policy().await.unwrap();
if let Some(monitoring_role_cn) = &rbac_config.monitoring_role_cn {
for item in monitoring_role_cn.iter() {
enforcer
.add_role_for_user(item, "monitoring_role", None)
.await
.unwrap();
enforcer.save_policy().await.unwrap();
}
}
if let Some(write_access_cn) = &rbac_config.write_access_cn {
for item in write_access_cn.iter() {
enforcer
.add_role_for_user(item, "write_access_base", None)
.await
.unwrap();
enforcer.save_policy().await.unwrap();
}
}
if let Some(read_access_cn) = &rbac_config.read_access_cn {
for item in read_access_cn.iter() {
enforcer
.add_role_for_user(item, "read_access_base", None)
.await
.unwrap();
enforcer.save_policy().await.unwrap();
}
}
if let Some(data_access_rules) = &rbac_config.data_access_rules {
for item in data_access_rules {
for (meta_id, meta_values) in item.meta_info.iter() {
for meta_value in meta_values {
enforcer
.add_policy(vec![
"meta".to_string(),
meta_id.to_string(),
meta_value.to_string(),
])
.await
.unwrap();
enforcer.save_policy().await.unwrap();
enforcer
.add_role_for_user(
&item.reader_cn.clone(),
"read_access_base",
None,
)
.await
.unwrap();
enforcer.save_policy().await.unwrap();
}
}
}
}
}
run(
configuration.application.addr,
configuration.application.port,
configuration.application.web_workers,
connection_pool,
db_metrics_watcher,
Some(tls_params),
enforce_rbac,
configuration.ignore_record_exists_error,
configuration.archival_config,
)
.await?
.await?;
} else {
run(
configuration.application.addr,
configuration.application.port,
configuration.application.web_workers,
connection_pool,
db_metrics_watcher,
None,
false,
configuration.ignore_record_exists_error,
configuration.archival_config,
)
.await?
.await?;
}
} else {
run(
configuration.application.addr,
configuration.application.port,
configuration.application.web_workers,
connection_pool,
db_metrics_watcher,
None,
false,
configuration.ignore_record_exists_error,
configuration.archival_config,
)
.await?
.await?;
tokio::select! {
_ = signal::ctrl_c() => {
tracing::info!("Received SIGINT (Ctrl-C)");
}
_ = sigterm.recv() => {
tracing::info!("Received SIGTERM");
}
}
}
Ok(())
}