use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::RwLock;
use std::time::Instant;
use crate::connector::RedDBClient;
use crate::error::{ClientError, ErrorCode, Result};
use crate::params::Value as ParamValue;
use crate::router::{ClusterMembership, HealthAwareRouter, Outcome};
use crate::types::{BulkInsertResult, InsertResult, JsonValue, QueryResult, ValueOut};
pub const DEFAULT_POOL_SIZE: usize = 4;
pub struct GrpcClient {
primary: Endpoint,
replicas: RwLock<Vec<Endpoint>>,
#[allow(dead_code)]
next_replica: AtomicUsize,
force_primary: bool,
pool_size: usize,
router: RwLock<HealthAwareRouter>,
}
struct Endpoint {
url: String,
pool: Vec<RedDBClient>,
next: AtomicUsize,
}
impl Endpoint {
async fn connect(url: String, pool_size: usize) -> Result<Self> {
let n = pool_size.max(1);
let head = RedDBClient::connect(&url, None)
.await
.map_err(|e| ClientError::new(ErrorCode::IoError, format!("connect {url}: {e}")))?;
let mut pool = Vec::with_capacity(n);
for _ in 0..(n - 1) {
pool.push(head.clone());
}
pool.push(head);
Ok(Self {
url,
pool,
next: AtomicUsize::new(0),
})
}
fn pick(&self) -> RedDBClient {
let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.pool.len();
self.pool[idx].clone()
}
}
impl std::fmt::Debug for GrpcClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let replicas_guard = self.replicas.read().unwrap();
let replicas: Vec<&str> = replicas_guard.iter().map(|e| e.url.as_str()).collect();
f.debug_struct("GrpcClient")
.field("primary", &self.primary.url)
.field("replicas", &replicas)
.field("force_primary", &self.force_primary)
.finish()
}
}
impl GrpcClient {
pub async fn connect(endpoint: String) -> Result<Self> {
Self::connect_with_pool_size(endpoint, DEFAULT_POOL_SIZE).await
}
pub async fn connect_with_pool_size(endpoint: String, pool_size: usize) -> Result<Self> {
let primary = Endpoint::connect(endpoint, pool_size).await?;
let membership = ClusterMembership::new(primary.url.clone(), Vec::new());
let router = RwLock::new(HealthAwareRouter::with_force_primary(membership, true));
Ok(Self {
primary,
replicas: RwLock::new(Vec::new()),
next_replica: AtomicUsize::new(0),
force_primary: true,
pool_size,
router,
})
}
pub async fn connect_cluster(
primary: String,
replicas: Vec<String>,
force_primary: bool,
) -> Result<Self> {
Self::connect_cluster_with_pool_size(primary, replicas, force_primary, DEFAULT_POOL_SIZE)
.await
}
pub async fn connect_cluster_with_pool_size(
primary: String,
replicas: Vec<String>,
force_primary: bool,
pool_size: usize,
) -> Result<Self> {
let primary_ep = Endpoint::connect(primary, pool_size).await?;
let mut replica_eps = Vec::with_capacity(replicas.len());
for url in replicas {
replica_eps.push(Endpoint::connect(url, pool_size).await?);
}
let membership = ClusterMembership::new(
primary_ep.url.clone(),
replica_eps.iter().map(|e| e.url.clone()).collect(),
);
let router = RwLock::new(HealthAwareRouter::with_force_primary(
membership,
force_primary,
));
Ok(Self {
primary: primary_ep,
replicas: RwLock::new(replica_eps),
next_replica: AtomicUsize::new(0),
force_primary,
pool_size,
router,
})
}
pub fn endpoint(&self) -> &str {
&self.primary.url
}
pub fn replica_endpoints(&self) -> Vec<String> {
self.replicas
.read()
.unwrap()
.iter()
.map(|e| e.url.clone())
.collect()
}
fn read_endpoint(&self) -> (RedDBClient, usize) {
let idx = self.router.read().unwrap().pick_read_index();
if idx == 0 {
return (self.primary.pick(), 0);
}
let replicas = self.replicas.read().unwrap();
match replicas.get(idx - 1) {
Some(ep) => (ep.pick(), idx),
None => (self.primary.pick(), 0),
}
}
pub fn update_membership(&self, new_membership: ClusterMembership) {
self.router
.write()
.unwrap()
.update_membership(new_membership);
}
pub async fn apply_topology(&self, primary_addr: &str, replica_addrs: &[String]) -> Result<()> {
if primary_addr != self.primary.url {
return Err(ClientError::new(
ErrorCode::InvalidUri,
format!(
"topology advertised primary {} differs from connected {}; primary failover is out of scope for #172",
primary_addr, self.primary.url
),
));
}
let current_urls: Vec<String> = self
.replicas
.read()
.unwrap()
.iter()
.map(|e| e.url.clone())
.collect();
let mut next: Vec<Endpoint> = Vec::with_capacity(replica_addrs.len());
for url in replica_addrs {
if current_urls.iter().any(|u| u == url) {
let mut guard = self.replicas.write().unwrap();
if let Some(pos) = guard.iter().position(|e| e.url == *url) {
next.push(guard.swap_remove(pos));
}
} else {
next.push(Endpoint::connect(url.clone(), self.pool_size).await?);
}
}
{
let mut guard = self.replicas.write().unwrap();
*guard = next;
}
let membership = ClusterMembership::new(self.primary.url.clone(), replica_addrs.to_vec());
self.router.write().unwrap().update_membership(membership);
Ok(())
}
pub async fn refresh_topology(&self) -> Result<()> {
let mut client = self.primary.pick();
let bytes = client
.topology()
.await
.map_err(|e| ClientError::new(ErrorCode::IoError, format!("topology rpc: {e}")))?;
let membership =
crate::topology::TopologyConsumer::consume_bytes(&bytes, None).map_err(|e| {
ClientError::new(ErrorCode::QueryError, format!("decode topology: {e}"))
})?;
let replicas: Vec<String> = membership.replicas.iter().map(|r| r.addr.clone()).collect();
self.apply_topology(&membership.primary.addr, &replicas)
.await
}
pub(crate) fn observe(&self, idx: usize, outcome: Outcome) {
self.router.read().unwrap().observe_index(idx, outcome);
}
pub async fn query(&self, sql: &str) -> Result<QueryResult> {
let (mut client, idx) = self.read_endpoint();
let started = Instant::now();
let reply = match client.query_reply(sql).await {
Ok(r) => {
self.observe(idx, Outcome::Rtt(started.elapsed()));
r
}
Err(e) => {
self.observe(idx, Outcome::Timeout);
return Err(ClientError::new(ErrorCode::QueryError, e.to_string()));
}
};
parse_query_json(&reply.result_json)
}
pub async fn query_with(&self, sql: &str, params: &[ParamValue]) -> Result<QueryResult> {
if params.is_empty() {
return self.query(sql).await;
}
let grpc_params = params_to_grpc_values(params);
let (mut client, idx) = self.read_endpoint();
let started = Instant::now();
let reply = match client.query_reply_with_params(sql, grpc_params).await {
Ok(r) => {
self.observe(idx, Outcome::Rtt(started.elapsed()));
r
}
Err(e) => {
self.observe(idx, Outcome::Timeout);
return Err(ClientError::new(ErrorCode::QueryError, e.to_string()));
}
};
parse_query_json(&reply.result_json)
}
pub async fn insert(&self, collection: &str, payload: &JsonValue) -> Result<InsertResult> {
if payload.as_object().is_none() {
return Err(ClientError::new(
ErrorCode::QueryError,
"insert payload must be a JSON object".to_string(),
));
}
let json_payload = payload.to_json_string();
let mut client = self.primary.pick();
let reply = client
.create_row_entity(collection, &json_payload)
.await
.map_err(|e| ClientError::new(ErrorCode::QueryError, e.to_string()))?;
Ok(InsertResult {
affected: 1,
id: Some(reply.id.to_string()),
})
}
pub async fn bulk_insert(
&self,
collection: &str,
payloads: &[JsonValue],
) -> Result<BulkInsertResult> {
let mut encoded = Vec::with_capacity(payloads.len());
for payload in payloads {
if payload.as_object().is_none() {
return Err(ClientError::new(
ErrorCode::QueryError,
"bulk_insert payloads must be JSON objects".to_string(),
));
}
encoded.push(payload.to_json_string());
}
let mut client = self.primary.pick();
let reply = client
.bulk_create_rows(collection, encoded)
.await
.map_err(|e| ClientError::new(ErrorCode::QueryError, e.to_string()))?;
Ok(BulkInsertResult {
affected: reply.count,
ids: reply.ids.into_iter().map(|id| id.to_string()).collect(),
})
}
pub async fn delete(&self, collection: &str, id: &str) -> Result<u64> {
let id = id.parse::<u64>().map_err(|_| {
ClientError::new(
ErrorCode::InvalidUri,
"id must be a numeric string".to_string(),
)
})?;
let mut client = self.primary.pick();
client
.delete_entity(collection, id)
.await
.map_err(|e| ClientError::new(ErrorCode::QueryError, e.to_string()))?;
Ok(1)
}
pub async fn close(&self) -> Result<()> {
Ok(())
}
pub fn ingest_topology_bytes(
&self,
bytes: &[u8],
uri_seed: Option<crate::topology::UriSeed>,
) -> std::result::Result<crate::topology::ClusterMembership, crate::topology::ConsumeError>
{
crate::topology::TopologyConsumer::consume_bytes(bytes, uri_seed)
}
}
fn parse_query_json(s: &str) -> Result<QueryResult> {
let parsed: serde_json::Value = serde_json::from_str(s)
.map_err(|e| ClientError::new(ErrorCode::QueryError, format!("bad server JSON: {e}")))?;
let statement = parsed
.get("statement")
.and_then(|v| v.as_str())
.unwrap_or("select")
.to_string();
let affected = parsed
.get("affected")
.and_then(|v| v.as_f64())
.unwrap_or(0.0) as u64;
let columns = parsed
.get("columns")
.and_then(|v| v.as_array())
.map(|cols| {
cols.iter()
.filter_map(|col| col.as_str().map(ToString::to_string))
.collect::<Vec<_>>()
})
.unwrap_or_default();
let rows = parsed
.get("rows")
.or_else(|| parsed.get("records"))
.and_then(|v| v.as_array())
.map(|rows| rows.iter().map(parse_row_value).collect())
.unwrap_or_default();
Ok(QueryResult {
statement,
affected,
columns,
rows,
})
}
fn parse_row_value(value: &serde_json::Value) -> Vec<(String, ValueOut)> {
value
.as_object()
.map(|row| {
row.iter()
.map(|(key, value)| (key.clone(), parse_scalar(value)))
.collect()
})
.unwrap_or_default()
}
fn parse_scalar(value: &serde_json::Value) -> ValueOut {
match value {
serde_json::Value::Null => ValueOut::Null,
serde_json::Value::Bool(b) => ValueOut::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
ValueOut::Integer(i)
} else if let Some(f) = n.as_f64() {
if f.fract() == 0.0 {
ValueOut::Integer(f as i64)
} else {
ValueOut::Float(f)
}
} else {
ValueOut::String(n.to_string())
}
}
serde_json::Value::String(s) => ValueOut::String(s.clone()),
other => ValueOut::String(other.to_string()),
}
}
fn params_to_grpc_values(params: &[ParamValue]) -> Vec<reddb_grpc_proto::QueryValue> {
use reddb_grpc_proto::query_value::Kind;
use reddb_grpc_proto::{QueryNull, QueryValue, QueryVector};
params
.iter()
.cloned()
.map(|value| {
let kind = match value {
ParamValue::Null => Kind::NullValue(QueryNull {}),
ParamValue::Bool(value) => Kind::BoolValue(value),
ParamValue::Int(value) => Kind::IntValue(value),
ParamValue::Float(value) => Kind::FloatValue(value),
ParamValue::Text(value) => Kind::TextValue(value),
ParamValue::Bytes(value) => Kind::BytesValue(value),
ParamValue::Vector(values) => Kind::VectorValue(QueryVector { values }),
ParamValue::Json(value) => Kind::JsonValue(value.to_json_string()),
ParamValue::Timestamp(value) => Kind::TimestampValue(value),
ParamValue::Uuid(value) => Kind::UuidValue(value.to_vec()),
};
QueryValue { kind: Some(kind) }
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use prost::Message;
#[test]
fn parse_query_json_extracts_rows_and_columns() {
let input = r#"{"statement":"select","affected":0,"columns":["id","name"],"rows":[{"id":1,"name":"Alice"},{"id":2,"name":"Bob"}]}"#;
let qr = parse_query_json(input).unwrap();
assert_eq!(qr.statement, "select");
assert_eq!(qr.affected, 0);
assert_eq!(qr.columns, vec!["id".to_string(), "name".to_string()]);
assert_eq!(qr.rows.len(), 2);
assert_eq!(qr.rows[0][0].0, "id");
assert!(matches!(qr.rows[0][0].1, ValueOut::Integer(1)));
assert_eq!(qr.rows[1][1].0, "name");
assert!(matches!(&qr.rows[1][1].1, ValueOut::String(s) if s == "Bob"));
}
#[test]
fn parse_query_json_handles_empty_rows() {
let input = r#"{"statement":"select","affected":0,"columns":[],"rows":[]}"#;
let qr = parse_query_json(input).unwrap();
assert!(qr.rows.is_empty());
assert!(qr.columns.is_empty());
}
#[test]
fn grpc_params_match_shared_fixtures() {
let manifest: serde_json::Value = serde_json::from_str(include_str!(
"../../reddb-wire/tests/fixtures/params/manifest.json"
))
.expect("manifest json");
for fixture in manifest["values"].as_array().expect("values array") {
let name = fixture["name"].as_str().expect("fixture name");
let expected = fixture["grpc_hex"].as_str().expect("fixture grpc_hex");
let encoded = params_to_grpc_values(&[fixture_param_value(name)]);
assert_eq!(expected, hex(&encoded[0].encode_to_vec()), "{name}");
}
for query in manifest["queries"].as_array().expect("queries array") {
let params = query["params"]
.as_array()
.expect("query params")
.iter()
.map(|param| fixture_param_value(param.as_str().expect("param name")))
.collect::<Vec<_>>();
let request = reddb_grpc_proto::QueryRequest {
query: query["sql"].as_str().expect("query sql").to_string(),
entity_types: Vec::new(),
capabilities: Vec::new(),
params: params_to_grpc_values(¶ms),
};
assert_eq!(
query["grpc_request_hex"]
.as_str()
.expect("query grpc_request_hex"),
hex(&request.encode_to_vec()),
"{}",
query["name"].as_str().unwrap()
);
}
}
fn fixture_param_value(name: &str) -> ParamValue {
match name {
"null" => ParamValue::Null,
"bool_true" => ParamValue::Bool(true),
"bool_false" => ParamValue::Bool(false),
"int_min" => ParamValue::Int(i64::MIN),
"int_max" => ParamValue::Int(i64::MAX),
"int_42" => ParamValue::Int(42),
"float_nan" => ParamValue::Float(f64::from_bits(0x7ff8000000000000)),
"float_pos_inf" => ParamValue::Float(f64::INFINITY),
"float_neg_inf" => ParamValue::Float(f64::NEG_INFINITY),
"float_subnormal_min" => ParamValue::Float(f64::from_bits(1)),
"text_unicode" => ParamValue::Text("h\u{e9}llo".to_string()),
"text_x" => ParamValue::Text("x".to_string()),
"bytes_empty" => ParamValue::Bytes(Vec::new()),
"bytes_deadbeef" => ParamValue::Bytes(vec![0xde, 0xad, 0xbe, 0xef]),
"bytes_256" => ParamValue::Bytes((0..=255).map(|value| value as u8).collect()),
"json_nested" => ParamValue::Json(JsonValue::object([
("a", JsonValue::Null),
(
"z",
JsonValue::array([
JsonValue::Number(1.0),
JsonValue::object([(
"deep",
JsonValue::array([JsonValue::Bool(true), JsonValue::Bool(false)]),
)]),
]),
),
])),
"timestamp_zero" => ParamValue::Timestamp(0),
"timestamp_max" => ParamValue::Timestamp(i64::MAX),
"uuid_001122" => ParamValue::Uuid([
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd,
0xee, 0xff,
]),
"vector_empty" => ParamValue::Vector(Vec::new()),
"vector_three" => ParamValue::Vector(vec![1.0, 2.0, -0.5]),
"vector_128" => ParamValue::Vector((0..128).map(|value| value as f32).collect()),
other => panic!("unknown fixture {other}"),
}
}
fn hex(bytes: &[u8]) -> String {
bytes.iter().map(|byte| format!("{byte:02x}")).collect()
}
#[test]
fn parse_query_json_tolerates_missing_fields() {
let qr = parse_query_json("{}").unwrap();
assert_eq!(qr.affected, 0);
assert!(qr.rows.is_empty());
}
#[test]
fn grpc_params_preserve_wire_value_variants() {
use reddb_grpc_proto::query_value::Kind;
let uuid = [0x11; 16];
let params = vec![
crate::params::Value::Null,
crate::params::Value::Bool(true),
crate::params::Value::Int(42),
crate::params::Value::Float(1.5),
crate::params::Value::Text("alice".into()),
crate::params::Value::Bytes(vec![0, 1, 2]),
crate::params::Value::Vector(vec![0.25, 0.5]),
crate::params::Value::Json(crate::types::JsonValue::object([(
"role",
crate::types::JsonValue::string("admin"),
)])),
crate::params::Value::Timestamp(1_779_999_000),
crate::params::Value::Uuid(uuid),
];
let encoded = params_to_grpc_values(¶ms);
assert_eq!(encoded.len(), 10);
assert!(matches!(encoded[0].kind, Some(Kind::NullValue(_))));
assert!(matches!(encoded[1].kind, Some(Kind::BoolValue(true))));
assert!(matches!(encoded[2].kind, Some(Kind::IntValue(42))));
assert!(matches!(encoded[3].kind, Some(Kind::FloatValue(1.5))));
assert!(matches!(&encoded[4].kind, Some(Kind::TextValue(v)) if v == "alice"));
assert!(matches!(&encoded[5].kind, Some(Kind::BytesValue(v)) if v == &[0, 1, 2]));
assert!(
matches!(&encoded[6].kind, Some(Kind::VectorValue(v)) if v.values == vec![0.25, 0.5])
);
assert!(
matches!(&encoded[7].kind, Some(Kind::JsonValue(v)) if v == "{\"role\":\"admin\"}")
);
assert!(matches!(
encoded[8].kind,
Some(Kind::TimestampValue(1_779_999_000))
));
assert!(matches!(&encoded[9].kind, Some(Kind::UuidValue(v)) if v == &uuid));
}
}