use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex as AsyncMutex, RwLock};
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::Channel;
use tonic::Streaming;
use tracing::{debug, instrument, warn};
use crate::auth::{ChannelAuthenticator, ChannelIdInterceptor, SaslStreamGuard};
use crate::config::GooseFsConfig;
use crate::error::{Error, Result};
use crate::proto::grpc::block::{
block_worker_client::BlockWorkerClient, write_request, ReadRequest, ReadResponse, RequestType,
WriteRequest, WriteRequestCommand, WriteResponse,
};
use crate::proto::proto::dataserver::{CreateUfsFileOptions, OpenUfsBlockOptions};
#[derive(Clone, Debug)]
pub struct WriteBlockOptions {
pub request_type: RequestType,
pub create_ufs_file_options: Option<CreateUfsFileOptions>,
}
impl Default for WriteBlockOptions {
fn default() -> Self {
Self {
request_type: RequestType::GoosefsBlock,
create_ufs_file_options: None,
}
}
}
pub struct WriteBlockHandle {
block_id: i64,
pub request_tx: mpsc::Sender<WriteRequest>,
response_rx: mpsc::Receiver<std::result::Result<WriteResponse, tonic::Status>>,
_task_handle: tokio::task::JoinHandle<()>,
}
impl WriteBlockHandle {
pub async fn recv_response(&mut self) -> Result<Option<WriteResponse>> {
match self.response_rx.recv().await {
Some(Ok(resp)) => Ok(Some(resp)),
Some(Err(status)) => Err(Error::GrpcError {
message: format!(
"WriteBlock server error for block_id={}: {}",
self.block_id, status
),
source: status,
}),
None => Ok(None),
}
}
pub async fn close(mut self) -> Result<()> {
drop(self.request_tx);
debug!(
block_id = self.block_id,
"closed write stream, waiting for server finalize"
);
while let Some(result) = self.response_rx.recv().await {
match result {
Ok(_resp) => {
debug!(
block_id = self.block_id,
"received final response from server"
);
}
Err(status) => {
return Err(Error::GrpcError {
message: format!(
"WriteBlock server error for block_id={}: {}",
self.block_id, status
),
source: status,
});
}
}
}
Ok(())
}
pub async fn cancel(self) {
drop(self.request_tx);
drop(self.response_rx);
debug!(block_id = self.block_id, "cancelled write stream");
}
}
type AuthenticatedBlockWorkerClient =
BlockWorkerClient<InterceptedService<Channel, ChannelIdInterceptor>>;
#[derive(Clone)]
pub struct WorkerClient {
inner: AuthenticatedBlockWorkerClient,
addr: String,
generation: u64,
_sasl_guard: std::sync::Arc<Option<SaslStreamGuard>>,
}
impl WorkerClient {
pub async fn connect(addr: &str, config: &GooseFsConfig) -> Result<Self> {
let endpoint = Channel::from_shared(format!("http://{}", addr))
.map_err(|e| Error::ConfigError {
message: format!("invalid worker endpoint: {}", e),
})?
.connect_timeout(config.connect_timeout);
let channel = endpoint.connect().await?;
let authenticator =
ChannelAuthenticator::new(config.auth_type, config.auth_username.clone(), None)
.with_auth_timeout(config.auth_timeout);
let mut auth_channel = authenticator.authenticate(channel).await?;
let sasl_guard = auth_channel.take_sasl_guard();
debug!(addr = %addr, auth_type = %config.auth_type, "connected to GooseFS Worker");
Ok(Self {
inner: BlockWorkerClient::new(auth_channel.channel),
addr: addr.to_string(),
generation: 0,
_sasl_guard: std::sync::Arc::new(sasl_guard),
})
}
pub async fn connect_simple(addr: &str, connect_timeout: Duration) -> Result<Self> {
let endpoint = Channel::from_shared(format!("http://{}", addr))
.map_err(|e| Error::ConfigError {
message: format!("invalid worker endpoint: {}", e),
})?
.connect_timeout(connect_timeout);
let channel = endpoint.connect().await?;
let interceptor = ChannelIdInterceptor::new(uuid::Uuid::new_v4().to_string());
let intercepted = InterceptedService::new(channel, interceptor);
debug!(addr = %addr, "connected to GooseFS Worker (no auth)");
Ok(Self {
inner: BlockWorkerClient::new(intercepted),
addr: addr.to_string(),
generation: 0,
_sasl_guard: std::sync::Arc::new(None),
})
}
pub fn from_channel(channel: Channel, addr: String) -> Self {
let interceptor = ChannelIdInterceptor::new("test-no-auth".to_string());
let intercepted = InterceptedService::new(channel, interceptor);
Self {
inner: BlockWorkerClient::new(intercepted),
addr,
generation: 0,
_sasl_guard: std::sync::Arc::new(None),
}
}
#[instrument(skip(self, open_ufs_block_options), fields(block_id = %block_id, offset = %offset, length = %length))]
pub async fn read_block(
&self,
block_id: i64,
offset: i64,
length: i64,
chunk_size: i64,
open_ufs_block_options: Option<OpenUfsBlockOptions>,
) -> Result<(mpsc::Sender<ReadRequest>, Streaming<ReadResponse>)> {
let (tx, rx) = mpsc::channel::<ReadRequest>(32);
let initial_request = ReadRequest {
block_id: Some(block_id),
offset: Some(offset),
length: Some(length),
chunk_size: Some(chunk_size),
open_ufs_block_options,
offset_received: None,
position_short: None,
request_id: None,
capability: None,
block_size: None,
prefetch_window: None,
};
tx.send(initial_request)
.await
.map_err(|_| Error::BlockIoError {
message: "failed to send initial ReadRequest".to_string(),
})?;
let stream = ReceiverStream::new(rx);
let response = self.inner.clone().read_block(stream).await?;
Ok((tx, response.into_inner()))
}
pub async fn read_block_positioned(
&self,
block_id: i64,
offset: i64,
length: i64,
chunk_size: i64,
open_ufs_block_options: Option<OpenUfsBlockOptions>,
) -> Result<(mpsc::Sender<ReadRequest>, Streaming<ReadResponse>)> {
let (tx, rx) = mpsc::channel::<ReadRequest>(32);
let initial_request = ReadRequest {
block_id: Some(block_id),
offset: Some(offset),
length: Some(length),
chunk_size: Some(chunk_size),
open_ufs_block_options,
offset_received: None,
position_short: Some(true), request_id: None,
capability: None,
block_size: None,
prefetch_window: None,
};
tx.send(initial_request)
.await
.map_err(|_| Error::BlockIoError {
message: "failed to send initial positioned ReadRequest".to_string(),
})?;
let stream = ReceiverStream::new(rx);
let response = self.inner.clone().read_block(stream).await?;
Ok((tx, response.into_inner()))
}
#[instrument(skip(self, options), fields(block_id = %block_id))]
pub async fn write_block(
&self,
block_id: i64,
space_to_reserve: i64,
options: WriteBlockOptions,
) -> Result<WriteBlockHandle> {
let (tx, rx) = mpsc::channel::<WriteRequest>(32);
let initial_command = WriteRequest {
value: Some(write_request::Value::Command(WriteRequestCommand {
r#type: Some(options.request_type as i32),
id: Some(block_id),
offset: Some(0),
flush: None,
create_ufs_file_options: options.create_ufs_file_options,
space_to_reserve: Some(space_to_reserve),
capability: None,
medium_type: None,
})),
};
let initial_stream = tokio_stream::once(initial_command);
let subsequent_stream = ReceiverStream::new(rx);
let combined_stream = initial_stream.chain(subsequent_stream);
let (resp_tx, resp_rx) =
mpsc::channel::<std::result::Result<WriteResponse, tonic::Status>>(8);
let mut client = self.inner.clone();
let addr = self.addr.clone();
let task_handle = tokio::spawn(async move {
debug!(block_id = block_id, addr = %addr, "WriteBlock gRPC task started");
let call_result = client.write_block(combined_stream).await;
match call_result {
Ok(response) => {
let mut stream = response.into_inner();
loop {
match stream.message().await {
Ok(Some(msg)) => {
if resp_tx.send(Ok(msg)).await.is_err() {
debug!(block_id = block_id, "response receiver dropped");
break;
}
}
Ok(None) => {
debug!(block_id = block_id, "server closed response stream");
break;
}
Err(status) => {
warn!(block_id = block_id, %status, "server response error");
let _ = resp_tx.send(Err(status)).await;
break;
}
}
}
}
Err(status) => {
warn!(block_id = block_id, %status, "WriteBlock RPC failed");
let _ = resp_tx.send(Err(status)).await;
}
}
debug!(block_id = block_id, "WriteBlock gRPC task finished");
});
debug!(block_id = block_id, "WriteBlock handle created");
Ok(WriteBlockHandle {
block_id,
request_tx: tx,
response_rx: resp_rx,
_task_handle: task_handle,
})
}
pub fn addr(&self) -> &str {
&self.addr
}
pub fn generation(&self) -> u64 {
self.generation
}
}
pub struct WorkerClientPool {
clients: RwLock<HashMap<String, WorkerClient>>,
reconnect_locks: RwLock<HashMap<String, Arc<AsyncMutex<()>>>>,
next_generation: AtomicU64,
config: GooseFsConfig,
}
impl WorkerClientPool {
pub fn new(config: GooseFsConfig) -> Self {
Self {
clients: RwLock::new(HashMap::new()),
reconnect_locks: RwLock::new(HashMap::new()),
next_generation: AtomicU64::new(1),
config,
}
}
pub async fn acquire(&self, addr: &str) -> Result<WorkerClient> {
{
let cache = self.clients.read().await;
if let Some(client) = cache.get(addr) {
debug!(addr = %addr, generation = client.generation, "reusing cached WorkerClient");
return Ok(client.clone());
}
}
let mut cache = self.clients.write().await;
if let Some(client) = cache.get(addr) {
return Ok(client.clone());
}
debug!(addr = %addr, "creating new WorkerClient for pool");
let mut client = WorkerClient::connect(addr, &self.config).await?;
client.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
cache.insert(addr.to_string(), client.clone());
Ok(client)
}
pub async fn invalidate(&self, addr: &str) {
let mut cache = self.clients.write().await;
if cache.remove(addr).is_some() {
debug!(addr = %addr, "invalidated WorkerClient from pool");
}
}
async fn reconnect_lock_for(&self, addr: &str) -> Arc<AsyncMutex<()>> {
{
let locks = self.reconnect_locks.read().await;
if let Some(m) = locks.get(addr) {
return Arc::clone(m);
}
}
let mut locks = self.reconnect_locks.write().await;
Arc::clone(
locks
.entry(addr.to_string())
.or_insert_with(|| Arc::new(AsyncMutex::new(()))),
)
}
pub async fn reconnect_if_stale(
&self,
addr: &str,
stale_generation: u64,
) -> Result<WorkerClient> {
let lock = self.reconnect_lock_for(addr).await;
let _guard = lock.lock().await;
{
let cache = self.clients.read().await;
if let Some(client) = cache.get(addr) {
if client.generation > stale_generation {
debug!(
addr = %addr,
observed = stale_generation,
current = client.generation,
"reconnect coalesced — another task already refreshed this channel"
);
return Ok(client.clone());
}
}
}
debug!(
addr = %addr,
stale_generation = stale_generation,
"performing single-flight reconnect"
);
{
let mut cache = self.clients.write().await;
cache.remove(addr);
}
let mut fresh = WorkerClient::connect(addr, &self.config).await?;
fresh.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
{
let mut cache = self.clients.write().await;
cache.insert(addr.to_string(), fresh.clone());
}
debug!(
addr = %addr,
new_generation = fresh.generation,
"single-flight reconnect installed fresh WorkerClient"
);
Ok(fresh)
}
pub async fn reconnect(&self, addr: &str) -> Result<WorkerClient> {
self.reconnect_if_stale(addr, u64::MAX).await
}
pub fn new_shared(config: GooseFsConfig) -> Arc<Self> {
Arc::new(Self::new(config))
}
#[cfg(test)]
async fn test_install(&self, addr: &str, mut client: WorkerClient) -> Option<WorkerClient> {
client.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
let mut cache = self.clients.write().await;
cache.insert(addr.to_string(), client)
}
#[cfg(test)]
async fn test_current_generation(&self, addr: &str) -> Option<u64> {
let cache = self.clients.read().await;
cache.get(addr).map(|c| c.generation)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tonic::transport::Channel;
fn fake_client(addr: &str) -> WorkerClient {
let channel = Channel::from_static("http://127.0.0.1:1").connect_lazy();
WorkerClient::from_channel(channel, addr.to_string())
}
#[tokio::test]
async fn test_reconnect_if_stale_coalesces_when_generation_advanced() {
let pool = WorkerClientPool::new(GooseFsConfig::new("127.0.0.1:9200"));
let addr = "test-worker:9203";
pool.test_install(addr, fake_client(addr)).await;
let gen_before = pool.test_current_generation(addr).await.unwrap();
pool.test_install(addr, fake_client(addr)).await;
let gen_after = pool.test_current_generation(addr).await.unwrap();
assert!(gen_after > gen_before);
let result = pool.reconnect_if_stale(addr, gen_before).await;
assert!(
result.is_ok(),
"coalesced reconnect must short-circuit without network I/O, got {:?}",
result.err()
);
let returned = result.unwrap();
assert_eq!(
returned.generation(),
gen_after,
"caller must receive the already-replaced generation"
);
assert_eq!(
pool.test_current_generation(addr).await,
Some(gen_after),
"cached generation must not advance for a coalesced caller"
);
}
#[tokio::test]
async fn test_reconnect_locks_are_per_address() {
let pool = WorkerClientPool::new(GooseFsConfig::new("127.0.0.1:9200"));
let lock_a = pool.reconnect_lock_for("worker-a:9203").await;
let lock_b = pool.reconnect_lock_for("worker-b:9203").await;
let guard_a = lock_a.lock().await;
let guard_b = tokio::time::timeout(std::time::Duration::from_millis(50), lock_b.lock())
.await
.expect("lock for different address must not be blocked");
drop(guard_b);
drop(guard_a);
}
#[tokio::test]
async fn test_generation_is_monotonic_across_installs() {
let pool = WorkerClientPool::new(GooseFsConfig::new("127.0.0.1:9200"));
let addr = "w:9203";
pool.test_install(addr, fake_client(addr)).await;
let g1 = pool.test_current_generation(addr).await.unwrap();
pool.test_install(addr, fake_client(addr)).await;
let g2 = pool.test_current_generation(addr).await.unwrap();
pool.test_install(addr, fake_client(addr)).await;
let g3 = pool.test_current_generation(addr).await.unwrap();
assert!(g1 < g2, "gen {} not less than {}", g1, g2);
assert!(g2 < g3, "gen {} not less than {}", g2, g3);
}
}