#![cfg(not(target_arch = "wasm32"))]
use moltendb_auth as auth; mod rate_limit; mod route_handlers; mod server; mod ws;
use route_handlers::{
handle_delegate, handle_delete, handle_get, handle_health, handle_login, handle_metrics,
handle_rest_get, handle_rest_get_collection, handle_revoke, handle_set, handle_snapshot,
handle_update,
};
use ws::ws_handler;
use moltendb_core::engine::{self, StorageBackend};
use axum::{
http::{HeaderValue, header},
middleware,
routing::{delete, get, post},
Extension,
Router,
};
use std::net::SocketAddr;
use std::sync::Arc;
use axum::extract::DefaultBodyLimit;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tracing::{error, info, warn};
use clap::Parser;
#[derive(Parser, Debug)]
#[command(name = "moltendb", version, about)]
struct Config {
#[command(subcommand)]
command: Option<Commands>,
#[arg(long, default_value = "0.0.0.0", env = "MOLTENDB_HOST")]
host: String,
#[arg(long, default_value = "1538", env = "MOLTENDB_PORT")]
port: u16,
#[arg(long, default_value = "my_database.log", env = "MOLTENDB_DB_PATH")]
db_path: String,
#[arg(long, default_value = "cert.pem", env = "MOLTENDB_TLS_CERT")]
cert: String,
#[arg(long, default_value = "key.pem", env = "MOLTENDB_TLS_KEY")]
key: String,
#[arg(long, env = "MOLTENDB_ENCRYPTION_KEY")]
encryption_key: Option<String>,
#[arg(long, default_value = "async", env = "MOLTENDB_WRITE_MODE")]
write_mode: String,
#[arg(long, default_value = "standard", env = "MOLTENDB_STORAGE_MODE")]
storage_mode: String,
#[arg(long, default_value = "100", env = "MOLTENDB_RATE_LIMIT_REQS")]
rate_limit_requests: u32,
#[arg(long, default_value = "60", env = "MOLTENDB_RATE_LIMIT_WINDOW")]
rate_limit_window: u64,
#[arg(long, env = "MOLTENDB_JWT_SECRET")]
jwt_secret: Option<String>,
#[arg(long, env = "MOLTENDB_ROOT_USER")]
root_user: Option<String>,
#[arg(long, env = "MOLTENDB_ROOT_PASSWORD")]
root_password: Option<String>,
#[arg(long, default_value = "10485760", env = "MOLTENDB_MAX_BODY_SIZE")]
max_body_size: usize,
#[arg(long, default_value = "1000", env = "MOLTENDB_MAX_KEYS_PER_REQUEST")]
max_keys_per_request: usize,
#[arg(long, default_value = "*", env = "MOLTENDB_CORS_ORIGIN")]
cors_origin: String,
#[arg(long, default_value = "false", env = "MOLTENDB_DISABLE_ENCRYPTION")]
disable_encryption: bool,
#[arg(long, default_value = "false", env = "MOLTENDB_DEBUG")]
debug: bool,
#[arg(long, default_value = "false", env = "MOLTENDB_DEV_MODE")]
dev_mode: bool,
#[arg(long, env = "MOLTENDB_POST_BACKUP_SCRIPT")]
pub post_backup_script: Option<String>,
#[arg(long, default_value = "50000", env = "MOLTENDB_HOT_THRESHOLD")]
hot_threshold: usize,
#[arg(long, default_value = "false", env = "MOLTENDB_IN_MEMORY")]
in_memory: bool,
}
#[derive(clap::Subcommand, Debug)]
enum Commands {
Serve,
Recover {
#[arg(long)]
log: String,
#[arg(long)]
snapshot: Option<String>,
#[arg(long)]
to_time: Option<u64>,
#[arg(long)]
to_seq: Option<u64>,
#[arg(long)]
out: String,
#[arg(long, env = "ENCRYPTION_KEY")]
encryption_key: Option<String>,
},
}
#[tokio::main]
async fn main() {
let cfg = Config::parse();
if let Some(Commands::Recover { log, snapshot: _, to_time, to_seq, out, encryption_key }) = &cfg.command {
tracing_subscriber::fmt().init();
info!("🕒 MoltenDB Point-in-Time Recovery Tool");
info!("📖 Reading log: {}", log);
let password = encryption_key.as_ref().map(|s| s.clone()).unwrap_or_else(|| "default_molten_password".to_string());
let master_key = engine::EncryptedStorage::derive_key(&password, "moltendb_log_salt");
let base_storage = Arc::new(engine::SyncDiskStorage::new(&log).expect("Failed to open log file"));
let storage: Arc<dyn engine::StorageBackend> = Arc::new(engine::EncryptedStorage::new(base_storage, &master_key));
match engine::Db::recover_to(&*storage, *to_time, *to_seq) {
Ok(entries) => {
info!("✅ Recovered {} entries.", entries.len());
info!("💾 Saving recovered state to: {}", out);
let temp_log = format!("{}.log", out);
{
let recovered_storage = engine::SyncDiskStorage::new(&temp_log).expect("Failed to create recovery log");
for entry in &entries {
recovered_storage.write_entry(entry).expect("Failed to write entry to recovery log");
}
recovered_storage.compact(entries).expect("Failed to compact recovery log");
}
let snapshot_path = format!("{}.snapshot.bin", temp_log);
std::fs::rename(snapshot_path, &out).expect("Failed to move snapshot to output path");
std::fs::remove_file(temp_log).ok();
info!("✨ Recovery complete! You can now use {} as your database snapshot.", out);
}
Err(e) => {
error!("❌ Recovery failed: {}", e);
std::process::exit(1);
}
}
return;
}
let log_level = if cfg.debug { "debug" } else { "info" };
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive(log_level.parse().unwrap()),
)
.init();
if cfg.jwt_secret.is_none() {
error!("🔥 CRITICAL: --jwt-secret (JWT_SECRET) not set! This is required for security.");
std::process::exit(1);
}
if cfg.encryption_key.is_none() {
warn!("⚠️ --encryption-key not set — using built-in default key. Set it for production!");
}
if cfg.root_user.is_none() {
error!("🔥 CRITICAL: --root-user (MOLTENDB_ROOT_USER) not set! This is required for security.");
std::process::exit(1);
}
if cfg.root_password.is_none() {
error!("🔥 CRITICAL: --root-password (MOLTENDB_ROOT_PASSWORD) not set! This is required for security.");
std::process::exit(1);
}
let root_user = cfg.root_user.unwrap();
let root_password = cfg.root_password.unwrap();
let db_path = cfg.db_path;
let host = cfg.host;
let port = cfg.port;
let cert_path = cfg.cert;
let key_path = cfg.key;
let rate_limit_requests = cfg.rate_limit_requests;
let rate_limit_window = cfg.rate_limit_window;
let is_sync_mode = cfg.write_mode.to_lowercase() == "sync";
let is_tiered_mode = cfg.storage_mode.to_lowercase() == "tiered";
let is_in_memory = cfg.in_memory;
let encryption_key_storage;
let encryption_key: Option<&[u8; 32]> = if cfg.disable_encryption {
warn!("⚠️ Encryption is DISABLED — data will be stored as plain JSON!");
None
} else {
let password = cfg.encryption_key
.unwrap_or_else(|| "moltendb-default-encryption-key".to_string());
let key = engine::EncryptedStorage::derive_key(&password, &db_path);
encryption_key_storage = Some(key);
encryption_key_storage.as_ref()
};
let db_config = engine::DbConfig {
path: db_path.clone(),
sync_mode: is_sync_mode,
tiered_mode: is_tiered_mode,
hot_threshold: cfg.hot_threshold,
rate_limit_requests: Some(rate_limit_requests),
rate_limit_window: Some(rate_limit_window),
max_body_size: cfg.max_body_size,
max_keys_per_request: cfg.max_keys_per_request,
encryption_key: encryption_key.cloned(),
post_backup_script: cfg.post_backup_script,
in_memory: cfg.in_memory,
};
let db = match engine::Db::open(db_config) {
Ok(database) => database,
Err(e) => {
error!("🔥 CRITICAL: Failed to start MoltenDB! Details: {}", e);
std::process::exit(1);
}
};
if is_in_memory {
warn!("⚡ IN-MEMORY MODE — all data is stored in RAM only. Nothing will be persisted to disk. Data will be lost on exit.");
}
if !is_in_memory {
let bg_db = db.clone();
let bg_db_path = db_path.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
let max_log_bytes: u64 = 100 * 1024 * 1024; let mut secs_since_compact: u64 = 0;
loop {
interval.tick().await;
secs_since_compact += 60;
let log_size = std::fs::metadata(&bg_db_path)
.map(|m| m.len())
.unwrap_or(0);
let should_compact = log_size >= max_log_bytes || secs_since_compact >= 3600;
if should_compact {
if let Err(e) = bg_db.compact() {
warn!("⚠️ Background compaction failed: {}", e);
} else {
info!("🗜️ Compaction complete (log was {} MB)", log_size / 1024 / 1024);
}
secs_since_compact = 0;
}
}
});
}
let users = auth::UserStore::new(root_user.clone(), root_password);
info!("👤 User authentication initialized");
let revocations_path = {
let base = std::path::Path::new(&db_path);
let stem = base.file_stem().and_then(|s| s.to_str()).unwrap_or("my_database");
let dir = base.parent().and_then(|p| p.to_str()).filter(|s| !s.is_empty()).unwrap_or(".");
format!("{}/{}.revocations.json", dir, stem)
};
let revocation_store = auth::RevocationStore::load_from_file(&revocations_path);
info!("🔒 Revocation store loaded from '{}'", revocations_path);
let prune_store = revocation_store.clone();
let prune_revocations_path = revocations_path.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
loop {
interval.tick().await;
prune_store.prune();
if !is_in_memory {
prune_store.save_to_file(&prune_revocations_path);
}
}
});
info!("🔒 Token revocation store initialized");
let rate_limiter = rate_limit::RateLimiter::new(rate_limit_requests as usize, rate_limit_window);
info!("🚦 Rate limiting: {} requests per {} seconds", rate_limit_requests, rate_limit_window);
let cleanup_limiter = rate_limiter.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
cleanup_limiter.cleanup();
}
});
let app_state = (db.clone(), users, cfg.max_body_size, cfg.max_keys_per_request, root_user);
let mut protected_routes = Router::new()
.route("/set", post(handle_set)) .route("/update", post(handle_update)) .route("/delete", post(handle_delete)) .route("/snapshot", post(handle_snapshot)) .route("/get", post(handle_get)) .route("/collections/{collection}", get(handle_rest_get_collection)) .route("/collections/{collection}/docs/{key}", get(handle_rest_get)) .route("/auth/delegate", post(handle_delegate)) .route("/auth/tokens/{jti}", delete(handle_revoke)) .route("/system/metrics", get(handle_metrics));
#[cfg(feature = "schema")]
{
use route_handlers::handle_schema;
protected_routes = protected_routes.route("/schema", post(handle_schema));
}
let protected_routes = protected_routes
.layer(middleware::from_fn(auth::auth_middleware))
.layer(Extension(revocation_store.clone()))
.layer(Extension(auth::RevocationsPath(revocations_path)));
let public_routes = Router::new()
.route("/login", post(handle_login)) .route("/ws", get(ws_handler)) .route("/system/health", get(handle_health)) .layer(Extension(revocation_store));
let cors = {
let origin_str = cfg.cors_origin.trim().to_string();
if origin_str == "*" {
if !cfg.debug {
warn!("⚠️ CORS is open to any origin ('*'). Set --cors-origin for production!");
}
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
let origins: Vec<HeaderValue> = origin_str
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.filter_map(|s| s.parse::<HeaderValue>().ok())
.collect();
if origins.is_empty() {
error!("🔥 CRITICAL: --cors-origin value '{}' produced no valid origins.", origin_str);
std::process::exit(1);
}
info!("🔒 CORS restricted to: {}", origin_str);
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods(Any)
.allow_headers(Any)
}
};
let app = public_routes
.merge(protected_routes)
.layer(cors)
.layer(SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::overriding(
header::X_XSS_PROTECTION,
HeaderValue::from_static("1; mode=block"),
))
.layer(SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=31536000; includeSubDomains"),
))
.layer(SetResponseHeaderLayer::overriding(
header::REFERRER_POLICY,
HeaderValue::from_static("no-referrer"),
))
.layer(SetResponseHeaderLayer::overriding(
header::HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
))
.layer(SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; script-src 'self'; object-src 'none'"),
))
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(cfg.max_body_size))
.layer(middleware::from_fn(rate_limit::rate_limit_middleware))
.layer(axum::Extension(rate_limiter))
.with_state(app_state);
let addr: SocketAddr = format!("{}:{}", host, port)
.parse()
.unwrap_or_else(|e| {
error!("🔥 Invalid --host value '{}': {}", host, e);
std::process::exit(1);
});
if cfg.dev_mode {
warn!("⚠️ DEV MODE ENABLED — server is running over plain HTTP/WS. NEVER use in production!");
} else {
info!("🔒 TLS enabled - loading certificates...");
}
info!("🛡️ Security headers enabled");
let handle = axum_server::Handle::new();
let shutdown_handle = handle.clone();
tokio::spawn(async move {
server::shutdown_signal().await;
info!("⏳ Draining in-flight requests (up to 30s)...");
shutdown_handle.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
});
if cfg.dev_mode {
info!("🚀 MoltenDB running on http://{}:{} (HTTP + WS) [DEV MODE]", addr.ip(), addr.port());
axum_server::bind(addr)
.handle(handle)
.serve(app.into_make_service())
.await
.unwrap();
} else {
match server::load_tls_config(&cert_path, &key_path).await {
Ok(tls_config) => {
info!("🚀 MoltenDB running on https://{}:{} (HTTPS + WSS)", addr.ip(), addr.port());
axum_server::bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.into_make_service())
.await
.unwrap();
}
Err(e) => {
error!("🔥 Failed to load TLS certificates: {}", e);
error!(" Cert path: {}", cert_path);
error!(" Key path: {}", key_path);
std::process::exit(1);
}
}
}
drop(db);
info!("✅ Database flushed. Shutdown complete.");
}