athena_rs 3.26.1

Hyper performant polyglot Database driver
Documentation
use serde_json::json;

use crate::AppState;
use crate::api::client_context::logging_pool;
use crate::athena::contracts::{
    AthenaCapabilities, AthenaClientDefinition, AthenaClientLifecycle, AthenaConnection,
    AthenaD1Connection, AthenaEngine, AthenaPostgresConnection,
};
use crate::data::clients::{AthenaClientRecord, get_athena_client_by_name};
use crate::drivers::cloudflare_d1::client::D1ConnectionInfo;
use crate::drivers::postgresql::sqlx_driver::RegisteredClient;
use crate::drivers::scylla::client::ScyllaConnectionInfo;

#[derive(Debug, Clone)]
pub enum AthenaClientResolveError {
    Inactive {
        client_name: String,
    },
    Frozen {
        client_name: String,
    },
    InvalidMetadata {
        client_name: String,
        message: String,
    },
    Lookup {
        client_name: String,
        message: String,
    },
}

impl AthenaClientResolveError {
    pub fn client_name(&self) -> &str {
        match self {
            Self::Inactive { client_name }
            | Self::Frozen { client_name }
            | Self::InvalidMetadata { client_name, .. }
            | Self::Lookup { client_name, .. } => client_name,
        }
    }
}

#[derive(Debug, Clone)]
pub enum AthenaResolvedQueryBackend {
    Postgres(AthenaClientDefinition),
    D1 {
        client: AthenaClientDefinition,
        connection_info: D1ConnectionInfo,
    },
    Scylla {
        client: AthenaClientDefinition,
        connection_info: ScyllaConnectionInfo,
    },
}

impl AthenaResolvedQueryBackend {
    pub fn client(&self) -> &AthenaClientDefinition {
        match self {
            Self::Postgres(client) => client,
            Self::D1 { client, .. } => client,
            Self::Scylla { client, .. } => client,
        }
    }

    pub fn d1_connection(&self) -> Option<&D1ConnectionInfo> {
        match self {
            Self::D1 {
                connection_info, ..
            } => Some(connection_info),
            Self::Postgres(_) | Self::Scylla { .. } => None,
        }
    }

    pub fn scylla_connection(&self) -> Option<&ScyllaConnectionInfo> {
        match self {
            Self::Scylla {
                connection_info, ..
            } => Some(connection_info),
            Self::Postgres(_) | Self::D1 { .. } => None,
        }
    }
}

pub async fn resolve_query_backend(
    state: &AppState,
    client_name: &str,
) -> Result<Option<AthenaResolvedQueryBackend>, AthenaClientResolveError> {
    let client_name = client_name.trim();
    if client_name.is_empty() {
        return Ok(None);
    }

    if let Some(registered) = state.pg_registry.registered_client(client_name) {
        return Ok(Some(AthenaResolvedQueryBackend::Postgres(
            definition_from_registered_postgres_client(registered),
        )));
    }

    let pool: sqlx::Pool<sqlx::Postgres> = match logging_pool(state) {
        Ok(pool) => pool,
        Err(_) => return Ok(None),
    };

    let record: Option<AthenaClientRecord> = get_athena_client_by_name(&pool, client_name)
        .await
        .map_err(|err| AthenaClientResolveError::Lookup {
            client_name: client_name.to_string(),
            message: err.to_string(),
        })?;
    let Some(record) = record else {
        return Ok(None);
    };

    if !record.is_active {
        return Err(AthenaClientResolveError::Inactive {
            client_name: record.client_name,
        });
    }
    if record.is_frozen {
        return Err(AthenaClientResolveError::Frozen {
            client_name: record.client_name,
        });
    }

    let d1_info = D1ConnectionInfo::from_metadata(&record.metadata).map_err(|message| {
        AthenaClientResolveError::InvalidMetadata {
            client_name: record.client_name.clone(),
            message,
        }
    })?;
    if let Some(info) = d1_info {
        let client = definition_from_catalog_d1_client(&record, info.clone());
        return Ok(Some(AthenaResolvedQueryBackend::D1 {
            client,
            connection_info: info,
        }));
    }

    let info: Option<ScyllaConnectionInfo> = ScyllaConnectionInfo::from_metadata(&record.metadata)
        .map_err(|message| AthenaClientResolveError::InvalidMetadata {
            client_name: record.client_name.clone(),
            message,
        })?;
    match info {
        Some(info) => {
            let client: AthenaClientDefinition =
                definition_from_catalog_scylla_client(&record, info.clone());
            Ok(Some(AthenaResolvedQueryBackend::Scylla {
                client,
                connection_info: info,
            }))
        }
        None => Ok(None),
    }
}

fn definition_from_registered_postgres_client(
    registered: RegisteredClient,
) -> AthenaClientDefinition {
    let lifecycle: AthenaClientLifecycle =
        lifecycle_from_flags(registered.is_active, registered.is_frozen);
    AthenaClientDefinition {
        client_name: registered.client_name,
        description: registered.description,
        source: registered.source,
        engine: AthenaEngine::Postgres,
        lifecycle,
        capabilities: AthenaCapabilities::for_engine(AthenaEngine::Postgres),
        connection: AthenaConnection::Postgres(AthenaPostgresConnection {
            pg_uri: registered.pg_uri,
            pg_uri_env_var: registered.pg_uri_env_var,
            config_uri_template: registered.config_uri_template,
        }),
        metadata: json!({
            "pool_connected": registered.pool_connected,
        }),
    }
}

fn definition_from_catalog_scylla_client(
    record: &AthenaClientRecord,
    info: ScyllaConnectionInfo,
) -> AthenaClientDefinition {
    AthenaClientDefinition {
        client_name: record.client_name.clone(),
        description: record.description.clone(),
        source: record.source.clone(),
        engine: AthenaEngine::Scylla,
        lifecycle: lifecycle_from_flags(record.is_active, record.is_frozen),
        capabilities: AthenaCapabilities::for_engine(AthenaEngine::Scylla),
        connection: AthenaConnection::Scylla(info),
        metadata: record.metadata.clone(),
    }
}

fn definition_from_catalog_d1_client(
    record: &AthenaClientRecord,
    info: D1ConnectionInfo,
) -> AthenaClientDefinition {
    AthenaClientDefinition {
        client_name: record.client_name.clone(),
        description: record.description.clone(),
        source: record.source.clone(),
        engine: AthenaEngine::D1,
        lifecycle: lifecycle_from_flags(record.is_active, record.is_frozen),
        capabilities: AthenaCapabilities::for_engine(AthenaEngine::D1),
        connection: AthenaConnection::D1(AthenaD1Connection::from(info)),
        metadata: record.metadata.clone(),
    }
}

fn lifecycle_from_flags(is_active: bool, is_frozen: bool) -> AthenaClientLifecycle {
    if !is_active {
        AthenaClientLifecycle::Inactive
    } else if is_frozen {
        AthenaClientLifecycle::Frozen
    } else {
        AthenaClientLifecycle::Active
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use chrono::Utc;

    #[test]
    fn postgres_capabilities_enable_sql_and_deferred_execution() {
        let caps: AthenaCapabilities = AthenaCapabilities::for_engine(AthenaEngine::Postgres);
        assert!(caps.raw_sql);
        assert!(caps.gateway_query);
        assert!(caps.deferred_query);
        assert!(!caps.raw_cql);
    }

    #[test]
    fn scylla_capabilities_enable_cql_but_not_deferred_execution() {
        let caps: AthenaCapabilities = AthenaCapabilities::for_engine(AthenaEngine::Scylla);
        assert!(caps.raw_cql);
        assert!(caps.gateway_query);
        assert!(!caps.raw_sql);
        assert!(!caps.deferred_query);
    }

    #[test]
    fn d1_capabilities_enable_gateway_sql_without_cql() {
        let caps: AthenaCapabilities = AthenaCapabilities::for_engine(AthenaEngine::D1);
        assert!(caps.raw_sql);
        assert!(caps.gateway_query);
        assert!(!caps.raw_cql);
        assert!(!caps.deferred_query);
    }

    #[test]
    fn registered_postgres_clients_become_athena_postgres_definitions() {
        let client: AthenaClientDefinition =
            definition_from_registered_postgres_client(RegisteredClient {
                client_name: "reporting".to_string(),
                source: "config".to_string(),
                description: Some("Reporting".to_string()),
                pg_uri: Some("postgres://localhost/reporting".to_string()),
                pg_uri_env_var: None,
                config_uri_template: None,
                is_active: true,
                is_frozen: false,
                pool_connected: true,
            });

        assert_eq!(client.engine, AthenaEngine::Postgres);
        assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
        assert!(matches!(client.connection, AthenaConnection::Postgres(_)));
    }

    #[test]

    fn scylla_catalog_clients_become_athena_scylla_definitions() {
        let record: AthenaClientRecord = AthenaClientRecord {
            id: "id".to_string(),
            client_name: "events".to_string(),
            description: Some("Events".to_string()),
            pg_uri: None,
            pg_uri_env_var: None,
            config_uri_template: None,
            source: "database".to_string(),
            is_active: true,
            is_frozen: false,
            last_synced_from_config_at: None,
            last_seen_at: None,
            metadata: json!({
                "dbEngine": "scylladb",
                "scyllaHosts": ["10.0.0.10:9042"],
                "scyllaKeyspace": "events"
            }),
            created_at: Utc::now(),
            updated_at: Utc::now(),
            deleted_at: None,
        };
        let info: ScyllaConnectionInfo = ScyllaConnectionInfo::from_metadata(&record.metadata)
            .expect("metadata should parse")
            .expect("expected Scylla info");
        let client: AthenaClientDefinition = definition_from_catalog_scylla_client(&record, info);

        assert_eq!(client.engine, AthenaEngine::Scylla);
        assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
        assert!(matches!(client.connection, AthenaConnection::Scylla(_)));
    }

    #[test]
    fn d1_catalog_clients_become_athena_d1_definitions() {
        let record: AthenaClientRecord = AthenaClientRecord {
            id: "id".to_string(),
            client_name: "analytics".to_string(),
            description: Some("Analytics".to_string()),
            pg_uri: None,
            pg_uri_env_var: None,
            config_uri_template: None,
            source: "database".to_string(),
            is_active: true,
            is_frozen: false,
            last_synced_from_config_at: None,
            last_seen_at: None,
            metadata: json!({
                "dbEngine": "cloudflare-d1",
                "cloudflareD1": {
                    "worker_base_url": "https://athena-cloudflare-d1-proxy.xylex-group.workers.dev",
                    "auth_token_env_var": "ATHENA_D1_PROXY_TOKEN",
                    "database_binding": "DB"
                }
            }),
            created_at: Utc::now(),
            updated_at: Utc::now(),
            deleted_at: None,
        };
        let info = D1ConnectionInfo::from_metadata(&record.metadata)
            .expect("metadata should parse")
            .expect("expected D1 info");
        let client = definition_from_catalog_d1_client(&record, info);

        assert_eq!(client.engine, AthenaEngine::D1);
        assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
        assert!(matches!(client.connection, AthenaConnection::D1(_)));
    }
}