athena-driver 3.18.0

Backend driver primitives for Athena, starting with Scylla and Supabase health-aware clients
Documentation
//! Execute CQL queries and share Scylla connection helpers.
use super::health::{self, HostOffline};

use scylla::client::session_builder::GenericSessionBuilder;
use scylla::{
    client::session::Session,
    client::session_builder::SessionBuilder,
    errors::{ExecutionError, RequestAttemptError},
    response::query_result::QueryResult,
    statement::{Consistency, Statement},
    value::{CqlValue, Row},
};

use serde_json::{Map, Number, Value, json};
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Debug, Formatter, Result as FmtResult};

/// Configuration that describes how to reach the Scylla host.
#[derive(Clone)]
pub struct ScyllaConnectionInfo {
    pub host: String,
    pub extra_hosts: Vec<String>,
    pub keyspace: Option<String>,
    pub username: String,
    pub password: String,
}

impl Debug for ScyllaConnectionInfo {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        f.debug_struct("ScyllaConnectionInfo")
            .field("host", &self.host)
            .field("extra_hosts", &self.extra_hosts)
            .field("keyspace", &self.keyspace)
            .field("username", &self.username)
            .field("password", &"<redacted>")
            .finish()
    }
}

impl ScyllaConnectionInfo {
    /// Build a Scylla connection target from resolved host/auth settings.
    pub fn from_resolved_parts(hosts: Vec<String>, auth: Option<HashMap<String, String>>) -> Self {
        let host: String = hosts
            .first()
            .cloned()
            .unwrap_or_else(|| "167.235.9.113:9042".to_string());
        let extra_hosts: Vec<String> = hosts.into_iter().skip(1).collect::<Vec<_>>();
        let username: String = auth
            .as_ref()
            .and_then(|resolved| resolved.get("username"))
            .cloned()
            .unwrap_or_default();
        let password: String = auth
            .as_ref()
            .and_then(|resolved| resolved.get("password"))
            .cloned()
            .unwrap_or_default();
        let keyspace: Option<String> = auth
            .as_ref()
            .and_then(|resolved| resolved.get("keyspace"))
            .cloned();

        Self {
            host,
            extra_hosts,
            keyspace,
            username,
            password,
        }
    }

    pub fn from_metadata(metadata: &Value) -> Result<Option<Self>, String> {
        let Some(object) = metadata.as_object() else {
            return Ok(None);
        };

        let declares_scylla: bool = metadata_declares_scylla(object);
        let nested: Option<&Map<String, Value>> = object
            .get("scylla")
            .and_then(Value::as_object)
            .or_else(|| object.get("scylladb").and_then(Value::as_object))
            .or_else(|| object.get("connection").and_then(Value::as_object));
        let hosts: Option<Vec<String>> = nested
            .and_then(|value| value.get("hosts"))
            .and_then(parse_hosts_value)
            .or_else(|| object.get("scyllaHosts").and_then(parse_hosts_value));
        let keyspace: Option<String> = nested
            .and_then(|value| value.get("keyspace"))
            .and_then(optional_trimmed_string)
            .or_else(|| {
                object
                    .get("scyllaKeyspace")
                    .and_then(optional_trimmed_string)
            });
        let username: String = nested
            .and_then(|value| value.get("username"))
            .and_then(optional_trimmed_string)
            .or_else(|| {
                object
                    .get("scyllaUsername")
                    .and_then(optional_trimmed_string)
            })
            .unwrap_or_default();
        let password: String = nested
            .and_then(|value| value.get("password"))
            .and_then(optional_trimmed_string)
            .or_else(|| {
                object
                    .get("scyllaPassword")
                    .and_then(optional_trimmed_string)
            })
            .unwrap_or_default();

        match hosts {
            Some(hosts) if !hosts.is_empty() => {
                let host: String = hosts[0].clone();
                let extra_hosts: Vec<String> = hosts.into_iter().skip(1).collect::<Vec<_>>();
                Ok(Some(Self {
                    host,
                    extra_hosts,
                    keyspace,
                    username,
                    password,
                }))
            }
            _ if declares_scylla => Err(
                "metadata marks the client as Scylla-backed but does not include any Scylla hosts"
                    .to_string(),
            ),
            _ => Ok(None),
        }
    }

    pub fn known_hosts(&self) -> Vec<&str> {
        let mut hosts: Vec<&str> = Vec::with_capacity(self.extra_hosts.len() + 1);
        hosts.push(self.host.as_str());
        hosts.extend(self.extra_hosts.iter().map(String::as_str));
        hosts
    }
}

/// Open a session to the configured host while respecting the in-memory tracker.
pub async fn open_session(info: &ScyllaConnectionInfo) -> Result<Session, Box<dyn Error>> {
    let tracker: &health::HostHealthTracker<health::SystemClock> = health::global_tracker();

    if let Some(deadline) = tracker.offline_until(&info.host) {
        return Err(Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>);
    }

    let mut builder: GenericSessionBuilder<scylla::client::session_builder::DefaultMode> =
        SessionBuilder::new();
    for host in info.known_hosts() {
        builder = builder.known_node(host);
    }
    if !info.username.is_empty() || !info.password.is_empty() {
        builder = builder.user(&info.username, &info.password);
    }

    match builder.build().await {
        Ok(session) => {
            if let Some(keyspace) = info.keyspace.as_deref() {
                session.use_keyspace(keyspace, false).await?;
            }
            tracker.record_success(&info.host);
            Ok(session)
        }

        Err(err) => {
            if let Some(deadline) = tracker.record_failure(&info.host) {
                return Err(
                    Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>
                );
            }
            Err(Box::new(err) as Box<dyn Error>)
        }
    }
}

pub(crate) fn is_connection_error(err: &ExecutionError) -> bool {
    matches!(
        err,
        ExecutionError::ConnectionPoolError(_)
            | ExecutionError::LastAttemptError(RequestAttemptError::BrokenConnectionError(_))
    )
}

/// Execute the provided CQL query against an explicit Scylla connection target.
pub async fn execute_query_with_info(
    query: String,
    info: &ScyllaConnectionInfo,
) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
    let session: Session = open_session(&info).await?;
    let statement: Statement = build_statement(&query);

    match session.query_unpaged(statement, &[]).await {
        Ok(result) => map_query_result(result),
        Err(err) => handle_query_error(err, &info),
    }
}

fn build_statement(sql: &str) -> Statement {
    let mut statement: Statement = Statement::new(sql);
    statement.set_consistency(Consistency::One);
    statement.set_page_size(1);
    statement
}

fn map_query_result(result: QueryResult) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
    if !result.is_rows() {
        return Ok((
            vec![json!({
                "status": "success",
                "message": "Query executed successfully"
            })],
            Vec::new(),
        ));
    }

    let rows_result: scylla::response::query_result::QueryRowsResult = result.into_rows_result()?;
    let column_specs: scylla::response::query_result::ColumnSpecs<'_, '_> =
        rows_result.column_specs();
    let columns: Vec<String> = column_specs
        .iter()
        .map(|spec| spec.name().to_string())
        .collect::<Vec<_>>();

    let mut rows: Vec<Value> = Vec::new();
    for row_result in rows_result.rows::<Row>()? {
        let row: Row = row_result?;
        let mut row_map: Map<String, Value> = Map::new();

        for (i, spec) in column_specs.iter().enumerate() {
            let column_name: String = spec.name().to_string();
            let value: Value = match row.columns.get(i) {
                Some(Some(cql_value)) => convert_cql_value(cql_value),
                Some(None) | None => Value::Null,
            };
            row_map.insert(column_name, value);
        }

        rows.push(Value::Object(row_map));
    }

    Ok((rows, columns))
}

fn parse_hosts(raw: &str) -> Vec<String> {
    raw.split(',')
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(str::to_string)
        .collect()
}

fn parse_hosts_value(value: &Value) -> Option<Vec<String>> {
    match value {
        Value::String(raw) => Some(parse_hosts(raw)),
        Value::Array(values) => Some(
            values
                .iter()
                .filter_map(Value::as_str)
                .map(str::trim)
                .filter(|value| !value.is_empty())
                .map(str::to_string)
                .collect(),
        ),
        _ => None,
    }
}

fn optional_trimmed_string(value: &Value) -> Option<String> {
    value
        .as_str()
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(str::to_string)
}

fn metadata_declares_scylla(metadata: &Map<String, Value>) -> bool {
    metadata
        .get("dbEngine")
        .or_else(|| metadata.get("backend"))
        .or_else(|| metadata.get("engine"))
        .and_then(Value::as_str)
        .map(|value| {
            let normalized = value.trim().to_ascii_lowercase();
            normalized == "scylla" || normalized == "scylladb"
        })
        .unwrap_or(false)
        || metadata.contains_key("scylla")
        || metadata.contains_key("scylladb")
        || metadata.contains_key("scyllaHosts")
}

fn handle_query_error(
    err: ExecutionError,
    info: &ScyllaConnectionInfo,
) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
    if is_connection_error(&err)
        && let Some(deadline) = health::global_tracker().record_failure(&info.host)
    {
        return Err(Box::new(HostOffline::new(info.host.clone(), deadline)));
    }
    Err(Box::new(err))
}

fn convert_cql_value(value: &CqlValue) -> Value {
    match value {
        CqlValue::Text(s) => Value::String(s.clone()),
        CqlValue::Int(i) => Value::Number(Number::from(*i)),
        CqlValue::BigInt(b) => Value::Number(Number::from(*b)),
        CqlValue::Boolean(b) => Value::Bool(*b),
        CqlValue::Float(f) => {
            Value::Number(Number::from_f64(*f as f64).unwrap_or_else(|| Number::from(0)))
        }
        CqlValue::Double(d) => {
            Value::Number(Number::from_f64(*d).unwrap_or_else(|| Number::from(0)))
        }
        CqlValue::Decimal(decimal) => {
            let decimal_str = format!("{:?}", decimal);
            let decimal_value = if let Some(start) = decimal_str.find("value: ") {
                let value_part = &decimal_str[start + 7..];
                if let Some(end) = value_part.find(',') {
                    value_part[..end].parse().unwrap_or(0.0)
                } else {
                    0.0
                }
            } else {
                0.0
            };
            Value::Number(Number::from_f64(decimal_value).unwrap_or_else(|| Number::from(0)))
        }
        other => Value::String(format!("{:?}", other)),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn metadata_parsing_accepts_flat_scylla_fields() {
        let info: ScyllaConnectionInfo = ScyllaConnectionInfo::from_metadata(&json!({
            "dbEngine": "scylladb",
            "scyllaHosts": "10.0.0.10:9042,10.0.0.11:9042",
            "scyllaKeyspace": "analytics",
            "scyllaUsername": "cassandra",
            "scyllaPassword": "secret"
        }))
        .expect("metadata should parse")
        .expect("expected Scylla info");

        assert_eq!(info.host, "10.0.0.10:9042");
        assert_eq!(info.extra_hosts, vec!["10.0.0.11:9042".to_string()]);
        assert_eq!(info.keyspace.as_deref(), Some("analytics"));
        assert_eq!(info.username, "cassandra");
        assert_eq!(info.password, "secret");
    }

    #[test]
    fn metadata_parsing_accepts_nested_connection_shape() {
        let info: ScyllaConnectionInfo = ScyllaConnectionInfo::from_metadata(&json!({
            "backend": "scylla",
            "scylla": {
                "hosts": ["10.0.0.20:9042", "10.0.0.21:9042"],
                "keyspace": "events"
            }
        }))
        .expect("metadata should parse")
        .expect("expected Scylla info");

        assert_eq!(info.host, "10.0.0.20:9042");
        assert_eq!(info.extra_hosts, vec!["10.0.0.21:9042".to_string()]);
        assert_eq!(info.keyspace.as_deref(), Some("events"));
    }

    #[test]
    fn metadata_parsing_rejects_missing_hosts_for_scylla_clients() {
        match ScyllaConnectionInfo::from_metadata(&json!({
            "dbEngine": "scylladb"
        })) {
            Err(error) => {
                assert!(error.contains("does not include any Scylla hosts"));
            }
            Ok(_) => panic!("missing hosts should fail"),
        }
    }

    #[test]
    fn metadata_parsing_ignores_non_scylla_clients() {
        let info: Option<ScyllaConnectionInfo> = ScyllaConnectionInfo::from_metadata(&json!({
            "dbEngine": "postgres",
            "pg_uri": "postgres://localhost/db"
        }))
        .expect("metadata should not fail");

        assert!(info.is_none());
    }
}