use super::health::{self, HostOffline};
use scylla::client::session_builder::GenericSessionBuilder;
use scylla::{
client::session::Session,
client::session_builder::SessionBuilder,
errors::{ExecutionError, RequestAttemptError},
response::query_result::QueryResult,
statement::{Consistency, Statement},
value::{CqlValue, Row},
};
use serde_json::{Map, Number, Value, json};
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Debug, Formatter, Result as FmtResult};
#[derive(Clone)]
pub struct ScyllaConnectionInfo {
pub host: String,
pub extra_hosts: Vec<String>,
pub keyspace: Option<String>,
pub username: String,
pub password: String,
}
impl Debug for ScyllaConnectionInfo {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("ScyllaConnectionInfo")
.field("host", &self.host)
.field("extra_hosts", &self.extra_hosts)
.field("keyspace", &self.keyspace)
.field("username", &self.username)
.field("password", &"<redacted>")
.finish()
}
}
impl ScyllaConnectionInfo {
pub fn from_resolved_parts(hosts: Vec<String>, auth: Option<HashMap<String, String>>) -> Self {
let host: String = hosts
.first()
.cloned()
.unwrap_or_else(|| "167.235.9.113:9042".to_string());
let extra_hosts: Vec<String> = hosts.into_iter().skip(1).collect::<Vec<_>>();
let username: String = auth
.as_ref()
.and_then(|resolved| resolved.get("username"))
.cloned()
.unwrap_or_default();
let password: String = auth
.as_ref()
.and_then(|resolved| resolved.get("password"))
.cloned()
.unwrap_or_default();
let keyspace: Option<String> = auth
.as_ref()
.and_then(|resolved| resolved.get("keyspace"))
.cloned();
Self {
host,
extra_hosts,
keyspace,
username,
password,
}
}
pub fn from_metadata(metadata: &Value) -> Result<Option<Self>, String> {
let Some(object) = metadata.as_object() else {
return Ok(None);
};
let declares_scylla: bool = metadata_declares_scylla(object);
let nested: Option<&Map<String, Value>> = object
.get("scylla")
.and_then(Value::as_object)
.or_else(|| object.get("scylladb").and_then(Value::as_object))
.or_else(|| object.get("connection").and_then(Value::as_object));
let hosts: Option<Vec<String>> = nested
.and_then(|value| value.get("hosts"))
.and_then(parse_hosts_value)
.or_else(|| object.get("scyllaHosts").and_then(parse_hosts_value));
let keyspace: Option<String> = nested
.and_then(|value| value.get("keyspace"))
.and_then(optional_trimmed_string)
.or_else(|| {
object
.get("scyllaKeyspace")
.and_then(optional_trimmed_string)
});
let username: String = nested
.and_then(|value| value.get("username"))
.and_then(optional_trimmed_string)
.or_else(|| {
object
.get("scyllaUsername")
.and_then(optional_trimmed_string)
})
.unwrap_or_default();
let password: String = nested
.and_then(|value| value.get("password"))
.and_then(optional_trimmed_string)
.or_else(|| {
object
.get("scyllaPassword")
.and_then(optional_trimmed_string)
})
.unwrap_or_default();
match hosts {
Some(hosts) if !hosts.is_empty() => {
let host: String = hosts[0].clone();
let extra_hosts: Vec<String> = hosts.into_iter().skip(1).collect::<Vec<_>>();
Ok(Some(Self {
host,
extra_hosts,
keyspace,
username,
password,
}))
}
_ if declares_scylla => Err(
"metadata marks the client as Scylla-backed but does not include any Scylla hosts"
.to_string(),
),
_ => Ok(None),
}
}
pub fn known_hosts(&self) -> Vec<&str> {
let mut hosts: Vec<&str> = Vec::with_capacity(self.extra_hosts.len() + 1);
hosts.push(self.host.as_str());
hosts.extend(self.extra_hosts.iter().map(String::as_str));
hosts
}
}
pub async fn open_session(info: &ScyllaConnectionInfo) -> Result<Session, Box<dyn Error>> {
let tracker: &health::HostHealthTracker<health::SystemClock> = health::global_tracker();
if let Some(deadline) = tracker.offline_until(&info.host) {
return Err(Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>);
}
let mut builder: GenericSessionBuilder<scylla::client::session_builder::DefaultMode> =
SessionBuilder::new();
for host in info.known_hosts() {
builder = builder.known_node(host);
}
if !info.username.is_empty() || !info.password.is_empty() {
builder = builder.user(&info.username, &info.password);
}
match builder.build().await {
Ok(session) => {
if let Some(keyspace) = info.keyspace.as_deref() {
session.use_keyspace(keyspace, false).await?;
}
tracker.record_success(&info.host);
Ok(session)
}
Err(err) => {
if let Some(deadline) = tracker.record_failure(&info.host) {
return Err(
Box::new(HostOffline::new(info.host.clone(), deadline)) as Box<dyn Error>
);
}
Err(Box::new(err) as Box<dyn Error>)
}
}
}
pub(crate) fn is_connection_error(err: &ExecutionError) -> bool {
matches!(
err,
ExecutionError::ConnectionPoolError(_)
| ExecutionError::LastAttemptError(RequestAttemptError::BrokenConnectionError(_))
)
}
pub async fn execute_query_with_info(
query: String,
info: &ScyllaConnectionInfo,
) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
let session: Session = open_session(&info).await?;
let statement: Statement = build_statement(&query);
match session.query_unpaged(statement, &[]).await {
Ok(result) => map_query_result(result),
Err(err) => handle_query_error(err, &info),
}
}
fn build_statement(sql: &str) -> Statement {
let mut statement: Statement = Statement::new(sql);
statement.set_consistency(Consistency::One);
statement.set_page_size(1);
statement
}
fn map_query_result(result: QueryResult) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
if !result.is_rows() {
return Ok((
vec![json!({
"status": "success",
"message": "Query executed successfully"
})],
Vec::new(),
));
}
let rows_result: scylla::response::query_result::QueryRowsResult = result.into_rows_result()?;
let column_specs: scylla::response::query_result::ColumnSpecs<'_, '_> =
rows_result.column_specs();
let columns: Vec<String> = column_specs
.iter()
.map(|spec| spec.name().to_string())
.collect::<Vec<_>>();
let mut rows: Vec<Value> = Vec::new();
for row_result in rows_result.rows::<Row>()? {
let row: Row = row_result?;
let mut row_map: Map<String, Value> = Map::new();
for (i, spec) in column_specs.iter().enumerate() {
let column_name: String = spec.name().to_string();
let value: Value = match row.columns.get(i) {
Some(Some(cql_value)) => convert_cql_value(cql_value),
Some(None) | None => Value::Null,
};
row_map.insert(column_name, value);
}
rows.push(Value::Object(row_map));
}
Ok((rows, columns))
}
fn parse_hosts(raw: &str) -> Vec<String> {
raw.split(',')
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
.collect()
}
fn parse_hosts_value(value: &Value) -> Option<Vec<String>> {
match value {
Value::String(raw) => Some(parse_hosts(raw)),
Value::Array(values) => Some(
values
.iter()
.filter_map(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
.collect(),
),
_ => None,
}
}
fn optional_trimmed_string(value: &Value) -> Option<String> {
value
.as_str()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
}
fn metadata_declares_scylla(metadata: &Map<String, Value>) -> bool {
metadata
.get("dbEngine")
.or_else(|| metadata.get("backend"))
.or_else(|| metadata.get("engine"))
.and_then(Value::as_str)
.map(|value| {
let normalized = value.trim().to_ascii_lowercase();
normalized == "scylla" || normalized == "scylladb"
})
.unwrap_or(false)
|| metadata.contains_key("scylla")
|| metadata.contains_key("scylladb")
|| metadata.contains_key("scyllaHosts")
}
fn handle_query_error(
err: ExecutionError,
info: &ScyllaConnectionInfo,
) -> Result<(Vec<Value>, Vec<String>), Box<dyn Error>> {
if is_connection_error(&err)
&& let Some(deadline) = health::global_tracker().record_failure(&info.host)
{
return Err(Box::new(HostOffline::new(info.host.clone(), deadline)));
}
Err(Box::new(err))
}
fn convert_cql_value(value: &CqlValue) -> Value {
match value {
CqlValue::Text(s) => Value::String(s.clone()),
CqlValue::Int(i) => Value::Number(Number::from(*i)),
CqlValue::BigInt(b) => Value::Number(Number::from(*b)),
CqlValue::Boolean(b) => Value::Bool(*b),
CqlValue::Float(f) => {
Value::Number(Number::from_f64(*f as f64).unwrap_or_else(|| Number::from(0)))
}
CqlValue::Double(d) => {
Value::Number(Number::from_f64(*d).unwrap_or_else(|| Number::from(0)))
}
CqlValue::Decimal(decimal) => {
let decimal_str = format!("{:?}", decimal);
let decimal_value = if let Some(start) = decimal_str.find("value: ") {
let value_part = &decimal_str[start + 7..];
if let Some(end) = value_part.find(',') {
value_part[..end].parse().unwrap_or(0.0)
} else {
0.0
}
} else {
0.0
};
Value::Number(Number::from_f64(decimal_value).unwrap_or_else(|| Number::from(0)))
}
other => Value::String(format!("{:?}", other)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn metadata_parsing_accepts_flat_scylla_fields() {
let info: ScyllaConnectionInfo = ScyllaConnectionInfo::from_metadata(&json!({
"dbEngine": "scylladb",
"scyllaHosts": "10.0.0.10:9042,10.0.0.11:9042",
"scyllaKeyspace": "analytics",
"scyllaUsername": "cassandra",
"scyllaPassword": "secret"
}))
.expect("metadata should parse")
.expect("expected Scylla info");
assert_eq!(info.host, "10.0.0.10:9042");
assert_eq!(info.extra_hosts, vec!["10.0.0.11:9042".to_string()]);
assert_eq!(info.keyspace.as_deref(), Some("analytics"));
assert_eq!(info.username, "cassandra");
assert_eq!(info.password, "secret");
}
#[test]
fn metadata_parsing_accepts_nested_connection_shape() {
let info: ScyllaConnectionInfo = ScyllaConnectionInfo::from_metadata(&json!({
"backend": "scylla",
"scylla": {
"hosts": ["10.0.0.20:9042", "10.0.0.21:9042"],
"keyspace": "events"
}
}))
.expect("metadata should parse")
.expect("expected Scylla info");
assert_eq!(info.host, "10.0.0.20:9042");
assert_eq!(info.extra_hosts, vec!["10.0.0.21:9042".to_string()]);
assert_eq!(info.keyspace.as_deref(), Some("events"));
}
#[test]
fn metadata_parsing_rejects_missing_hosts_for_scylla_clients() {
match ScyllaConnectionInfo::from_metadata(&json!({
"dbEngine": "scylladb"
})) {
Err(error) => {
assert!(error.contains("does not include any Scylla hosts"));
}
Ok(_) => panic!("missing hosts should fail"),
}
}
#[test]
fn metadata_parsing_ignores_non_scylla_clients() {
let info: Option<ScyllaConnectionInfo> = ScyllaConnectionInfo::from_metadata(&json!({
"dbEngine": "postgres",
"pg_uri": "postgres://localhost/db"
}))
.expect("metadata should not fail");
assert!(info.is_none());
}
}