athena_rs 3.9.0

Hyper performant polyglot Database driver
Documentation
use serde_json::json;

use crate::AppState;
use crate::api::client_context::logging_pool;
use crate::athena::contracts::{
    AthenaCapabilities, AthenaClientDefinition, AthenaClientLifecycle, AthenaConnection,
    AthenaEngine, AthenaPostgresConnection,
};
use crate::data::clients::{AthenaClientRecord, get_athena_client_by_name};
use crate::drivers::postgresql::sqlx_driver::RegisteredClient;
use crate::drivers::scylla::client::ScyllaConnectionInfo;

#[derive(Debug, Clone)]
pub enum AthenaClientResolveError {
    Inactive { client_name: String },
    Frozen { client_name: String },
    InvalidMetadata { client_name: String, message: String },
    Lookup { client_name: String, message: String },
}

impl AthenaClientResolveError {
    pub fn client_name(&self) -> &str {
        match self {
            Self::Inactive { client_name }
            | Self::Frozen { client_name }
            | Self::InvalidMetadata { client_name, .. }
            | Self::Lookup { client_name, .. } => client_name,
        }
    }
}

#[derive(Debug, Clone)]
pub enum AthenaResolvedQueryBackend {
    Postgres(AthenaClientDefinition),
    Scylla {
        client: AthenaClientDefinition,
        connection_info: ScyllaConnectionInfo,
    },
}

impl AthenaResolvedQueryBackend {
    pub fn client(&self) -> &AthenaClientDefinition {
        match self {
            Self::Postgres(client) => client,
            Self::Scylla { client, .. } => client,
        }
    }

    pub fn scylla_connection(&self) -> Option<&ScyllaConnectionInfo> {
        match self {
            Self::Scylla {
                connection_info, ..
            } => Some(connection_info),
            Self::Postgres(_) => None,
        }
    }
}

pub async fn resolve_query_backend(
    state: &AppState,
    client_name: &str,
) -> Result<Option<AthenaResolvedQueryBackend>, AthenaClientResolveError> {
    let client_name = client_name.trim();
    if client_name.is_empty() {
        return Ok(None);
    }

    if let Some(registered) = state.pg_registry.registered_client(client_name) {
        return Ok(Some(AthenaResolvedQueryBackend::Postgres(
            definition_from_registered_postgres_client(registered),
        )));
    }

    let pool = match logging_pool(state) {
        Ok(pool) => pool,
        Err(_) => return Ok(None),
    };

    let record = get_athena_client_by_name(&pool, client_name)
        .await
        .map_err(|err| AthenaClientResolveError::Lookup {
            client_name: client_name.to_string(),
            message: err.to_string(),
        })?;
    let Some(record) = record else {
        return Ok(None);
    };

    if !record.is_active {
        return Err(AthenaClientResolveError::Inactive {
            client_name: record.client_name,
        });
    }
    if record.is_frozen {
        return Err(AthenaClientResolveError::Frozen {
            client_name: record.client_name,
        });
    }

    let info = ScyllaConnectionInfo::from_metadata(&record.metadata).map_err(|message| {
        AthenaClientResolveError::InvalidMetadata {
            client_name: record.client_name.clone(),
            message,
        }
    })?;
    match info {
        Some(info) => {
            let client = definition_from_catalog_scylla_client(&record, info.clone());
            Ok(Some(AthenaResolvedQueryBackend::Scylla {
                client,
                connection_info: info,
            }))
        }
        None => Ok(None),
    }
}

fn definition_from_registered_postgres_client(
    registered: RegisteredClient,
) -> AthenaClientDefinition {
    let lifecycle = lifecycle_from_flags(registered.is_active, registered.is_frozen);
    AthenaClientDefinition {
        client_name: registered.client_name,
        description: registered.description,
        source: registered.source,
        engine: AthenaEngine::Postgres,
        lifecycle,
        capabilities: AthenaCapabilities::for_engine(AthenaEngine::Postgres),
        connection: AthenaConnection::Postgres(AthenaPostgresConnection {
            pg_uri: registered.pg_uri,
            pg_uri_env_var: registered.pg_uri_env_var,
            config_uri_template: registered.config_uri_template,
        }),
        metadata: json!({
            "pool_connected": registered.pool_connected,
        }),
    }
}

fn definition_from_catalog_scylla_client(
    record: &AthenaClientRecord,
    info: ScyllaConnectionInfo,
) -> AthenaClientDefinition {
    AthenaClientDefinition {
        client_name: record.client_name.clone(),
        description: record.description.clone(),
        source: record.source.clone(),
        engine: AthenaEngine::Scylla,
        lifecycle: lifecycle_from_flags(record.is_active, record.is_frozen),
        capabilities: AthenaCapabilities::for_engine(AthenaEngine::Scylla),
        connection: AthenaConnection::Scylla(info),
        metadata: record.metadata.clone(),
    }
}

fn lifecycle_from_flags(is_active: bool, is_frozen: bool) -> AthenaClientLifecycle {
    if !is_active {
        AthenaClientLifecycle::Inactive
    } else if is_frozen {
        AthenaClientLifecycle::Frozen
    } else {
        AthenaClientLifecycle::Active
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use chrono::Utc;

    #[test]
    fn postgres_capabilities_enable_sql_and_deferred_execution() {
        let caps = AthenaCapabilities::for_engine(AthenaEngine::Postgres);
        assert!(caps.raw_sql);
        assert!(caps.gateway_query);
        assert!(caps.deferred_query);
        assert!(!caps.raw_cql);
    }

    #[test]
    fn scylla_capabilities_enable_cql_but_not_deferred_execution() {
        let caps = AthenaCapabilities::for_engine(AthenaEngine::Scylla);
        assert!(caps.raw_cql);
        assert!(caps.gateway_query);
        assert!(!caps.raw_sql);
        assert!(!caps.deferred_query);
    }

    #[test]
    fn registered_postgres_clients_become_athena_postgres_definitions() {
        let client = definition_from_registered_postgres_client(RegisteredClient {
            client_name: "reporting".to_string(),
            source: "config".to_string(),
            description: Some("Reporting".to_string()),
            pg_uri: Some("postgres://localhost/reporting".to_string()),
            pg_uri_env_var: None,
            config_uri_template: None,
            is_active: true,
            is_frozen: false,
            pool_connected: true,
        });

        assert_eq!(client.engine, AthenaEngine::Postgres);
        assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
        assert!(matches!(client.connection, AthenaConnection::Postgres(_)));
    }

    #[test]
    fn scylla_catalog_clients_become_athena_scylla_definitions() {
        let record = AthenaClientRecord {
            id: "id".to_string(),
            client_name: "events".to_string(),
            description: Some("Events".to_string()),
            pg_uri: None,
            pg_uri_env_var: None,
            config_uri_template: None,
            source: "database".to_string(),
            is_active: true,
            is_frozen: false,
            last_synced_from_config_at: None,
            last_seen_at: None,
            metadata: json!({
                "dbEngine": "scylladb",
                "scyllaHosts": ["10.0.0.10:9042"],
                "scyllaKeyspace": "events"
            }),
            created_at: Utc::now(),
            updated_at: Utc::now(),
            deleted_at: None,
        };
        let info = ScyllaConnectionInfo::from_metadata(&record.metadata)
            .expect("metadata should parse")
            .expect("expected Scylla info");
        let client = definition_from_catalog_scylla_client(&record, info);

        assert_eq!(client.engine, AthenaEngine::Scylla);
        assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
        assert!(matches!(client.connection, AthenaConnection::Scylla(_)));
    }
}