use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::watch;
use tokio::sync::{Mutex, RwLock};
use tonic::transport::Channel;
use tracing::{debug, info, warn};
use crate::config::GooseFsConfig;
use crate::error::{Error, Result};
use crate::proto::grpc::version::{
service_version_client_service_client::ServiceVersionClientServiceClient,
GetServiceVersionPRequest, ServiceType,
};
use crate::retry::{ExponentialTimeBoundedRetry, RetryPolicy};
#[async_trait]
pub trait MasterInquireClient: Send + Sync {
async fn get_primary_rpc_address(&self) -> Result<String>;
fn get_master_rpc_addresses(&self) -> Vec<String>;
async fn reset_cached_primary(&self);
}
pub struct SingleMasterInquireClient {
address: String,
}
impl SingleMasterInquireClient {
pub fn new(address: String) -> Self {
Self { address }
}
}
#[async_trait]
impl MasterInquireClient for SingleMasterInquireClient {
async fn get_primary_rpc_address(&self) -> Result<String> {
Ok(self.address.clone())
}
fn get_master_rpc_addresses(&self) -> Vec<String> {
vec![self.address.clone()]
}
async fn reset_cached_primary(&self) {
}
}
#[derive(Debug, Clone)]
enum PollResult {
Ok(String),
Err(String),
}
type PollGate = Mutex<Option<watch::Receiver<Option<PollResult>>>>;
pub struct PollingMasterInquireClient {
addresses: Vec<String>,
cached_primary: Arc<RwLock<Option<String>>>,
max_duration: Duration,
initial_sleep: Duration,
max_sleep: Duration,
polling_timeout: Duration,
poll_gate: Arc<PollGate>,
}
impl PollingMasterInquireClient {
pub fn new(
addresses: Vec<String>,
max_duration: Duration,
initial_sleep: Duration,
max_sleep: Duration,
polling_timeout: Duration,
) -> Self {
Self {
addresses,
cached_primary: Arc::new(RwLock::new(None)),
max_duration,
initial_sleep,
max_sleep,
polling_timeout,
poll_gate: Arc::new(Mutex::new(None)),
}
}
async fn ping_meta_service(&self, addr: &str) -> std::result::Result<(), PingError> {
let endpoint_uri = format!("http://{}", addr);
let endpoint = Channel::from_shared(endpoint_uri)
.map_err(|e| PingError::Fatal(format!("invalid endpoint for {}: {}", addr, e)))?
.connect_timeout(self.polling_timeout)
.timeout(self.polling_timeout);
let channel = endpoint
.connect()
.await
.map_err(|e| PingError::Unavailable(format!("{}: connection failed: {}", addr, e)))?;
let mut client = ServiceVersionClientServiceClient::new(channel);
let req = GetServiceVersionPRequest {
service_type: Some(ServiceType::MetaMasterClientService as i32),
allowed_on_standby_masters: Some(false),
};
match client.get_service_version(req).await {
Ok(resp) => {
let version = resp.into_inner().version.unwrap_or(0);
debug!(addr = %addr, version = version, "primary master detected");
Ok(())
}
Err(status) => match status.code() {
tonic::Code::NotFound => {
debug!(addr = %addr, "standby master (NotFound)");
Err(PingError::Standby)
}
tonic::Code::Unavailable
| tonic::Code::DeadlineExceeded
| tonic::Code::Cancelled => {
debug!(addr = %addr, code = ?status.code(), "master unavailable or timed out");
Err(PingError::Unavailable(format!(
"{}: [{}] {}",
addr,
status.code(),
status.message()
)))
}
_ => {
warn!(addr = %addr, code = ?status.code(), msg = %status.message(), "unexpected error pinging master");
Err(PingError::Fatal(format!(
"{}: [{}] {}",
addr,
status.code(),
status.message()
)))
}
},
}
}
pub async fn reset_primary(&self) {
let mut cache = self.cached_primary.write().await;
*cache = None;
}
async fn poll_for_primary(&self) -> std::result::Result<String, String> {
let mut retry =
ExponentialTimeBoundedRetry::new(self.max_duration, self.initial_sleep, self.max_sleep);
let mut last_errors: Vec<String> = Vec::new();
while retry.should_retry() {
last_errors.clear();
for addr in &self.addresses {
match self.ping_meta_service(addr).await {
Ok(()) => {
info!(addr = %addr, attempts = retry.attempt_count(), "discovered primary master");
let mut cache = self.cached_primary.write().await;
*cache = Some(addr.clone());
return Ok(addr.clone());
}
Err(PingError::Standby) => {
last_errors.push(format!("{}: standby", addr));
continue;
}
Err(PingError::Unavailable(msg)) => {
last_errors.push(msg);
continue;
}
Err(PingError::Fatal(msg)) => {
last_errors.push(msg);
break;
}
}
}
let sleep_dur = retry.next_sleep();
debug!(
attempt = retry.attempt_count(),
sleep_ms = sleep_dur.as_millis(),
"no primary found this round, sleeping"
);
tokio::time::sleep(sleep_dur).await;
}
Err(format!(
"failed to find primary master after {} attempts across {} addresses. Last round errors: [{}]",
retry.attempt_count(),
self.addresses.len(),
last_errors.join("; "),
))
}
}
#[async_trait]
impl MasterInquireClient for PollingMasterInquireClient {
async fn get_primary_rpc_address(&self) -> Result<String> {
{
let cache = self.cached_primary.read().await;
if let Some(ref addr) = *cache {
if self.ping_meta_service(addr).await.is_ok() {
debug!(addr = %addr, "cached primary still valid");
return Ok(addr.clone());
}
debug!(addr = %addr, "cached primary stale, re-polling");
}
}
let rx_opt: Option<watch::Receiver<Option<PollResult>>> = {
let mut gate = self.poll_gate.lock().await;
match &*gate {
Some(existing_rx) => {
debug!("singleflight follower: waiting for in-flight poll");
Some(existing_rx.clone())
}
None => {
let (tx, rx) = watch::channel::<Option<PollResult>>(None);
*gate = Some(rx);
drop(gate);
debug!("singleflight leader: starting primary poll");
let result = self.poll_for_primary().await;
let broadcast = match &result {
Ok(addr) => PollResult::Ok(addr.clone()),
Err(msg) => PollResult::Err(msg.clone()),
};
let _ = tx.send(Some(broadcast));
let mut gate2 = self.poll_gate.lock().await;
*gate2 = None;
return result.map_err(|msg| Error::Internal {
message: msg,
source: None,
});
}
}
};
if let Some(mut rx) = rx_opt {
loop {
if rx.changed().await.is_err() {
warn!("singleflight leader dropped channel, follower retrying");
return self.get_primary_rpc_address().await;
}
let value = rx.borrow().clone();
match value {
Some(PollResult::Ok(addr)) => {
debug!(addr = %addr, "singleflight follower received primary");
return Ok(addr);
}
Some(PollResult::Err(msg)) => {
return Err(Error::Internal {
message: msg,
source: None,
});
}
None => {
continue;
}
}
}
}
Err(Error::Internal {
message: "singleflight logic error: neither leader nor follower path returned"
.to_string(),
source: None,
})
}
fn get_master_rpc_addresses(&self) -> Vec<String> {
self.addresses.clone()
}
async fn reset_cached_primary(&self) {
self.reset_primary().await;
}
}
enum PingError {
Standby,
Unavailable(String),
Fatal(String),
}
pub fn create_master_inquire_client(config: &GooseFsConfig) -> Arc<dyn MasterInquireClient> {
let addrs = config.master_addresses();
if addrs.len() <= 1 {
let addr = addrs
.into_iter()
.next()
.unwrap_or_else(|| config.master_addr.clone());
debug!(addr = %addr, "using SingleMasterInquireClient");
Arc::new(SingleMasterInquireClient::new(addr))
} else {
debug!(addresses = ?addrs, "using PollingMasterInquireClient");
Arc::new(PollingMasterInquireClient::new(
addrs,
config.master_inquire_retry_max_duration,
config.master_inquire_initial_sleep,
config.master_inquire_max_sleep,
config.master_polling_timeout,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn test_single_master_returns_address() {
let client = SingleMasterInquireClient::new("master:19998".to_string());
assert_eq!(
client.get_primary_rpc_address().await.unwrap(),
"master:19998"
);
assert_eq!(
client.get_master_rpc_addresses(),
vec!["master:19998".to_string()]
);
}
#[tokio::test]
async fn test_single_master_reset_is_noop() {
let client = SingleMasterInquireClient::new("master:19998".to_string());
client.reset_cached_primary().await;
assert_eq!(
client.get_primary_rpc_address().await.unwrap(),
"master:19998"
);
}
#[tokio::test]
async fn test_polling_client_gate_starts_empty() {
let client = PollingMasterInquireClient::new(
vec!["a:1".to_string(), "b:2".to_string()],
Duration::from_millis(100),
Duration::from_millis(10),
Duration::from_millis(50),
Duration::from_millis(50),
);
let gate = client.poll_gate.lock().await;
assert!(gate.is_none(), "gate should start empty");
}
#[tokio::test]
async fn test_polling_client_addresses() {
let addrs = vec!["host1:19998".to_string(), "host2:19998".to_string()];
let client = PollingMasterInquireClient::new(
addrs.clone(),
Duration::from_millis(100),
Duration::from_millis(10),
Duration::from_millis(50),
Duration::from_millis(50),
);
assert_eq!(client.get_master_rpc_addresses(), addrs);
}
#[tokio::test]
async fn test_polling_client_reset_clears_cache() {
let client = PollingMasterInquireClient::new(
vec!["host:19998".to_string()],
Duration::from_millis(100),
Duration::from_millis(10),
Duration::from_millis(50),
Duration::from_millis(50),
);
{
let mut cache = client.cached_primary.write().await;
*cache = Some("host:19998".to_string());
}
client.reset_cached_primary().await;
let cache = client.cached_primary.read().await;
assert!(cache.is_none(), "cache should be cleared after reset");
}
#[tokio::test]
async fn test_singleflight_gate_broadcast() {
let (tx, rx) = watch::channel::<Option<PollResult>>(None);
let mut follower_rx = rx.clone();
let received = Arc::new(AtomicUsize::new(0));
let received_clone = received.clone();
let follower = tokio::spawn(async move {
follower_rx.changed().await.unwrap();
let value = follower_rx.borrow().clone();
if let Some(PollResult::Ok(addr)) = value {
received_clone.fetch_add(1, Ordering::SeqCst);
addr
} else {
panic!("expected Ok result");
}
});
tokio::time::sleep(Duration::from_millis(5)).await;
tx.send(Some(PollResult::Ok("primary:19998".to_string())))
.unwrap();
let addr = follower.await.unwrap();
assert_eq!(addr, "primary:19998");
assert_eq!(received.load(Ordering::SeqCst), 1);
}
}