athena_rs 0.75.4

WIP Database API gateway
Documentation
//! Athena RS binary.
//!
//! Starts the Actix Web server, wires endpoints, configures CORS and tracing,
//! and exposes convenience endpoints for Scylla demo queries.
//!
//!
use actix_cors::Cors;
use actix_web::body::{BoxBody, EitherBody};
use actix_web::dev::{Service, ServiceResponse};
use actix_web::http::header;
use actix_web::{App, HttpResponse, HttpServer, Responder, get, web};
use dotenv::dotenv;
use moka::future::Cache;
use reqwest::Client;
use scylla::client::session::Session;
use scylla::client::session_builder::SessionBuilder;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::future::Future;
use std::io::Error;
use std::io::ErrorKind::Other;
use std::io::Result as IoResult;
use tracing_subscriber::fmt::time::ChronoLocal;

use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
use std::net::{SocketAddr, TcpListener};
use std::sync::Arc;
use std::thread::available_parallelism;
use std::time::Duration;
use tracing::info;
use tracing_subscriber::EnvFilter;
use web::Data;

use athena_rs::AppState;
use athena_rs::api::athena_router_registry;
use athena_rs::api::gateway::delete::delete_data;
use athena_rs::api::gateway::fetch::{
    fetch_data_route, gateway_update_route, get_data_route, proxy_fetch_data_route,
};
use athena_rs::api::gateway::insert::insert_data;
use athena_rs::api::gateway::query::gateway_query_route;
use athena_rs::api::query::sql::sql_query;
use athena_rs::api::pipelines::{load_registry_from_path, run_pipeline};
use athena_rs::api::registry::{api_registry, api_registry_by_id};
use athena_rs::api::schema;
use athena_rs::api::supabase::ssl_enforcement;
use athena_rs::config::Config;
use athena_rs::drivers::postgresql::sqlx_driver::PostgresClientRegistry;
use athena_rs::parser::{parse_secs_or_default, parse_usize, resolve_postgres_uri};

#[cfg(test)]
mod tests {
    use super::*;
    use actix_web::{App, test};
    use athena_rs::drivers::postgresql::sqlx_driver::PostgresClientRegistry;
    use serde_json::Value;

    #[actix_web::test]
    async fn test_ping_endpoint() {
        let cache: Arc<Cache<String, Value>> = Arc::new(
            Cache::builder()
                .time_to_live(Duration::from_secs(60))
                .build(),
        );
        let immortal_cache: Arc<Cache<String, Value>> = Arc::new(Cache::builder().build());
        let client = Client::builder()
            .pool_idle_timeout(Duration::from_secs(90))
            .build()
            .expect("Failed to build HTTP client");

        let app_state: AppState = AppState {
            cache,
            immortal_cache,
            client,
            pg_registry: Arc::new(PostgresClientRegistry::empty()),
            gateway_force_camel_case_to_snake_case: false,
            pipeline_registry: None,
        };

        let app = test::init_service(App::new().app_data(Data::new(app_state)).service(ping)).await;
        let req = test::TestRequest::get().uri("/").to_request();
        let resp: ServiceResponse = test::call_service(&app, req).await;

        assert!(resp.status().is_success());

        let body: Value = test::read_body_json(resp).await;
        assert_eq!(body["message"], "athena is online");
        assert_eq!(body["cargo_toml_version"], env!("CARGO_PKG_VERSION"));
    }

    #[test]
    async fn test_app_state_creation() {
        let cache: Arc<Cache<String, Value>> = Arc::new(
            Cache::builder()
                .time_to_live(Duration::from_secs(60))
                .build(),
        );
        let immortal_cache: Arc<Cache<String, Value>> = Arc::new(Cache::builder().build());
        let client: Client = Client::builder()
            .pool_idle_timeout(Duration::from_secs(90))
            .build()
            .expect("Failed to build HTTP client");

        let app_state: AppState = AppState {
            cache,
            immortal_cache,
            client,
            pg_registry: Arc::new(PostgresClientRegistry::empty()),
            gateway_force_camel_case_to_snake_case: false,
            pipeline_registry: None,
        };

        // Test that the app state was created successfully
        assert!(app_state.cache.weighted_size() == 0);
        assert!(app_state.immortal_cache.weighted_size() == 0);
    }

    #[tokio::test]
    async fn test_cache_expiration() {
        let cache: Arc<Cache<String, Value>> = Arc::new(
            Cache::builder()
                .time_to_live(Duration::from_millis(100))
                .build(),
        );

        let key: String = "expiring_key".to_string();
        let value: Value = json!({"expires": true});

        cache.insert(key.clone(), value.clone()).await;

        // Should exist immediately
        assert!(cache.get(&key).await.is_some());

        // Wait for expiration
        tokio::time::sleep(Duration::from_millis(150)).await;

        // Should be expired now
        assert!(cache.get(&key).await.is_none());
    }
}

/// Route entry for the root API listing (PostgREST-style).
fn api_routes() -> Value {
    json!([
        { "path": "/", "methods": ["GET"], "summary": "API root and route listing" },
        { "path": "/query/sql", "methods": ["POST"], "summary": "Execute SQL using selected driver" },
        { "path": "/gateway/data", "methods": ["POST"], "summary": "Fetch data with conditions (table, columns, conditions)" },
        { "path": "/gateway/fetch", "methods": ["POST"], "summary": "Alias for POST /gateway/data" },
        { "path": "/gateway/update", "methods": ["POST"], "summary": "Fetch/update (same payload as /gateway/data)" },
        { "path": "/data", "methods": ["GET"], "summary": "Fetch data via GET (view, eq_column, eq_value, etc.)" },
        { "path": "/gateway/insert", "methods": ["PUT"], "summary": "Insert a row into a table" },
        { "path": "/gateway/delete", "methods": ["DELETE"], "summary": "Delete a row by table_name and resource_id" },
        { "path": "/gateway/query", "methods": ["POST"], "summary": "Run raw SQL against the selected PostgreSQL pool" },
        { "path": "/pipelines", "methods": ["POST"], "summary": "Run a config-driven pipeline (source → transform → sink)" },
        { "path": "/router/registry", "methods": ["GET"], "summary": "List Athena router registry entries" },
        { "path": "/registry", "methods": ["GET"], "summary": "List API registry entries" },
        { "path": "/registry/{api_registry_id}", "methods": ["GET"], "summary": "Get API registry entry by id" },
        { "path": "/schema/clients", "methods": ["GET"], "summary": "List configured Postgres clients" },
        { "path": "/schema/tables", "methods": ["GET"], "summary": "List tables (requires X-Athena-Client)" },
        { "path": "/schema/columns", "methods": ["GET"], "summary": "List columns for a table (requires X-Athena-Client, query: table_name)" },
        { "path": "/api/v2/supabase/ssl_enforcement", "methods": ["POST"], "summary": "Toggle Supabase SSL enforcement for a project" }
    ])
}

#[get("/")]
async fn ping(data: web::Data<AppState>) -> impl Responder {
    info!("Received health check request");

    let scy_ok: bool = cached_health_status(&data.cache, "scylla_health", || check_scylla()).await;
    let scy_status: &str = if scy_ok { "online" } else { "offline" };
    let cargo_version: &str = env!("CARGO_PKG_VERSION");

    HttpResponse::Ok().json(json!({
        "message": "athena is online",
        "version": cargo_version,
        "athena_api": "online",
        "athena_deadpool": "online",
        "athena_scylladb": scy_status,
        "cargo_toml_version": cargo_version,
        "routes": api_routes()
    }))
}

async fn cached_health_status<F, Fut>(
    cache: &Arc<Cache<String, Value>>,
    cache_key: &str,
    check_fn: F,
) -> bool
where
    F: FnOnce() -> Fut,
    Fut: Future<Output = bool>,
{
    if let Some(Value::Bool(cached)) = cache.get(cache_key).await {
        return cached;
    }

    let status: bool = check_fn().await;
    cache
        .insert(cache_key.to_string(), Value::Bool(status))
        .await;
    status
}

async fn check_scylla() -> bool {
    let config: Config = Config::load().expect("Failed to load config.yaml");
    let uri: String = config
        .get_host("scylladb")
        .map(|s| s.clone())
        .unwrap_or_else(|| "127.0.0.1:9042".to_string());
    let empty_map: HashMap<String, String> = HashMap::new();
    let authenticator: &HashMap<String, String> =
        config.get_authenticator("scylladb").unwrap_or(&empty_map);
    let empty_string: String = String::new();
    let session: Session = match SessionBuilder::new()
        .known_node(uri)
        .user(
            authenticator.get("username").unwrap_or(&empty_string),
            authenticator.get("password").unwrap_or(&empty_string),
        )
        .build()
        .await
    {
        Ok(session) => session,
        Err(_) => return false,
    };
    session.query_unpaged("SELECT 1", &[]).await.is_ok()
}

#[actix_web::main]
async fn main() -> IoResult<()> {
    dotenv().ok();
    init_tracing();
    let config: Config = Config::load().expect("Failed to load config.yaml");
    let port: u16 = config
        .get_api()
        .ok_or("No API port configured")
        .and_then(|port_str| port_str.parse().map_err(|_| "Invalid port number"))
        .expect("Failed to parse API port");

    let cache_ttl: u64 = config
        .get_cache_ttl()
        .ok_or("No cache TTL configured")
        .and_then(|ttl_str| ttl_str.parse().map_err(|_| "Invalid cache TTL"))
        .expect("Failed to parse cache TTL");

    let pool_idle_timeout: u64 = config
        .get_pool_idle_timeout()
        .ok_or("No pool idle timeout configured")
        .and_then(|timeout_str| timeout_str.parse().map_err(|_| "Invalid pool idle timeout"))
        .expect("Failed to parse pool idle timeout");

    let cache: Arc<Cache<String, Value>> = Arc::new(
        Cache::builder()
            .time_to_live(Duration::from_secs(cache_ttl))
            .build(),
    );
    let immortal_cache: Arc<Cache<String, Value>> = Arc::new(Cache::builder().build());
    let client: Client = Client::builder()
        .pool_idle_timeout(Duration::from_secs(pool_idle_timeout))
        .build()
        .expect("Failed to build HTTP client");
    let postgres_entries: Vec<(String, String)> = config
        .postgres_clients
        .iter()
        .flat_map(|map| {
            map.iter()
                .map(|(key, uri)| (key.clone(), resolve_postgres_uri(uri)))
        })
        .collect::<Vec<_>>();

    let (registry, failed_connections) = PostgresClientRegistry::from_entries(postgres_entries)
        .await
        .map_err(|err| {
            tracing::error!(error = %err, "Failed to build Postgres registry");
            Error::new(
                Other,
                format!("failed to initialize Postgres clients: {}", err),
            )
        })?;

    if !failed_connections.is_empty() {
        for (client, err) in &failed_connections {
            tracing::warn!(
                client = %client,
                error = %err,
                "Postgres client unavailable, continuing without it"
            );
        }
    }

    if registry.is_empty() {
        tracing::warn!(
            "No Postgres clients connected; Athena API will start without Postgres support"
        );
    }

    let force_camel_case: bool = config.get_gateway_force_camel_case_to_snake_case();
    let pg_registry: Arc<PostgresClientRegistry> = Arc::new(registry);
    let pipeline_registry = load_registry_from_path("config/pipelines.yaml")
        .ok()
        .map(Arc::new);
    let app_state: Data<AppState> = Data::new(AppState {
        cache,
        immortal_cache,
        client,
        pg_registry,
        gateway_force_camel_case_to_snake_case: force_camel_case,
        pipeline_registry,
    });

    let keep_alive: Duration = parse_secs_or_default(config.get_http_keep_alive_secs(), 15);
    let client_disconnect_timeout =
        parse_secs_or_default(config.get_client_disconnect_timeout_secs(), 60);
    let client_request_timeout =
        parse_secs_or_default(config.get_client_request_timeout_secs(), 60);
    let worker_count: usize = config
        .get_http_workers()
        .and_then(parse_usize)
        .unwrap_or_else(|| available_parallelism().map(|n| n.get()).unwrap_or(4));
    let max_connections: usize = config
        .get_http_max_connections()
        .and_then(parse_usize)
        .unwrap_or(10_000);
    let backlog: usize = config
        .get_http_backlog()
        .and_then(parse_usize)
        .unwrap_or(2_048);
    let tcp_keepalive: Duration = parse_secs_or_default(config.get_tcp_keepalive_secs(), 75);

    let addr: SocketAddr = SocketAddr::from(([0, 0, 0, 0], port));
    let socket: Socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
    socket.set_nonblocking(true)?;
    socket.set_keepalive(true)?;
    let keepalive_cfg: TcpKeepalive = TcpKeepalive::new().with_time(tcp_keepalive);
    socket.set_tcp_keepalive(&keepalive_cfg)?;
    socket.bind(&addr.into())?;
    let listen_backlog: i32 = backlog.min(i32::MAX as usize) as i32;
    socket.listen(listen_backlog)?;
    let listener: TcpListener = socket.into();

    HttpServer::new(move || {
        let cors = Cors::default()
            .allow_any_origin()
            .allow_any_method()
            .allow_any_header();
        App::new()
            .wrap(cors)
            .wrap_fn(|req, srv| {
                let fut = srv.call(req);
                async move {
                    let mut res: ServiceResponse<EitherBody<BoxBody>> = fut.await?;
                    res.headers_mut()
                        .insert(header::SERVER, "XYLEX/0".parse().unwrap());
                    Ok(res)
                }
            })
            .app_data(app_state.clone())
            .service(ping)
            .service(sql_query)
            .service(fetch_data_route)
            .service(get_data_route)
            .service(proxy_fetch_data_route)
            .service(gateway_update_route)
            .service(gateway_query_route)
            .service(insert_data)
            .service(delete_data)
            .service(run_pipeline)
            .service(athena_router_registry)
            .service(api_registry)
            .service(api_registry_by_id)
            .configure(schema::services)
            .service(ssl_enforcement)
    })
    .workers(worker_count)
    .keep_alive(keep_alive)
    .client_disconnect_timeout(client_disconnect_timeout)
    .client_request_timeout(client_request_timeout)
    .max_connections(max_connections)
    .backlog(backlog as u32)
    .listen(listener)?
    .run()
    .await
}

fn init_tracing() {
    let filter: EnvFilter =
        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
    tracing_subscriber::fmt()
        .with_env_filter(filter)
        .with_timer(ChronoLocal::new("\x1b[34m%H:%M:%S%.3f\x1b[0m".to_string()))
        .init();
}