1mod ai;
2mod auth;
3mod config;
4mod db;
5mod error;
6mod handlers;
7mod state;
8
9use axum::{
10 Router,
11 http::{HeaderValue, StatusCode, Uri, header},
12 middleware,
13 response::{Html, IntoResponse, Response},
14 routing::{get, post},
15};
16
17pub use auth::Claims;
18pub use config::AppConfig;
19pub use db::{DatabaseInfo, DatabaseType, DbPool, TableInfo, TableType};
20pub use error::AuthError;
21use rust_embed::Embed;
22pub use state::AppState;
23use tower_http::{
24 LatencyUnit,
25 cors::{self, CorsLayer},
26 trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
27};
28use tracing::Level;
29
30static INDEX_HTML: &str = "index.html";
31
32#[derive(Embed)]
33#[folder = "ui/dist"]
34struct Assets;
35
36pub fn get_router(state: AppState) -> Router {
37 let allowed_origin_str = state.config.allowed_origin.clone();
39 let cors = CorsLayer::new()
40 .allow_origin(
41 allowed_origin_str
42 .parse::<HeaderValue>()
43 .unwrap_or_else(|_| panic!("Invalid ALLOWED_ORIGIN: {}", allowed_origin_str)),
44 )
45 .allow_methods(cors::Any)
46 .allow_headers(cors::Any);
47
48 let api_routes = Router::new()
50 .route("/ping", get(handlers::ping))
51 .route("/databases", get(handlers::list_databases))
52 .route("/databases/{db_name}/tables", get(handlers::list_tables))
53 .route(
54 "/databases/{db_name}/tables/{table_name}/schema",
55 get(handlers::get_table_schema),
56 )
57 .route("/execute-query", post(handlers::execute_query))
58 .route("/schema", get(handlers::get_full_schema))
59 .route("/gen-query", post(handlers::gen_query))
60 .route_layer(middleware::from_fn_with_state(
61 state.clone(),
62 auth::auth_middleware,
63 ));
64
65 Router::new()
67 .nest("/api", api_routes)
68 .layer(cors)
69 .layer(
70 TraceLayer::new_for_http()
71 .on_request(DefaultOnRequest::new().level(Level::INFO))
72 .on_response(
73 DefaultOnResponse::new()
74 .level(Level::INFO)
75 .latency_unit(LatencyUnit::Micros),
76 ),
77 )
78 .fallback(static_handler)
79 .with_state(state)
80}
81
82async fn static_handler(uri: Uri) -> impl IntoResponse {
83 let path = uri.path().trim_start_matches('/');
84
85 if path.is_empty() || path == INDEX_HTML {
86 return index_html().await;
87 }
88
89 match Assets::get(path) {
90 Some(content) => {
91 let mime = mime_guess::from_path(path).first_or_octet_stream();
92
93 ([(header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
94 }
95 None => {
96 if path.contains('.') {
97 return not_found().await;
98 }
99
100 index_html().await
101 }
102 }
103}
104
105async fn index_html() -> Response {
106 match Assets::get(INDEX_HTML) {
107 Some(content) => Html(content.data).into_response(),
108 None => not_found().await,
109 }
110}
111
112async fn not_found() -> Response {
113 (StatusCode::NOT_FOUND, "404").into_response()
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[tokio::test]
121 async fn test_get_router() {
122 let config = AppConfig::load().unwrap(); let state = AppState::new(config).await.unwrap();
126 let _router = get_router(state);
127 }
129}