use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::RwLock;
use std::time::Instant;
use crate::connector::RedDBClient;
use crate::error::{ClientError, ErrorCode, Result};
use crate::router::{ClusterMembership, HealthAwareRouter, Outcome};
use crate::types::{InsertResult, JsonValue, QueryResult, ValueOut};
pub const DEFAULT_POOL_SIZE: usize = 4;
pub struct GrpcClient {
primary: Endpoint,
replicas: RwLock<Vec<Endpoint>>,
#[allow(dead_code)]
next_replica: AtomicUsize,
force_primary: bool,
pool_size: usize,
router: RwLock<HealthAwareRouter>,
}
struct Endpoint {
url: String,
pool: Vec<RedDBClient>,
next: AtomicUsize,
}
impl Endpoint {
async fn connect(url: String, pool_size: usize) -> Result<Self> {
let n = pool_size.max(1);
let head = RedDBClient::connect(&url, None)
.await
.map_err(|e| ClientError::new(ErrorCode::IoError, format!("connect {url}: {e}")))?;
let mut pool = Vec::with_capacity(n);
for _ in 0..(n - 1) {
pool.push(head.clone());
}
pool.push(head);
Ok(Self {
url,
pool,
next: AtomicUsize::new(0),
})
}
fn pick(&self) -> RedDBClient {
let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.pool.len();
self.pool[idx].clone()
}
}
impl std::fmt::Debug for GrpcClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let replicas_guard = self.replicas.read().unwrap();
let replicas: Vec<&str> = replicas_guard.iter().map(|e| e.url.as_str()).collect();
f.debug_struct("GrpcClient")
.field("primary", &self.primary.url)
.field("replicas", &replicas)
.field("force_primary", &self.force_primary)
.finish()
}
}
impl GrpcClient {
pub async fn connect(endpoint: String) -> Result<Self> {
Self::connect_with_pool_size(endpoint, DEFAULT_POOL_SIZE).await
}
pub async fn connect_with_pool_size(endpoint: String, pool_size: usize) -> Result<Self> {
let primary = Endpoint::connect(endpoint, pool_size).await?;
let membership = ClusterMembership::new(primary.url.clone(), Vec::new());
let router = RwLock::new(HealthAwareRouter::with_force_primary(membership, true));
Ok(Self {
primary,
replicas: RwLock::new(Vec::new()),
next_replica: AtomicUsize::new(0),
force_primary: true,
pool_size,
router,
})
}
pub async fn connect_cluster(
primary: String,
replicas: Vec<String>,
force_primary: bool,
) -> Result<Self> {
Self::connect_cluster_with_pool_size(primary, replicas, force_primary, DEFAULT_POOL_SIZE)
.await
}
pub async fn connect_cluster_with_pool_size(
primary: String,
replicas: Vec<String>,
force_primary: bool,
pool_size: usize,
) -> Result<Self> {
let primary_ep = Endpoint::connect(primary, pool_size).await?;
let mut replica_eps = Vec::with_capacity(replicas.len());
for url in replicas {
replica_eps.push(Endpoint::connect(url, pool_size).await?);
}
let membership = ClusterMembership::new(
primary_ep.url.clone(),
replica_eps.iter().map(|e| e.url.clone()).collect(),
);
let router = RwLock::new(HealthAwareRouter::with_force_primary(
membership,
force_primary,
));
Ok(Self {
primary: primary_ep,
replicas: RwLock::new(replica_eps),
next_replica: AtomicUsize::new(0),
force_primary,
pool_size,
router,
})
}
pub fn endpoint(&self) -> &str {
&self.primary.url
}
pub fn replica_endpoints(&self) -> Vec<String> {
self.replicas
.read()
.unwrap()
.iter()
.map(|e| e.url.clone())
.collect()
}
fn read_endpoint(&self) -> (RedDBClient, usize) {
let idx = self.router.read().unwrap().pick_read_index();
if idx == 0 {
return (self.primary.pick(), 0);
}
let replicas = self.replicas.read().unwrap();
match replicas.get(idx - 1) {
Some(ep) => (ep.pick(), idx),
None => (self.primary.pick(), 0),
}
}
pub fn update_membership(&self, new_membership: ClusterMembership) {
self.router
.write()
.unwrap()
.update_membership(new_membership);
}
pub async fn apply_topology(&self, primary_addr: &str, replica_addrs: &[String]) -> Result<()> {
if primary_addr != self.primary.url {
return Err(ClientError::new(
ErrorCode::InvalidUri,
format!(
"topology advertised primary {} differs from connected {}; primary failover is out of scope for #172",
primary_addr, self.primary.url
),
));
}
let current_urls: Vec<String> = self
.replicas
.read()
.unwrap()
.iter()
.map(|e| e.url.clone())
.collect();
let mut next: Vec<Endpoint> = Vec::with_capacity(replica_addrs.len());
for url in replica_addrs {
if current_urls.iter().any(|u| u == url) {
let mut guard = self.replicas.write().unwrap();
if let Some(pos) = guard.iter().position(|e| e.url == *url) {
next.push(guard.swap_remove(pos));
}
} else {
next.push(Endpoint::connect(url.clone(), self.pool_size).await?);
}
}
{
let mut guard = self.replicas.write().unwrap();
*guard = next;
}
let membership = ClusterMembership::new(self.primary.url.clone(), replica_addrs.to_vec());
self.router.write().unwrap().update_membership(membership);
Ok(())
}
pub async fn refresh_topology(&self) -> Result<()> {
let mut client = self.primary.pick();
let bytes = client
.topology()
.await
.map_err(|e| ClientError::new(ErrorCode::IoError, format!("topology rpc: {e}")))?;
let membership =
crate::topology::TopologyConsumer::consume_bytes(&bytes, None).map_err(|e| {
ClientError::new(ErrorCode::QueryError, format!("decode topology: {e}"))
})?;
let replicas: Vec<String> = membership.replicas.iter().map(|r| r.addr.clone()).collect();
self.apply_topology(&membership.primary.addr, &replicas)
.await
}
pub(crate) fn observe(&self, idx: usize, outcome: Outcome) {
self.router.read().unwrap().observe_index(idx, outcome);
}
pub async fn query(&self, sql: &str) -> Result<QueryResult> {
let (mut client, idx) = self.read_endpoint();
let started = Instant::now();
let reply = match client.query_reply(sql).await {
Ok(r) => {
self.observe(idx, Outcome::Rtt(started.elapsed()));
r
}
Err(e) => {
self.observe(idx, Outcome::Timeout);
return Err(ClientError::new(ErrorCode::QueryError, e.to_string()));
}
};
parse_query_json(&reply.result_json)
}
pub async fn insert(&self, collection: &str, payload: &JsonValue) -> Result<InsertResult> {
if payload.as_object().is_none() {
return Err(ClientError::new(
ErrorCode::QueryError,
"insert payload must be a JSON object".to_string(),
));
}
let json_payload = payload.to_json_string();
let mut client = self.primary.pick();
let reply = client
.create_row_entity(collection, &json_payload)
.await
.map_err(|e| ClientError::new(ErrorCode::QueryError, e.to_string()))?;
Ok(InsertResult {
affected: 1,
id: Some(reply.id.to_string()),
})
}
pub async fn bulk_insert(&self, collection: &str, payloads: &[JsonValue]) -> Result<u64> {
let mut encoded = Vec::with_capacity(payloads.len());
for payload in payloads {
if payload.as_object().is_none() {
return Err(ClientError::new(
ErrorCode::QueryError,
"bulk_insert payloads must be JSON objects".to_string(),
));
}
encoded.push(payload.to_json_string());
}
let mut client = self.primary.pick();
let reply = client
.bulk_create_rows(collection, encoded)
.await
.map_err(|e| ClientError::new(ErrorCode::QueryError, e.to_string()))?;
Ok(reply.count)
}
pub async fn delete(&self, collection: &str, id: &str) -> Result<u64> {
let id = id.parse::<u64>().map_err(|_| {
ClientError::new(
ErrorCode::InvalidUri,
"id must be a numeric string".to_string(),
)
})?;
let mut client = self.primary.pick();
client
.delete_entity(collection, id)
.await
.map_err(|e| ClientError::new(ErrorCode::QueryError, e.to_string()))?;
Ok(1)
}
pub async fn close(&self) -> Result<()> {
Ok(())
}
pub fn ingest_topology_bytes(
&self,
bytes: &[u8],
uri_seed: Option<crate::topology::UriSeed>,
) -> std::result::Result<crate::topology::ClusterMembership, crate::topology::ConsumeError>
{
crate::topology::TopologyConsumer::consume_bytes(bytes, uri_seed)
}
}
fn parse_query_json(s: &str) -> Result<QueryResult> {
let parsed: serde_json::Value = serde_json::from_str(s)
.map_err(|e| ClientError::new(ErrorCode::QueryError, format!("bad server JSON: {e}")))?;
let statement = parsed
.get("statement")
.and_then(|v| v.as_str())
.unwrap_or("select")
.to_string();
let affected = parsed
.get("affected")
.and_then(|v| v.as_f64())
.unwrap_or(0.0) as u64;
let columns = parsed
.get("columns")
.and_then(|v| v.as_array())
.map(|cols| {
cols.iter()
.filter_map(|col| col.as_str().map(ToString::to_string))
.collect::<Vec<_>>()
})
.unwrap_or_default();
let rows = parsed
.get("rows")
.or_else(|| parsed.get("records"))
.and_then(|v| v.as_array())
.map(|rows| rows.iter().map(parse_row_value).collect())
.unwrap_or_default();
Ok(QueryResult {
statement,
affected,
columns,
rows,
})
}
fn parse_row_value(value: &serde_json::Value) -> Vec<(String, ValueOut)> {
value
.as_object()
.map(|row| {
row.iter()
.map(|(key, value)| (key.clone(), parse_scalar(value)))
.collect()
})
.unwrap_or_default()
}
fn parse_scalar(value: &serde_json::Value) -> ValueOut {
match value {
serde_json::Value::Null => ValueOut::Null,
serde_json::Value::Bool(b) => ValueOut::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
ValueOut::Integer(i)
} else if let Some(f) = n.as_f64() {
if f.fract() == 0.0 {
ValueOut::Integer(f as i64)
} else {
ValueOut::Float(f)
}
} else {
ValueOut::String(n.to_string())
}
}
serde_json::Value::String(s) => ValueOut::String(s.clone()),
other => ValueOut::String(other.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_query_json_extracts_rows_and_columns() {
let input = r#"{"statement":"select","affected":0,"columns":["id","name"],"rows":[{"id":1,"name":"Alice"},{"id":2,"name":"Bob"}]}"#;
let qr = parse_query_json(input).unwrap();
assert_eq!(qr.statement, "select");
assert_eq!(qr.affected, 0);
assert_eq!(qr.columns, vec!["id".to_string(), "name".to_string()]);
assert_eq!(qr.rows.len(), 2);
assert_eq!(qr.rows[0][0].0, "id");
assert!(matches!(qr.rows[0][0].1, ValueOut::Integer(1)));
assert_eq!(qr.rows[1][1].0, "name");
assert!(matches!(&qr.rows[1][1].1, ValueOut::String(s) if s == "Bob"));
}
#[test]
fn parse_query_json_handles_empty_rows() {
let input = r#"{"statement":"select","affected":0,"columns":[],"rows":[]}"#;
let qr = parse_query_json(input).unwrap();
assert!(qr.rows.is_empty());
assert!(qr.columns.is_empty());
}
#[test]
fn parse_query_json_tolerates_missing_fields() {
let qr = parse_query_json("{}").unwrap();
assert_eq!(qr.affected, 0);
assert!(qr.rows.is_empty());
}
}