use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use bpaf::Bpaf;
use cognitox::{api, config::StorageConfig, storage::Storage};
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const DEFAULT_PORT: u16 = 9229;
const DEFAULT_STORAGE_MODE: &str = "memory";
const DEFAULT_LOG_FILTER: &str = "cognitox=info,tower_http=info";
#[derive(Debug, Clone, Bpaf)]
#[bpaf(options, version)]
struct Cli {
#[bpaf(
short,
long,
env("COGNITOX_PORT"),
fallback(DEFAULT_PORT),
display_fallback
)]
port: u16,
#[bpaf(
long,
env("COGNITOX_STORAGE_MODE"),
fallback(String::from(DEFAULT_STORAGE_MODE)),
display_fallback
)]
storage_mode: String,
#[bpaf(short, long, env("COGNITOX_DATA_FILE"), argument("FILE"))]
data_file: Option<PathBuf>,
#[bpaf(short, long, env("RUST_LOG"), argument("FILTER"))]
log_level: Option<String>,
}
impl Cli {
fn log_filter(&self) -> String {
self.log_level
.clone()
.unwrap_or_else(|| DEFAULT_LOG_FILTER.to_string())
}
fn storage_config(&self) -> Result<StorageConfig, String> {
match self.storage_mode.as_str() {
"memory" => {
match &self.data_file {
Some(path) => Ok(StorageConfig::persistent(path.clone())),
None => Ok(StorageConfig::memory()),
}
}
"persistent" => {
let path = self.data_file.clone().ok_or_else(|| {
"--data-file is required when --storage-mode=persistent".to_string()
})?;
Ok(StorageConfig::persistent(path))
}
other => Err(format!(
"Unknown storage mode: {other}. Expected \"memory\" or \"persistent\"."
)),
}
}
}
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok();
let cli = cli().run();
let log_filter = cli.log_filter();
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| log_filter.into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let storage_config = cli.storage_config().unwrap_or_else(|e| {
tracing::error!("{e}");
std::process::exit(1);
});
let storage = Arc::new(Storage::with_config(storage_config).unwrap_or_else(|e| {
tracing::error!("Failed to initialize storage: {e}");
std::process::exit(1);
}));
tracing::info!("Storage backend: {}", storage.backend_description());
let app = api::create_router((*storage).clone())
.layer(TraceLayer::new_for_http())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
);
let addr = SocketAddr::from(([0, 0, 0, 0], cli.port));
tracing::info!("Starting Cognito emulator on http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr)
.await
.expect("failed to bind TCP listener");
let shutdown_storage = storage.clone();
axum::serve(listener, app)
.with_graceful_shutdown(async move {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for ctrl+c");
tracing::info!("Shutdown signal received, flushing persistence...");
if let Err(e) = shutdown_storage.flush_persistence().await {
tracing::error!("Failed to flush persistence on shutdown: {e}");
}
tracing::info!("Shutdown complete");
})
.await
.expect("server exited with error");
}
#[cfg(test)]
mod tests {
use super::*;
fn cli_for_test() -> Cli {
Cli {
port: DEFAULT_PORT,
storage_mode: DEFAULT_STORAGE_MODE.to_string(),
data_file: None,
log_level: None,
}
}
#[test]
fn default_log_filter_is_info() {
assert_eq!(cli_for_test().log_filter(), DEFAULT_LOG_FILTER);
}
#[test]
fn explicit_log_filter_overrides_default() {
let mut cli = cli_for_test();
cli.log_level = Some("cognitox=warn".to_string());
assert_eq!(cli.log_filter(), "cognitox=warn");
}
}