r2_data2/
lib.rs

1mod ai;
2mod auth;
3mod config;
4mod db;
5mod error;
6mod handlers;
7mod state;
8
9use axum::{
10    Router,
11    http::{HeaderValue, StatusCode, Uri, header},
12    middleware,
13    response::{Html, IntoResponse, Response},
14    routing::{get, post},
15};
16
17pub use auth::Claims;
18pub use config::AppConfig;
19pub use db::{DatabaseInfo, DatabaseType, DbPool, TableInfo, TableType};
20pub use error::AuthError;
21use rust_embed::Embed;
22pub use state::AppState;
23use tower_http::{
24    LatencyUnit,
25    cors::{self, CorsLayer},
26    trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
27};
28use tracing::Level;
29
30static INDEX_HTML: &str = "index.html";
31
32#[derive(Embed)]
33#[folder = "ui/dist"]
34struct Assets;
35
36pub fn get_router(state: AppState) -> Router {
37    // Configure CORS
38    let allowed_origin_str = state.config.allowed_origin.clone();
39    let cors = CorsLayer::new()
40        .allow_origin(
41            allowed_origin_str
42                .parse::<HeaderValue>()
43                .unwrap_or_else(|_| panic!("Invalid ALLOWED_ORIGIN: {}", allowed_origin_str)),
44        )
45        .allow_methods(cors::Any)
46        .allow_headers(cors::Any);
47
48    // Define routes that need authentication
49    let api_routes = Router::new()
50        .route("/ping", get(handlers::ping))
51        .route("/databases", get(handlers::list_databases))
52        .route("/databases/{db_name}/tables", get(handlers::list_tables))
53        .route(
54            "/databases/{db_name}/tables/{table_name}/schema",
55            get(handlers::get_table_schema),
56        )
57        .route("/execute-query", post(handlers::execute_query))
58        .route("/schema", get(handlers::get_full_schema))
59        .route("/gen-query", post(handlers::gen_query))
60        .route_layer(middleware::from_fn_with_state(
61            state.clone(),
62            auth::auth_middleware,
63        ));
64
65    // Public routes (like root or maybe login later)
66    Router::new()
67        .nest("/api", api_routes)
68        .layer(cors)
69        .layer(
70            TraceLayer::new_for_http()
71                .on_request(DefaultOnRequest::new().level(Level::INFO))
72                .on_response(
73                    DefaultOnResponse::new()
74                        .level(Level::INFO)
75                        .latency_unit(LatencyUnit::Micros),
76                ),
77        )
78        .fallback(static_handler)
79        .with_state(state)
80}
81
82async fn static_handler(uri: Uri) -> impl IntoResponse {
83    let path = uri.path().trim_start_matches('/');
84
85    if path.is_empty() || path == INDEX_HTML {
86        return index_html().await;
87    }
88
89    match Assets::get(path) {
90        Some(content) => {
91            let mime = mime_guess::from_path(path).first_or_octet_stream();
92
93            ([(header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
94        }
95        None => {
96            if path.contains('.') {
97                return not_found().await;
98            }
99
100            index_html().await
101        }
102    }
103}
104
105async fn index_html() -> Response {
106    match Assets::get(INDEX_HTML) {
107        Some(content) => Html(content.data).into_response(),
108        None => not_found().await,
109    }
110}
111
112async fn not_found() -> Response {
113    (StatusCode::NOT_FOUND, "404").into_response()
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[tokio::test]
121    async fn test_get_router() {
122        // Mock or load a valid config for testing
123        // This might require creating a test config file or mocking AppConfig::load
124        let config = AppConfig::load().unwrap(); // Assumes config files exist
125        let state = AppState::new(config).await.unwrap();
126        let _router = get_router(state);
127        // Basic test passes if it doesn't panic
128    }
129}