#![cfg(not(target_arch = "wasm32"))]
use moltendb_auth as auth; mod handlers; mod rate_limit; mod validation;
use moltendb_core::engine;
use axum::extract::Path;
use axum::extract::Query as AxumQuery;
use std::collections::HashMap as QueryMap;
use axum::extract::ws::Utf8Bytes;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
http::{StatusCode, HeaderValue, header},
middleware,
routing::{get, post},
Json,
Router,
};
use axum_server::tls_rustls::RustlsConfig;
use futures::{sink::SinkExt, stream::StreamExt};
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::path::PathBuf;
use tokio::signal;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tracing::{error, info, warn};
use clap::Parser;
#[derive(Parser, Debug)]
#[command(name = "moltendb", version, about)]
struct Config {
#[arg(long, default_value = "1538", env = "PORT")]
port: u16,
#[arg(long, default_value = "my_database.log", env = "DB_PATH")]
db_path: String,
#[arg(long, default_value = "cert.pem", env = "TLS_CERT")]
cert: String,
#[arg(long, default_value = "key.pem", env = "TLS_KEY")]
key: String,
#[arg(long, env = "ENCRYPTION_KEY")]
encryption_key: Option<String>,
#[arg(long, default_value = "async", env = "WRITE_MODE")]
write_mode: String,
#[arg(long, default_value = "standard", env = "STORAGE_MODE")]
storage_mode: String,
#[arg(long, default_value = "100", env = "RATE_LIMIT_REQUESTS")]
rate_limit_requests: usize,
#[arg(long, default_value = "60", env = "RATE_LIMIT_WINDOW_SECS")]
rate_limit_window: u64,
#[arg(long, env = "JWT_SECRET")]
jwt_secret: Option<String>,
#[arg(long, env = "MOLTENDB_ADMIN_USER")]
admin_user: Option<String>,
#[arg(long, env = "MOLTENDB_ADMIN_PASSWORD")]
admin_password: Option<String>,
#[arg(long, default_value = "10485760", env = "MAX_BODY_SIZE")]
max_body_size: usize,
#[arg(long, default_value = "*", env = "CORS_ORIGIN")]
cors_origin: String,
#[arg(long, default_value = "false", env = "DISABLE_ENCRYPTION")]
disable_encryption: bool,
#[arg(long, default_value = "false", env = "DEBUG")]
debug: bool,
}
#[tokio::main]
async fn main() {
let cfg = Config::parse();
let log_level = if cfg.debug { "debug" } else { "info" };
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive(log_level.parse().unwrap()),
)
.init();
if cfg.jwt_secret.is_none() {
error!("🔥 CRITICAL: --jwt-secret (JWT_SECRET) not set! This is required for security.");
std::process::exit(1);
}
if cfg.encryption_key.is_none() {
warn!("⚠️ --encryption-key not set — using built-in default key. Set it for production!");
}
if cfg.admin_user.is_none() {
error!("🔥 CRITICAL: --admin-user (MOLTENDB_ADMIN_USER) not set! This is required for security.");
std::process::exit(1);
}
if cfg.admin_password.is_none() {
error!("🔥 CRITICAL: --admin-password (MOLTENDB_ADMIN_PASSWORD) not set! This is required for security.");
std::process::exit(1);
}
let admin_user = cfg.admin_user.unwrap();
let admin_password = cfg.admin_password.unwrap();
let db_path = cfg.db_path;
let port = cfg.port;
let cert_path = cfg.cert;
let key_path = cfg.key;
let rate_limit_requests = cfg.rate_limit_requests;
let rate_limit_window = cfg.rate_limit_window;
let is_sync_mode = cfg.write_mode.to_lowercase() == "sync";
let is_tiered_mode = cfg.storage_mode.to_lowercase() == "tiered";
let encryption_key_storage;
let encryption_key: Option<&[u8; 32]> = if cfg.disable_encryption {
warn!("⚠️ Encryption is DISABLED — data will be stored as plain JSON!");
None
} else {
let password = cfg.encryption_key
.unwrap_or_else(|| "moltendb-default-encryption-key".to_string());
let key = engine::EncryptedStorage::derive_key(&password, &db_path);
encryption_key_storage = Some(key);
encryption_key_storage.as_ref()
};
let db = match engine::Db::open(&db_path, is_sync_mode, is_tiered_mode, encryption_key) {
Ok(database) => database,
Err(e) => {
error!("🔥 CRITICAL: Failed to start MoltenDB! Details: {}", e);
std::process::exit(1);
}
};
let bg_db = db.clone();
let bg_db_path = db_path.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
let max_log_bytes: u64 = 100 * 1024 * 1024; let mut secs_since_compact: u64 = 0;
loop {
interval.tick().await;
secs_since_compact += 60;
let log_size = std::fs::metadata(&bg_db_path)
.map(|m| m.len())
.unwrap_or(0);
let should_compact = log_size >= max_log_bytes || secs_since_compact >= 3600;
if should_compact {
if let Err(e) = bg_db.compact() {
warn!("⚠️ Background compaction failed: {}", e);
} else {
info!("🗜️ Compaction complete (log was {} MB)", log_size / 1024 / 1024);
}
secs_since_compact = 0;
}
}
});
let users = auth::UserStore::new(admin_user, admin_password);
info!("👤 User authentication initialized");
let rate_limiter = rate_limit::RateLimiter::new(rate_limit_requests, rate_limit_window);
info!("🚦 Rate limiting: {} requests per {} seconds", rate_limit_requests, rate_limit_window);
let cleanup_limiter = rate_limiter.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
cleanup_limiter.cleanup();
}
});
let app_state = (db.clone(), users, cfg.max_body_size);
let protected_routes = Router::new()
.route("/set", post(handle_set)) .route("/update", post(handle_update)) .route("/delete", post(handle_delete)) .route("/get", post(handle_get)) .route("/collections/{collection}", get(handle_rest_get_collection)) .route("/collections/{collection}/docs/{key}", get(handle_rest_get)) .layer(middleware::from_fn(auth::auth_middleware));
let public_routes = Router::new()
.route("/login", post(handle_login)) .route("/ws", get(ws_handler));
let cors = {
let origin_str = cfg.cors_origin.trim().to_string();
if origin_str == "*" {
if !cfg.debug {
warn!("⚠️ CORS is open to any origin ('*'). Set --cors-origin for production!");
}
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
let origins: Vec<HeaderValue> = origin_str
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.filter_map(|s| s.parse::<HeaderValue>().ok())
.collect();
if origins.is_empty() {
error!("🔥 CRITICAL: --cors-origin value '{}' produced no valid origins.", origin_str);
std::process::exit(1);
}
info!("🔒 CORS restricted to: {}", origin_str);
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods(Any)
.allow_headers(Any)
}
};
let app = public_routes
.merge(protected_routes)
.layer(cors)
.layer(SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::overriding(
header::X_XSS_PROTECTION,
HeaderValue::from_static("1; mode=block"),
))
.layer(SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=31536000; includeSubDomains"),
))
.layer(SetResponseHeaderLayer::overriding(
header::REFERRER_POLICY,
HeaderValue::from_static("no-referrer"),
))
.layer(SetResponseHeaderLayer::overriding(
header::HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
))
.layer(SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; script-src 'self'; object-src 'none'"),
))
.layer(RequestBodyLimitLayer::new(cfg.max_body_size))
.layer(middleware::from_fn(rate_limit::rate_limit_middleware))
.layer(axum::Extension(rate_limiter))
.with_state(app_state);
let addr = SocketAddr::from(([0, 0, 0, 0], port));
info!("🔒 TLS enabled - loading certificates...");
info!("🛡️ Security headers enabled");
let handle = axum_server::Handle::new();
let shutdown_handle = handle.clone();
tokio::spawn(async move {
shutdown_signal().await;
info!("⏳ Draining in-flight requests (up to 30s)...");
shutdown_handle.graceful_shutdown(Some(std::time::Duration::from_secs(30)));
});
match load_tls_config(&cert_path, &key_path).await {
Ok(tls_config) => {
info!("🚀 MoltenDB running on https://{}:{} (HTTPS + WSS)", addr.ip(), addr.port());
axum_server::bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.into_make_service())
.await
.unwrap();
drop(db);
info!("✅ Database flushed. Shutdown complete.");
}
Err(e) => {
error!("🔥 Failed to load TLS certificates: {}", e);
error!(" Cert path: {}", cert_path);
error!(" Key path: {}", key_path);
std::process::exit(1);
}
}
}
async fn load_tls_config(
cert_path: &str,
key_path: &str,
) -> Result<RustlsConfig, Box<dyn std::error::Error>> {
let cert = PathBuf::from(cert_path);
let key = PathBuf::from(key_path);
if !cert.exists() {
return Err(format!("Certificate file not found: {}", cert_path).into());
}
if !key.exists() {
return Err(format!("Key file not found: {}", key_path).into());
}
Ok(RustlsConfig::from_pem_file(cert, key).await?)
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
info!("🛑 Shutting down gracefully...");
}
async fn handle_login(
State((_, users, _)): State<(engine::Db, auth::UserStore, usize)>,
Json(payload): Json<auth::LoginRequest>,
) -> Result<Json<auth::LoginResponse>, (StatusCode, Json<Value>)> {
if users.verify_user(&payload.username, &payload.password) {
match auth::create_token(&payload.username) {
Ok(token) => Ok(Json(auth::LoginResponse { token })),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": "Failed to create token"})),
)),
}
} else {
Err((
StatusCode::UNAUTHORIZED,
Json(json!({"error": "Invalid credentials"})),
))
}
}
async fn handle_set(
State((db, _, max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
Json(payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
let (code, body) = handlers::process_set(&db, &payload, max_body_size);
(StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(body))
}
async fn handle_update(
State((db, _, max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
Json(payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
let (code, body) = handlers::process_update(&db, &payload, max_body_size);
(StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(body))
}
async fn handle_get(
State((db, _, max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
Json(payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
let (code, body) = handlers::process_get(&db, &payload, max_body_size);
(StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(body))
}
async fn handle_delete(
State((db, _, max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
Json(payload): Json<Value>,
) -> (StatusCode, Json<Value>) {
let (code, body) = handlers::process_delete(&db, &payload, max_body_size);
(StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(body))
}
async fn handle_rest_get(
State((db, _, max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
Path((collection, key)): Path<(String, String)>,
) -> (StatusCode, Json<Value>) {
let payload = json!({
"collection": collection,
"keys": key
});
let (code, body) = handlers::process_get(&db, &payload, max_body_size);
(StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(body))
}
async fn handle_rest_get_collection(
State((db, _, max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
Path(collection): Path<String>,
AxumQuery(params): AxumQuery<QueryMap<String, String>>,
) -> (StatusCode, Json<Value>) {
let mut payload = json!({ "collection": collection });
if let Some(limit) = params.get("limit").and_then(|v| v.parse::<u64>().ok()) {
payload["count"] = json!(limit);
}
if let Some(offset) = params.get("offset").and_then(|v| v.parse::<u64>().ok()) {
payload["offset"] = json!(offset);
}
let (code, body) = handlers::process_get(&db, &payload, max_body_size);
(StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), Json(body))
}
async fn ws_handler(
ws: WebSocketUpgrade,
State((db, _, _max_body_size)): State<(engine::Db, auth::UserStore, usize)>,
) -> impl axum::response::IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, db))
}
async fn handle_socket(mut socket: WebSocket, db: engine::Db) {
let is_authenticated = match socket.next().await {
Some(Ok(Message::Text(text))) => {
if let Ok(payload) = serde_json::from_str::<Value>(&text) {
if payload["action"].as_str() == Some("AUTH") {
if let Some(token) = payload["token"].as_str() {
auth::verify_token(token).is_ok()
} else {
false
}
} else {
false
}
} else {
false
}
}
_ => false,
};
if !is_authenticated {
let _ = socket
.send(Message::Text(Utf8Bytes::from(
r#"{"error":"Authentication required. Send {\"action\":\"AUTH\",\"token\":\"<jwt>\"} as the first message."}"#,
)))
.await;
let _ = socket.close().await;
warn!("🔒 Rejected unauthenticated WebSocket connection.");
return;
}
let _ = socket
.send(Message::Text(Utf8Bytes::from(
r#"{"status":"authenticated","message":"Connected to MoltenDB real-time feed. Use HTTP endpoints for CRUD. Send {\"action\":\"SUBSCRIBE\",\"collection\":\"<name>\"} to register interest."}"#,
)))
.await;
let (mut sender, mut receiver) = socket.split();
let mut rx = db.subscribe();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(_text))) = receiver.next().await {
}
});
let mut send_task = tokio::spawn(async move {
loop {
tokio::select! {
Ok(msg) = rx.recv() => {
if sender.send(Message::Text(Utf8Bytes::from(msg))).await.is_err() {
break; }
}
else => break,
}
}
});
tokio::select! {
_ = (&mut recv_task) => send_task.abort(),
_ = (&mut send_task) => recv_task.abort(),
};
}