use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::RwLock;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::Channel;
use crate::config::ClientConfig;
use crate::connection::ConnectionState;
use crate::error::KubemqError;
use crate::server_info::ServerInfo;
use crate::transport::interceptor::AuthInterceptor;
type ProtoClient =
crate::proto::kubemq::kubemq_client::KubemqClient<InterceptedService<Channel, AuthInterceptor>>;
fn state_to_u8(s: ConnectionState) -> u8 {
match s {
ConnectionState::Idle => 0,
ConnectionState::Ready => 2,
ConnectionState::Closed => 4,
}
}
fn u8_to_state(v: u8) -> ConnectionState {
match v {
0 => ConnectionState::Idle,
2 => ConnectionState::Ready,
4 => ConnectionState::Closed,
_ => ConnectionState::Closed,
}
}
pub(crate) struct GrpcTransport {
client: StdMutex<Option<ProtoClient>>,
config: ClientConfig,
state: AtomicU8,
#[allow(dead_code)]
token: Arc<RwLock<Option<String>>>,
cancel: CancellationToken,
}
impl GrpcTransport {
pub(crate) async fn connect(config: ClientConfig) -> crate::Result<Self> {
let scheme = if config.tls_config.is_some() {
"https"
} else {
"http"
};
let uri = format!("{}://{}:{}", scheme, config.host, config.port);
let mut endpoint =
Channel::from_shared(uri.clone()).map_err(|e| KubemqError::Transport {
operation: "connect".to_string(),
source: Box::new(e),
})?;
endpoint = endpoint
.tcp_nodelay(true)
.keep_alive_timeout(config.keepalive_timeout)
.http2_keep_alive_interval(config.keepalive_time)
.keep_alive_while_idle(config.permit_keepalive_without_stream)
.connect_timeout(config.connection_timeout);
if let Some(ref tls_config) = config.tls_config {
let tonic_tls = tls_config.to_tonic_tls_config().await?;
endpoint = endpoint
.tls_config(tonic_tls)
.map_err(|e| KubemqError::Transport {
operation: "tls_config".to_string(),
source: Box::new(e),
})?;
}
let token = Arc::new(RwLock::new(config.auth_token.clone()));
let cancel = CancellationToken::new();
if let Some(ref provider) = config.credential_provider {
let (token_str, expires_at) = provider.get_token().await?;
*token.write().unwrap() = Some(token_str);
if let Some(expiry) = expires_at {
let provider_clone = Arc::clone(provider);
let token_clone = token.clone();
let cancel_clone = cancel.clone();
const REFRESH_MARGIN: Duration = Duration::from_secs(30);
tokio::spawn(async move {
let mut next_expiry = expiry;
loop {
let now = std::time::SystemTime::now();
let sleep_duration = next_expiry
.duration_since(now)
.unwrap_or(Duration::from_secs(1))
.checked_sub(REFRESH_MARGIN)
.unwrap_or(Duration::from_secs(1));
tokio::select! {
_ = cancel_clone.cancelled() => break,
_ = tokio::time::sleep(sleep_duration) => {
let mut retry_delay = Duration::from_secs(1);
const MAX_RETRY_DELAY: Duration = Duration::from_secs(30);
const MAX_RETRIES: u32 = 5;
for attempt in 0..MAX_RETRIES {
match provider_clone.get_token().await {
Ok((new_token, new_expiry)) => {
*token_clone.write().unwrap() = Some(new_token);
tracing::debug!("Token refreshed successfully");
if let Some(exp) = new_expiry {
next_expiry = exp;
}
break;
}
Err(e) => {
tracing::warn!(
"Token refresh failed (attempt {}/{}): {}",
attempt + 1, MAX_RETRIES, e
);
if attempt + 1 < MAX_RETRIES {
tokio::select! {
_ = cancel_clone.cancelled() => break,
_ = tokio::time::sleep(retry_delay) => {}
}
retry_delay = (retry_delay * 2).min(MAX_RETRY_DELAY);
} else {
tracing::error!(
"Token refresh exhausted all {} retries", MAX_RETRIES
);
}
}
}
}
}
}
}
});
}
}
let channel = if config.check_connection {
endpoint
.connect()
.await
.map_err(|e| KubemqError::Transport {
operation: "connect".to_string(),
source: Box::new(e),
})?
} else {
endpoint.connect_lazy()
};
let interceptor = AuthInterceptor::new(token.clone());
let client = crate::proto::kubemq::kubemq_client::KubemqClient::with_interceptor(
channel,
interceptor,
)
.max_decoding_message_size(config.max_receive_message_size)
.max_encoding_message_size(config.max_send_message_size);
let initial_state = if config.check_connection {
ConnectionState::Ready
} else {
ConnectionState::Idle
};
if let Some(cb) = &config.on_connected {
cb().await;
}
Ok(Self {
client: StdMutex::new(Some(client)),
config,
state: AtomicU8::new(state_to_u8(initial_state)),
token,
cancel,
})
}
pub(crate) fn state(&self) -> ConnectionState {
u8_to_state(self.state.load(Ordering::Acquire))
}
pub(crate) fn client(&self) -> crate::Result<ProtoClient> {
let guard = self.client.lock().unwrap();
guard.clone().ok_or(KubemqError::ClientClosed)
}
pub(crate) fn config(&self) -> &ClientConfig {
&self.config
}
pub(crate) async fn ping(&self) -> crate::Result<ServerInfo> {
let mut client = self.client()?;
let mut request = tonic::Request::new(crate::proto::kubemq::Empty {});
request.set_timeout(self.config.rpc_timeout);
let response = client
.ping(request)
.await
.map_err(|s| KubemqError::from_grpc_status(s, "ping", ""))?;
let result = response.into_inner();
Ok(ServerInfo {
host: result.host,
version: result.version,
server_start_time: result.server_start_time,
server_up_time_seconds: result.server_up_time_seconds,
})
}
pub(crate) async fn close(&self) {
self.state
.store(state_to_u8(ConnectionState::Closed), Ordering::Release);
self.cancel.cancel();
{
let mut guard = self.client.lock().unwrap();
*guard = None;
}
}
pub(crate) fn cancel_token(&self) -> CancellationToken {
self.cancel.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::ConnectionState;
#[test]
fn test_state_to_u8_idle() {
assert_eq!(state_to_u8(ConnectionState::Idle), 0);
}
#[test]
fn test_state_to_u8_ready() {
assert_eq!(state_to_u8(ConnectionState::Ready), 2);
}
#[test]
fn test_state_to_u8_closed() {
assert_eq!(state_to_u8(ConnectionState::Closed), 4);
}
#[test]
fn test_u8_to_state_roundtrip() {
let states = [
ConnectionState::Idle,
ConnectionState::Ready,
ConnectionState::Closed,
];
for state in &states {
assert_eq!(u8_to_state(state_to_u8(*state)), *state);
}
}
#[test]
fn test_u8_to_state_unknown_defaults_to_closed() {
assert_eq!(u8_to_state(5), ConnectionState::Closed);
assert_eq!(u8_to_state(100), ConnectionState::Closed);
assert_eq!(u8_to_state(255), ConnectionState::Closed);
}
#[test]
fn test_u8_to_state_all_valid_values() {
assert_eq!(u8_to_state(0), ConnectionState::Idle);
assert_eq!(u8_to_state(2), ConnectionState::Ready);
assert_eq!(u8_to_state(4), ConnectionState::Closed);
}
}