use crate::config::Config;
use crate::drivers::scylla::health::{self, HostOffline};
use scylla::{
client::session::Session,
client::session_builder::SessionBuilder,
errors::{ExecutionError, RequestAttemptError},
response::query_result::QueryResult,
statement::{Consistency, Statement},
value::{CqlValue, Row},
};
use serde_json::{Map, Number, Value, json};
use std::error::Error;
#[derive(Clone)]
pub struct ScyllaConnectionInfo {
pub host: String,
pub username: String,
pub password: String,
}
impl ScyllaConnectionInfo {
pub fn from_config(config: &Config) -> Self {
let host = config
.get_host("scylladb")
.map(|value| value.clone())
.unwrap_or_else(|| "167.235.9.113:9042".to_string());
let username = config
.get_authenticator("scylladb")
.and_then(|auth| auth.get("username"))
.cloned()
.unwrap_or_default();
let password = config
.get_authenticator("scylladb")
.and_then(|auth| auth.get("password"))
.cloned()
.unwrap_or_default();
Self {
host,
username,
password,
}
}
}
pub fn load_connection_info() -> Result<ScyllaConnectionInfo, Box<dyn Error>> {
let config = Config::load()?;
Ok(ScyllaConnectionInfo::from_config(&config))
}
pub async fn open_session(info: &ScyllaConnectionInfo) -> Result<Session, Box<dyn Error>> {
let tracker = health::global_tracker();
if let Some(deadline) = tracker.offline_until(&info.host) {
return Err(Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>);
}
match SessionBuilder::new()
.known_node(&info.host)
.user(&info.username, &info.password)
.build()
.await
{
Ok(session) => {
tracker.record_success(&info.host);
Ok(session)
}
Err(err) => {
if let Some(deadline) = tracker.record_failure(&info.host) {
return Err(
Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>
);
}
Err(Box::new(err) as Box<dyn Error>)
}
}
}
pub(crate) fn is_connection_error(err: &ExecutionError) -> bool {
matches!(
err,
ExecutionError::ConnectionPoolError(_)
| ExecutionError::LastAttemptError(RequestAttemptError::BrokenConnectionError(_))
)
}
pub async fn execute_query(query: String) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
let info = load_connection_info()?;
let session = open_session(&info).await?;
let statement = build_statement(&query);
match session.query_unpaged(statement, &[]).await {
Ok(result) => map_query_result(result),
Err(err) => handle_query_error(err, &info),
}
}
fn build_statement(sql: &str) -> Statement {
let mut statement = Statement::new(sql);
statement.set_consistency(Consistency::One);
statement.set_page_size(1);
statement
}
fn map_query_result(result: QueryResult) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
if !result.is_rows() {
return Ok((
vec![json!({
"status": "success",
"message": "Query executed successfully"
})],
Vec::new(),
));
}
let rows_result: scylla::response::query_result::QueryRowsResult = result.into_rows_result()?;
let column_specs = rows_result.column_specs();
let columns = column_specs
.iter()
.map(|spec| spec.name().to_string())
.collect::<Vec<_>>();
let mut rows: Vec<Value> = Vec::new();
for row_result in rows_result.rows::<Row>()? {
let row = row_result?;
let mut row_map: Map<String, Value> = Map::new();
for (i, spec) in column_specs.iter().enumerate() {
let column_name: String = spec.name().to_string();
let value: Value = match row.columns.get(i) {
Some(Some(cql_value)) => convert_cql_value(cql_value),
Some(None) | None => Value::Null,
};
row_map.insert(column_name, value);
}
rows.push(Value::Object(row_map));
}
Ok((rows, columns))
}
fn handle_query_error(
err: ExecutionError,
info: &ScyllaConnectionInfo,
) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
if is_connection_error(&err) {
if let Some(deadline) = health::global_tracker().record_failure(&info.host) {
return Err(Box::new(HostOffline::new(info.host.clone(), deadline)));
}
}
Err(Box::new(err))
}
fn convert_cql_value(value: &CqlValue) -> Value {
match value {
CqlValue::Text(s) => Value::String(s.clone()),
CqlValue::Int(i) => Value::Number(Number::from(*i)),
CqlValue::BigInt(b) => Value::Number(Number::from(*b)),
CqlValue::Boolean(b) => Value::Bool(*b),
CqlValue::Float(f) => {
Value::Number(Number::from_f64(*f as f64).unwrap_or_else(|| Number::from(0)))
}
CqlValue::Double(d) => {
Value::Number(Number::from_f64(*d).unwrap_or_else(|| Number::from(0)))
}
CqlValue::Decimal(decimal) => {
let decimal_str = format!("{:?}", decimal);
let decimal_value = if let Some(start) = decimal_str.find("value: ") {
let value_part = &decimal_str[start + 7..];
if let Some(end) = value_part.find(',') {
value_part[..end].parse().unwrap_or(0.0)
} else {
0.0
}
} else {
0.0
};
Value::Number(Number::from_f64(decimal_value).unwrap_or_else(|| Number::from(0)))
}
other => Value::String(format!("{:?}", other)),
}
}