use crate::cluster::worker_registry::{LoadBalancingStrategy, WorkerRegistry};
use crate::{RpcClient, RpcConfig, RpcError};
use std::collections::HashMap;
use std::sync::Arc;
pub struct ClusterClient {
registry: Arc<WorkerRegistry>,
config: RpcConfig,
clients: Arc<tokio::sync::RwLock<HashMap<std::net::SocketAddr, Arc<RpcClient>>>>,
}
impl ClusterClient {
pub fn new(registry: Arc<WorkerRegistry>, config: RpcConfig) -> Self {
Self {
registry,
config,
clients: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub async fn call_worker(
&self,
method: &str,
params: Vec<u8>,
tag_filter: Option<&HashMap<String, String>>,
) -> Result<Vec<u8>, RpcError> {
let worker = self
.registry
.select_worker(tag_filter)
.await
.ok_or_else(|| RpcError::ConnectionError("No available workers".to_string()))?;
let client = self.get_or_create_client(worker.addr).await?;
worker.increment_connections();
let result = client.call(method, params).await;
worker.decrement_connections();
result
}
async fn get_or_create_client(
&self,
addr: std::net::SocketAddr,
) -> Result<Arc<RpcClient>, RpcError> {
{
let clients = self.clients.read().await;
if let Some(client) = clients.get(&addr) {
return Ok(client.clone());
}
}
let client = Arc::new(RpcClient::connect(addr, self.config.clone()).await?);
let mut clients = self.clients.write().await;
clients.insert(addr, client.clone());
Ok(client)
}
pub async fn call_all_workers(
&self,
method: &str,
params: Vec<u8>,
tag_filter: Option<&HashMap<String, String>>,
) -> Vec<Result<Vec<u8>, RpcError>> {
let workers = if let Some(filter) = tag_filter {
let mut filtered = Vec::new();
for worker in self.registry.all_workers().await {
if filter
.iter()
.all(|(k, v)| worker.tags.get(k).map(|val| val == v).unwrap_or(false))
{
filtered.push(worker);
}
}
filtered
} else {
self.registry.all_workers().await
};
let mut tasks = Vec::new();
for worker in workers {
let addr = worker.addr;
let method = method.to_string();
let params = params.clone();
let clients = self.clients.clone();
let config = self.config.clone();
let task = tokio::spawn(async move {
let client = {
let clients_read = clients.read().await;
if let Some(client) = clients_read.get(&addr) {
client.clone()
} else {
drop(clients_read);
let new_client = Arc::new(RpcClient::connect(addr, config).await?);
let mut clients_write = clients.write().await;
clients_write.insert(addr, new_client.clone());
new_client
}
};
worker.increment_connections();
let result = client.call(&method, params).await;
worker.decrement_connections();
result
});
tasks.push(task);
}
let mut results = Vec::new();
for task in tasks {
match task.await {
Ok(result) => results.push(result),
Err(e) => results.push(Err(RpcError::ConnectionError(e.to_string()))),
}
}
results
}
pub fn strategy(&self) -> LoadBalancingStrategy {
self.registry.strategy()
}
pub async fn worker_count(&self) -> usize {
self.registry.worker_count().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cluster::{ClusterConfig, ClusterMembership, LoadBalancingStrategy, WorkerRegistry};
use s2n_quic::Client as QuicClient;
use std::net::SocketAddr;
use std::path::Path;
async fn create_test_client() -> Arc<QuicClient> {
let cert_path = Path::new("certs/test_cert.pem");
let client = QuicClient::builder()
.with_tls(cert_path)
.unwrap()
.with_io("0.0.0.0:0")
.unwrap()
.start()
.unwrap();
Arc::new(client)
}
#[tokio::test]
async fn test_cluster_client_creation() {
let config = ClusterConfig::default();
let addr: SocketAddr = "127.0.0.1:10000".parse().unwrap();
let quic_client = create_test_client().await;
let cluster = Arc::new(
ClusterMembership::new(addr, config, quic_client)
.await
.unwrap(),
);
let registry = Arc::new(WorkerRegistry::new(
cluster,
LoadBalancingStrategy::RoundRobin,
));
let rpc_config = RpcConfig::new("certs/test_cert.pem", "127.0.0.1:0");
let cluster_client = ClusterClient::new(registry, rpc_config);
assert_eq!(cluster_client.worker_count().await, 0);
assert_eq!(cluster_client.strategy(), LoadBalancingStrategy::RoundRobin);
}
}