use serde_json::json;
use crate::AppState;
use crate::api::client_context::logging_pool;
use crate::athena::contracts::{
AthenaCapabilities, AthenaClientDefinition, AthenaClientLifecycle, AthenaConnection,
AthenaD1Connection, AthenaEngine, AthenaPostgresConnection,
};
use crate::data::clients::{AthenaClientRecord, get_athena_client_by_name};
use crate::drivers::cloudflare_d1::client::D1ConnectionInfo;
use crate::drivers::postgresql::sqlx_driver::RegisteredClient;
use crate::drivers::scylla::client::ScyllaConnectionInfo;
#[derive(Debug, Clone)]
pub enum AthenaClientResolveError {
Inactive {
client_name: String,
},
Frozen {
client_name: String,
},
InvalidMetadata {
client_name: String,
message: String,
},
Lookup {
client_name: String,
message: String,
},
}
impl AthenaClientResolveError {
pub fn client_name(&self) -> &str {
match self {
Self::Inactive { client_name }
| Self::Frozen { client_name }
| Self::InvalidMetadata { client_name, .. }
| Self::Lookup { client_name, .. } => client_name,
}
}
}
#[derive(Debug, Clone)]
pub enum AthenaResolvedQueryBackend {
Postgres(AthenaClientDefinition),
D1 {
client: AthenaClientDefinition,
connection_info: D1ConnectionInfo,
},
Scylla {
client: AthenaClientDefinition,
connection_info: ScyllaConnectionInfo,
},
}
impl AthenaResolvedQueryBackend {
pub fn client(&self) -> &AthenaClientDefinition {
match self {
Self::Postgres(client) => client,
Self::D1 { client, .. } => client,
Self::Scylla { client, .. } => client,
}
}
pub fn d1_connection(&self) -> Option<&D1ConnectionInfo> {
match self {
Self::D1 {
connection_info, ..
} => Some(connection_info),
Self::Postgres(_) | Self::Scylla { .. } => None,
}
}
pub fn scylla_connection(&self) -> Option<&ScyllaConnectionInfo> {
match self {
Self::Scylla {
connection_info, ..
} => Some(connection_info),
Self::Postgres(_) | Self::D1 { .. } => None,
}
}
}
pub async fn resolve_query_backend(
state: &AppState,
client_name: &str,
) -> Result<Option<AthenaResolvedQueryBackend>, AthenaClientResolveError> {
let client_name = client_name.trim();
if client_name.is_empty() {
return Ok(None);
}
if let Some(registered) = state.pg_registry.registered_client(client_name) {
return Ok(Some(AthenaResolvedQueryBackend::Postgres(
definition_from_registered_postgres_client(registered),
)));
}
let pool: sqlx::Pool<sqlx::Postgres> = match logging_pool(state) {
Ok(pool) => pool,
Err(_) => return Ok(None),
};
let record: Option<AthenaClientRecord> = get_athena_client_by_name(&pool, client_name)
.await
.map_err(|err| AthenaClientResolveError::Lookup {
client_name: client_name.to_string(),
message: err.to_string(),
})?;
let Some(record) = record else {
return Ok(None);
};
if !record.is_active {
return Err(AthenaClientResolveError::Inactive {
client_name: record.client_name,
});
}
if record.is_frozen {
return Err(AthenaClientResolveError::Frozen {
client_name: record.client_name,
});
}
let d1_info = D1ConnectionInfo::from_metadata(&record.metadata).map_err(|message| {
AthenaClientResolveError::InvalidMetadata {
client_name: record.client_name.clone(),
message,
}
})?;
if let Some(info) = d1_info {
let client = definition_from_catalog_d1_client(&record, info.clone());
return Ok(Some(AthenaResolvedQueryBackend::D1 {
client,
connection_info: info,
}));
}
let info: Option<ScyllaConnectionInfo> = ScyllaConnectionInfo::from_metadata(&record.metadata)
.map_err(|message| AthenaClientResolveError::InvalidMetadata {
client_name: record.client_name.clone(),
message,
})?;
match info {
Some(info) => {
let client: AthenaClientDefinition =
definition_from_catalog_scylla_client(&record, info.clone());
Ok(Some(AthenaResolvedQueryBackend::Scylla {
client,
connection_info: info,
}))
}
None => Ok(None),
}
}
fn definition_from_registered_postgres_client(
registered: RegisteredClient,
) -> AthenaClientDefinition {
let lifecycle: AthenaClientLifecycle =
lifecycle_from_flags(registered.is_active, registered.is_frozen);
AthenaClientDefinition {
client_name: registered.client_name,
description: registered.description,
source: registered.source,
engine: AthenaEngine::Postgres,
lifecycle,
capabilities: AthenaCapabilities::for_engine(AthenaEngine::Postgres),
connection: AthenaConnection::Postgres(AthenaPostgresConnection {
pg_uri: registered.pg_uri,
pg_uri_env_var: registered.pg_uri_env_var,
config_uri_template: registered.config_uri_template,
}),
metadata: json!({
"pool_connected": registered.pool_connected,
}),
}
}
fn definition_from_catalog_scylla_client(
record: &AthenaClientRecord,
info: ScyllaConnectionInfo,
) -> AthenaClientDefinition {
AthenaClientDefinition {
client_name: record.client_name.clone(),
description: record.description.clone(),
source: record.source.clone(),
engine: AthenaEngine::Scylla,
lifecycle: lifecycle_from_flags(record.is_active, record.is_frozen),
capabilities: AthenaCapabilities::for_engine(AthenaEngine::Scylla),
connection: AthenaConnection::Scylla(info),
metadata: record.metadata.clone(),
}
}
fn definition_from_catalog_d1_client(
record: &AthenaClientRecord,
info: D1ConnectionInfo,
) -> AthenaClientDefinition {
AthenaClientDefinition {
client_name: record.client_name.clone(),
description: record.description.clone(),
source: record.source.clone(),
engine: AthenaEngine::D1,
lifecycle: lifecycle_from_flags(record.is_active, record.is_frozen),
capabilities: AthenaCapabilities::for_engine(AthenaEngine::D1),
connection: AthenaConnection::D1(AthenaD1Connection::from(info)),
metadata: record.metadata.clone(),
}
}
fn lifecycle_from_flags(is_active: bool, is_frozen: bool) -> AthenaClientLifecycle {
if !is_active {
AthenaClientLifecycle::Inactive
} else if is_frozen {
AthenaClientLifecycle::Frozen
} else {
AthenaClientLifecycle::Active
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[test]
fn postgres_capabilities_enable_sql_and_deferred_execution() {
let caps: AthenaCapabilities = AthenaCapabilities::for_engine(AthenaEngine::Postgres);
assert!(caps.raw_sql);
assert!(caps.gateway_query);
assert!(caps.deferred_query);
assert!(!caps.raw_cql);
}
#[test]
fn scylla_capabilities_enable_cql_but_not_deferred_execution() {
let caps: AthenaCapabilities = AthenaCapabilities::for_engine(AthenaEngine::Scylla);
assert!(caps.raw_cql);
assert!(caps.gateway_query);
assert!(!caps.raw_sql);
assert!(!caps.deferred_query);
}
#[test]
fn d1_capabilities_enable_gateway_sql_without_cql() {
let caps: AthenaCapabilities = AthenaCapabilities::for_engine(AthenaEngine::D1);
assert!(caps.raw_sql);
assert!(caps.gateway_query);
assert!(!caps.raw_cql);
assert!(!caps.deferred_query);
}
#[test]
fn registered_postgres_clients_become_athena_postgres_definitions() {
let client: AthenaClientDefinition =
definition_from_registered_postgres_client(RegisteredClient {
client_name: "reporting".to_string(),
source: "config".to_string(),
description: Some("Reporting".to_string()),
pg_uri: Some("postgres://localhost/reporting".to_string()),
pg_uri_env_var: None,
config_uri_template: None,
is_active: true,
is_frozen: false,
pool_connected: true,
});
assert_eq!(client.engine, AthenaEngine::Postgres);
assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
assert!(matches!(client.connection, AthenaConnection::Postgres(_)));
}
#[test]
fn scylla_catalog_clients_become_athena_scylla_definitions() {
let record: AthenaClientRecord = AthenaClientRecord {
id: "id".to_string(),
client_name: "events".to_string(),
description: Some("Events".to_string()),
pg_uri: None,
pg_uri_env_var: None,
config_uri_template: None,
source: "database".to_string(),
is_active: true,
is_frozen: false,
last_synced_from_config_at: None,
last_seen_at: None,
metadata: json!({
"dbEngine": "scylladb",
"scyllaHosts": ["10.0.0.10:9042"],
"scyllaKeyspace": "events"
}),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
let info: ScyllaConnectionInfo = ScyllaConnectionInfo::from_metadata(&record.metadata)
.expect("metadata should parse")
.expect("expected Scylla info");
let client: AthenaClientDefinition = definition_from_catalog_scylla_client(&record, info);
assert_eq!(client.engine, AthenaEngine::Scylla);
assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
assert!(matches!(client.connection, AthenaConnection::Scylla(_)));
}
#[test]
fn d1_catalog_clients_become_athena_d1_definitions() {
let record: AthenaClientRecord = AthenaClientRecord {
id: "id".to_string(),
client_name: "analytics".to_string(),
description: Some("Analytics".to_string()),
pg_uri: None,
pg_uri_env_var: None,
config_uri_template: None,
source: "database".to_string(),
is_active: true,
is_frozen: false,
last_synced_from_config_at: None,
last_seen_at: None,
metadata: json!({
"dbEngine": "cloudflare-d1",
"cloudflareD1": {
"worker_base_url": "https://athena-cloudflare-d1-proxy.xylex-group.workers.dev",
"auth_token_env_var": "ATHENA_D1_PROXY_TOKEN",
"database_binding": "DB"
}
}),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
let info = D1ConnectionInfo::from_metadata(&record.metadata)
.expect("metadata should parse")
.expect("expected D1 info");
let client = definition_from_catalog_d1_client(&record, info);
assert_eq!(client.engine, AthenaEngine::D1);
assert_eq!(client.lifecycle, AthenaClientLifecycle::Active);
assert!(matches!(client.connection, AthenaConnection::D1(_)));
}
}