athena_rs 3.3.0

Database gateway API
Documentation
//! Execute CQL queries and share Scylla connection helpers.
use crate::config::Config;
use crate::drivers::scylla::health::{self, HostOffline};
use scylla::{
    client::session::Session,
    client::session_builder::SessionBuilder,
    errors::{ExecutionError, RequestAttemptError},
    response::query_result::QueryResult,
    statement::{Consistency, Statement},
    value::{CqlValue, Row},
};
use serde_json::{Map, Number, Value, json};
use std::error::Error;

/// Configuration that describes how to reach the Scylla host.
#[derive(Clone)]
pub struct ScyllaConnectionInfo {
    pub host: String,
    pub username: String,
    pub password: String,
}

impl ScyllaConnectionInfo {
    pub fn from_config(config: &Config) -> Self {
        let host = config
            .get_host("scylladb")
            .cloned()
            .unwrap_or_else(|| "167.235.9.113:9042".to_string());
        let username = config
            .get_authenticator("scylladb")
            .and_then(|auth| auth.get("username"))
            .cloned()
            .unwrap_or_default();
        let password = config
            .get_authenticator("scylladb")
            .and_then(|auth| auth.get("password"))
            .cloned()
            .unwrap_or_default();
        Self {
            host,
            username,
            password,
        }
    }
}

/// Load the Scylla connection configuration from disk.
pub fn load_connection_info() -> Result<ScyllaConnectionInfo, Box<dyn Error>> {
    let config = Config::load()?;
    Ok(ScyllaConnectionInfo::from_config(&config))
}

/// Open a session to the configured host while respecting the in-memory tracker.
pub async fn open_session(info: &ScyllaConnectionInfo) -> Result<Session, Box<dyn Error>> {
    let tracker = health::global_tracker();

    if let Some(deadline) = tracker.offline_until(&info.host) {
        return Err(Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>);
    }

    match SessionBuilder::new()
        .known_node(&info.host)
        .user(&info.username, &info.password)
        .build()
        .await
    {
        Ok(session) => {
            tracker.record_success(&info.host);
            Ok(session)
        }

        Err(err) => {
            if let Some(deadline) = tracker.record_failure(&info.host) {
                return Err(
                    Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>
                );
            }
            Err(Box::new(err) as Box<dyn Error>)
        }
    }
}

pub(crate) fn is_connection_error(err: &ExecutionError) -> bool {
    matches!(
        err,
        ExecutionError::ConnectionPoolError(_)
            | ExecutionError::LastAttemptError(RequestAttemptError::BrokenConnectionError(_))
    )
}

/// Execute the provided CQL query and return the normalized JSON rows.
pub async fn execute_query(query: String) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
    let info = load_connection_info()?;
    let session = open_session(&info).await?;
    let statement = build_statement(&query);

    match session.query_unpaged(statement, &[]).await {
        Ok(result) => map_query_result(result),
        Err(err) => handle_query_error(err, &info),
    }
}

fn build_statement(sql: &str) -> Statement {
    let mut statement = Statement::new(sql);
    statement.set_consistency(Consistency::One);
    statement.set_page_size(1);
    statement
}

fn map_query_result(result: QueryResult) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
    if !result.is_rows() {
        return Ok((
            vec![json!({
                "status": "success",
                "message": "Query executed successfully"
            })],
            Vec::new(),
        ));
    }

    let rows_result: scylla::response::query_result::QueryRowsResult = result.into_rows_result()?;
    let column_specs = rows_result.column_specs();
    let columns = column_specs
        .iter()
        .map(|spec| spec.name().to_string())
        .collect::<Vec<_>>();

    let mut rows: Vec<Value> = Vec::new();
    for row_result in rows_result.rows::<Row>()? {
        let row = row_result?;
        let mut row_map: Map<String, Value> = Map::new();

        for (i, spec) in column_specs.iter().enumerate() {
            let column_name: String = spec.name().to_string();
            let value: Value = match row.columns.get(i) {
                Some(Some(cql_value)) => convert_cql_value(cql_value),
                Some(None) | None => Value::Null,
            };
            row_map.insert(column_name, value);
        }

        rows.push(Value::Object(row_map));
    }

    Ok((rows, columns))
}

fn handle_query_error(
    err: ExecutionError,
    info: &ScyllaConnectionInfo,
) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
    if is_connection_error(&err)
        && let Some(deadline) = health::global_tracker().record_failure(&info.host)
    {
        return Err(Box::new(HostOffline::new(info.host.clone(), deadline)));
    }
    Err(Box::new(err))
}

fn convert_cql_value(value: &CqlValue) -> Value {
    match value {
        CqlValue::Text(s) => Value::String(s.clone()),
        CqlValue::Int(i) => Value::Number(Number::from(*i)),
        CqlValue::BigInt(b) => Value::Number(Number::from(*b)),
        CqlValue::Boolean(b) => Value::Bool(*b),
        CqlValue::Float(f) => {
            Value::Number(Number::from_f64(*f as f64).unwrap_or_else(|| Number::from(0)))
        }
        CqlValue::Double(d) => {
            Value::Number(Number::from_f64(*d).unwrap_or_else(|| Number::from(0)))
        }
        CqlValue::Decimal(decimal) => {
            let decimal_str = format!("{:?}", decimal);
            let decimal_value = if let Some(start) = decimal_str.find("value: ") {
                let value_part = &decimal_str[start + 7..];
                if let Some(end) = value_part.find(',') {
                    value_part[..end].parse().unwrap_or(0.0)
                } else {
                    0.0
                }
            } else {
                0.0
            };
            Value::Number(Number::from_f64(decimal_value).unwrap_or_else(|| Number::from(0)))
        }
        other => Value::String(format!("{:?}", other)),
    }
}