use crate::{
Result, api,
api::system::SystemState,
auth::{AuthState, auth_filter, handle_auth_rejection},
config::DashboardConfig,
websocket::WebSocketState,
};
use hammerwork::JobQueue;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::RwLock;
use tracing::{error, info};
use warp::{Filter, Reply};
pub struct WebDashboard {
config: DashboardConfig,
auth_state: AuthState,
websocket_state: Arc<RwLock<WebSocketState>>,
}
impl WebDashboard {
pub async fn new(config: DashboardConfig) -> Result<Self> {
let auth_state = AuthState::new(config.auth.clone());
let websocket_state = Arc::new(RwLock::new(WebSocketState::new(config.websocket.clone())));
Ok(Self {
config,
auth_state,
websocket_state,
})
}
pub async fn start(self) -> Result<()> {
let bind_addr: SocketAddr = self.config.bind_addr().parse()?;
let (queue, database_type) = self.create_job_queue_with_type().await?;
let queue = Arc::new(queue);
let system_state = Arc::new(RwLock::new(SystemState::new(
self.config.clone(),
database_type,
self.config.pool_size,
)));
let api_routes = Self::create_api_routes_static(
queue.clone(),
self.auth_state.clone(),
system_state.clone(),
);
let websocket_routes = Self::create_websocket_routes_static(
self.websocket_state.clone(),
self.auth_state.clone(),
);
let static_routes = Self::create_static_routes_static(self.config.static_dir.clone())?;
let routes = api_routes
.or(websocket_routes)
.or(static_routes)
.recover(handle_auth_rejection);
let routes = routes.with(if self.config.enable_cors {
warp::cors()
.allow_any_origin()
.allow_headers(vec!["content-type", "authorization"])
.allow_methods(vec!["GET", "POST", "PUT", "DELETE", "OPTIONS"])
} else {
warp::cors().allow_origin("none") });
info!("Starting web server on {}", bind_addr);
let auth_state_cleanup = self.auth_state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); loop {
interval.tick().await;
auth_state_cleanup.cleanup_expired_attempts().await;
}
});
let websocket_state_ping = self.websocket_state.clone();
let ping_interval = self.config.websocket.ping_interval;
tokio::spawn(async move {
let mut interval = tokio::time::interval(ping_interval);
loop {
interval.tick().await;
let state = websocket_state_ping.read().await;
state.ping_all_connections().await;
}
});
let websocket_state_broadcast = self.websocket_state.clone();
WebSocketState::start_broadcast_listener(websocket_state_broadcast).await?;
warp::serve(routes).run(bind_addr).await;
Ok(())
}
async fn create_job_queue_with_type(&self) -> Result<(QueueType, String)> {
if self.config.database_url.starts_with("postgres") {
#[cfg(feature = "postgres")]
{
let pg_pool = sqlx::PgPool::connect(&self.config.database_url).await?;
info!(
"Connected to PostgreSQL with {} connections",
self.config.pool_size
);
Ok((JobQueue::new(pg_pool), "PostgreSQL".to_string()))
}
#[cfg(not(feature = "postgres"))]
{
return Err(anyhow::anyhow!(
"PostgreSQL support not enabled. Rebuild with --features postgres"
));
}
} else if self.config.database_url.starts_with("mysql") {
#[cfg(feature = "mysql")]
{
#[cfg(all(feature = "mysql", not(feature = "postgres")))]
{
let mysql_pool = sqlx::MySqlPool::connect(&self.config.database_url).await?;
info!(
"Connected to MySQL with {} connections",
self.config.pool_size
);
Ok((JobQueue::new(mysql_pool), "MySQL".to_string()))
}
#[cfg(all(feature = "postgres", feature = "mysql"))]
{
return Err(anyhow::anyhow!(
"MySQL database URL provided but PostgreSQL is the default when both features are enabled"
));
}
}
#[cfg(not(feature = "mysql"))]
{
return Err(anyhow::anyhow!(
"MySQL support not enabled. Rebuild with --features mysql"
));
}
} else {
Err(anyhow::anyhow!("Unsupported database URL format"))
}
}
fn create_api_routes_static(
queue: Arc<QueueType>,
auth_state: AuthState,
system_state: Arc<RwLock<SystemState>>,
) -> impl Filter<Extract = impl Reply, Error = warp::Rejection> + Clone {
let health = warp::path("health")
.and(warp::path::end())
.and(warp::get())
.map(|| {
warp::reply::json(&serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION")
}))
});
let api_routes = api::queues::routes(queue.clone())
.or(api::jobs::routes(queue.clone()))
.or(api::stats::routes(queue.clone(), system_state.clone()))
.or(api::system::routes(queue.clone(), system_state))
.or(api::archive::archive_routes(queue.clone()))
.or(api::spawn::spawn_routes(queue));
let authenticated_api = warp::path("api")
.and(auth_filter(auth_state))
.untuple_one()
.and(api_routes);
health.or(authenticated_api)
}
fn create_websocket_routes_static(
websocket_state: Arc<RwLock<WebSocketState>>,
auth_state: AuthState,
) -> impl Filter<Extract = impl Reply, Error = warp::Rejection> + Clone {
warp::path("ws")
.and(warp::path::end())
.and(auth_filter(auth_state))
.and(warp::ws())
.and(warp::any().map(move || websocket_state.clone()))
.map(
|_: (), ws: warp::ws::Ws, websocket_state: Arc<RwLock<WebSocketState>>| {
ws.on_upgrade(move |socket| async move {
let mut state = websocket_state.write().await;
if let Err(e) = state.handle_connection(socket).await {
error!("WebSocket error: {}", e);
}
})
},
)
}
fn create_static_routes_static(
static_dir: std::path::PathBuf,
) -> Result<impl Filter<Extract = impl Reply, Error = warp::Rejection> + Clone> {
let static_files = warp::path("static").and(warp::fs::dir(static_dir.clone()));
let index = warp::path::end().and(warp::fs::file(static_dir.join("index.html")));
let spa_routes = warp::any().and(warp::fs::file(static_dir.join("index.html")));
Ok(index.or(static_files).or(spa_routes))
}
}
#[cfg(all(feature = "postgres", not(feature = "mysql")))]
type QueueType = JobQueue<sqlx::Postgres>;
#[cfg(all(feature = "mysql", not(feature = "postgres")))]
type QueueType = JobQueue<sqlx::MySql>;
#[cfg(all(feature = "postgres", feature = "mysql"))]
type QueueType = JobQueue<sqlx::Postgres>;
#[cfg(all(not(feature = "postgres"), not(feature = "mysql")))]
compile_error!("At least one database feature (postgres or mysql) must be enabled");
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DashboardConfig;
use tempfile::tempdir;
#[tokio::test]
async fn test_dashboard_creation() {
let temp_dir = tempdir().unwrap();
let config = DashboardConfig::new().with_static_dir(temp_dir.path().to_path_buf());
let dashboard = WebDashboard::new(config).await;
assert!(dashboard.is_ok());
}
#[test]
fn test_cors_configuration() {
let config = DashboardConfig::new().with_cors(true);
assert!(config.enable_cors);
}
}