Skip to main content

goosefs_sdk/client/
worker.rs

1//! Goosefs Worker gRPC client for block data read/write.
2//!
3//! Wraps `BlockWorker` service (Worker:9203) providing:
4//! - `read_block` — bidirectional streaming block read
5//! - `write_block` — bidirectional streaming block write
6//!
7//! ## Write Protocol
8//!
9//! Goosefs Worker's `WriteBlock` is a bidirectional streaming RPC but the server
10//! does **not** send HTTP/2 response headers until the client sends a `flush`
11//! command or closes the stream. This means tonic's
12//! `client.write_block(stream).await` will block until the first server response.
13//!
14//! To work around this, `write_block()` returns a [`WriteBlockHandle`] that
15//! runs the gRPC call in a background tokio task. The caller sends data chunks
16//! through the request sender, then calls `flush()` or `close()` on the handle
17//! to receive server responses.
18
19use std::collections::HashMap;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22use std::time::Duration;
23
24use tokio::sync::{mpsc, Mutex as AsyncMutex, RwLock};
25use tokio_stream::wrappers::ReceiverStream;
26use tokio_stream::StreamExt;
27use tonic::service::interceptor::InterceptedService;
28use tonic::transport::Channel;
29use tonic::Streaming;
30use tracing::{debug, instrument, warn};
31
32use crate::auth::{ChannelAuthenticator, ChannelIdInterceptor, SaslStreamGuard};
33use crate::config::GoosefsConfig;
34use crate::error::{Error, Result};
35use crate::proto::grpc::block::{
36    block_worker_client::BlockWorkerClient, write_request, ReadRequest, ReadResponse, RequestType,
37    WriteRequest, WriteRequestCommand, WriteResponse,
38};
39use crate::proto::proto::dataserver::{CreateUfsFileOptions, OpenUfsBlockOptions};
40
41/// Options for a `write_block` RPC that control *where* the Worker writes data.
42///
43/// - `GoosefsBlock` (default): write to Goosefs cache (MUST_CACHE / CACHE_THROUGH / ASYNC_THROUGH)
44/// - `UfsFile`: write directly to UFS (THROUGH mode), requires `create_ufs_file_options`
45/// - `UfsFallbackBlock`: cache-full fallback to UFS (TRY_CACHE)
46#[derive(Clone, Debug)]
47pub struct WriteBlockOptions {
48    /// The request type sent in the initial `WriteRequestCommand`.
49    pub request_type: RequestType,
50    /// UFS file creation options (required when `request_type == UfsFile`).
51    pub create_ufs_file_options: Option<CreateUfsFileOptions>,
52}
53
54impl Default for WriteBlockOptions {
55    fn default() -> Self {
56        Self {
57            request_type: RequestType::GoosefsBlock,
58            create_ufs_file_options: None,
59        }
60    }
61}
62
63/// Handle for an in-progress `WriteBlock` bidirectional streaming RPC.
64///
65/// The gRPC call runs in a background tokio task. The caller sends data through
66/// `request_tx` and receives responses via `recv_response()`. When done, call
67/// `close()` to drop the request channel and wait for the server to finalize.
68pub struct WriteBlockHandle {
69    /// Block being written.
70    block_id: i64,
71    /// Sender for client → server WriteRequest messages (data chunks, flush commands).
72    pub request_tx: mpsc::Sender<WriteRequest>,
73    /// Receiver for server → client WriteResponse messages, forwarded from the background task.
74    response_rx: mpsc::Receiver<std::result::Result<WriteResponse, tonic::Status>>,
75    /// Handle to the background gRPC task.
76    _task_handle: tokio::task::JoinHandle<()>,
77}
78
79impl WriteBlockHandle {
80    /// Receive the next `WriteResponse` from the server (e.g., flush ack).
81    ///
82    /// Returns `None` if the server has closed the response stream.
83    pub async fn recv_response(&mut self) -> Result<Option<WriteResponse>> {
84        match self.response_rx.recv().await {
85            Some(Ok(resp)) => Ok(Some(resp)),
86            Some(Err(status)) => Err(Error::GrpcError {
87                message: format!(
88                    "WriteBlock server error for block_id={}: {}",
89                    self.block_id, status
90                ),
91                source: status,
92            }),
93            None => Ok(None),
94        }
95    }
96
97    /// Close the write stream by dropping the request sender and wait for
98    /// any final response from the server.
99    pub async fn close(mut self) -> Result<()> {
100        // Drop the request sender to close the client→server half of the stream.
101        // The server will then call onCompleted → commitBlock → replySuccess.
102        drop(self.request_tx);
103        debug!(
104            block_id = self.block_id,
105            "closed write stream, waiting for server finalize"
106        );
107        // Wait for the server's final response (or stream close).
108        // This ensures the background task finishes before we return,
109        // preventing the Channel from being dropped while the task is still running.
110        while let Some(result) = self.response_rx.recv().await {
111            match result {
112                Ok(_resp) => {
113                    debug!(
114                        block_id = self.block_id,
115                        "received final response from server"
116                    );
117                }
118                Err(status) => {
119                    return Err(Error::GrpcError {
120                        message: format!(
121                            "WriteBlock server error for block_id={}: {}",
122                            self.block_id, status
123                        ),
124                        source: status,
125                    });
126                }
127            }
128        }
129        Ok(())
130    }
131
132    /// Cancel the write stream without waiting for server finalization.
133    ///
134    /// Drops the request sender and response receiver immediately.
135    /// The server will detect the stream cancellation and clean up.
136    /// Matches Java's `GrpcBlockingStream.cancel()`.
137    pub async fn cancel(self) {
138        drop(self.request_tx);
139        drop(self.response_rx);
140        debug!(block_id = self.block_id, "cancelled write stream");
141    }
142}
143
144/// Type alias for the authenticated Worker gRPC client.
145type AuthenticatedBlockWorkerClient =
146    BlockWorkerClient<InterceptedService<Channel, ChannelIdInterceptor>>;
147
148/// Client for `BlockWorker` service on a single worker node.
149///
150/// Each `WorkerClient` carries a monotonic `generation` tag assigned by
151/// [`WorkerClientPool`] at construction time.  The generation allows callers
152/// that observed a failure on a specific client to request a **single-flight
153/// reconnect** via [`WorkerClientPool::reconnect_if_stale`]: only the first
154/// observer of generation `N` actually re-establishes the TCP+SASL
155/// connection; all concurrent observers with the same (or older) generation
156/// simply receive the already-replaced client.  This collapses the
157/// "thundering-herd reconnect" that previously produced hundreds of duplicate
158/// `authentication failed` warnings when a SASL session expired.
159#[derive(Clone)]
160pub struct WorkerClient {
161    inner: AuthenticatedBlockWorkerClient,
162    addr: String,
163    /// Monotonic tag identifying this exact connection instance.
164    ///
165    /// Two clients cached for the same address must have different
166    /// generations; a caller that observes a failure on generation `N` can
167    /// ask the pool to reconnect *only if* generation has not advanced yet.
168    generation: u64,
169    /// Keeps the SASL authentication stream alive for the channel's lifetime.
170    _sasl_guard: std::sync::Arc<Option<SaslStreamGuard>>,
171}
172
173impl WorkerClient {
174    /// Connect to a Goosefs Worker at the given address with authentication.
175    ///
176    /// Authentication is performed according to `config.auth_type`.
177    pub async fn connect(addr: &str, config: &GoosefsConfig) -> Result<Self> {
178        let endpoint = Channel::from_shared(format!("http://{}", addr))
179            .map_err(|e| Error::ConfigError {
180                message: format!("invalid worker endpoint: {}", e),
181            })?
182            .connect_timeout(config.connect_timeout);
183
184        let channel = endpoint.connect().await?;
185
186        // Perform SASL authentication based on the configured auth type
187        let authenticator =
188            ChannelAuthenticator::new(config.auth_type, config.auth_username.clone(), None)
189                .with_auth_timeout(config.auth_timeout);
190
191        let mut auth_channel = authenticator.authenticate(channel).await?;
192        let sasl_guard = auth_channel.take_sasl_guard();
193        debug!(addr = %addr, auth_type = %config.auth_type, "connected to Goosefs Worker");
194
195        Ok(Self {
196            inner: BlockWorkerClient::new(auth_channel.channel),
197            addr: addr.to_string(),
198            generation: 0,
199            _sasl_guard: std::sync::Arc::new(sasl_guard),
200        })
201    }
202
203    /// Connect to a Goosefs Worker with only connect_timeout (backward compatible, NOSASL).
204    ///
205    /// **Deprecated**: Use `connect(addr, config)` instead for proper authentication.
206    pub async fn connect_simple(addr: &str, connect_timeout: Duration) -> Result<Self> {
207        let endpoint = Channel::from_shared(format!("http://{}", addr))
208            .map_err(|e| Error::ConfigError {
209                message: format!("invalid worker endpoint: {}", e),
210            })?
211            .connect_timeout(connect_timeout);
212
213        let channel = endpoint.connect().await?;
214        let interceptor = ChannelIdInterceptor::new(uuid::Uuid::new_v4().to_string());
215        let intercepted = InterceptedService::new(channel, interceptor);
216        debug!(addr = %addr, "connected to Goosefs Worker (no auth)");
217
218        Ok(Self {
219            inner: BlockWorkerClient::new(intercepted),
220            addr: addr.to_string(),
221            generation: 0,
222            _sasl_guard: std::sync::Arc::new(None),
223        })
224    }
225
226    /// Create from an existing tonic channel (useful for testing / channel sharing).
227    ///
228    /// **Note**: This bypasses authentication.
229    pub fn from_channel(channel: Channel, addr: String) -> Self {
230        let interceptor = ChannelIdInterceptor::new("test-no-auth".to_string());
231        let intercepted = InterceptedService::new(channel, interceptor);
232        Self {
233            inner: BlockWorkerClient::new(intercepted),
234            addr,
235            generation: 0,
236            _sasl_guard: std::sync::Arc::new(None),
237        }
238    }
239
240    /// Start a bidirectional streaming ReadBlock RPC.
241    ///
242    /// Returns: (request_sender, response_stream)
243    ///
244    /// The caller sends an initial `ReadRequest` with block_id/offset/length,
245    /// then sends periodic `offset_received` ACKs. The response stream yields
246    /// `ReadResponse` containing `Chunk` data.
247    ///
248    /// When the block is only stored in UFS (e.g. written with `THROUGH` mode),
249    /// `open_ufs_block_options` must be provided so the Worker knows how to
250    /// locate and read the data from the underlying storage.
251    #[instrument(skip(self, open_ufs_block_options), fields(block_id = %block_id, offset = %offset, length = %length))]
252    pub async fn read_block(
253        &self,
254        block_id: i64,
255        offset: i64,
256        length: i64,
257        chunk_size: i64,
258        open_ufs_block_options: Option<OpenUfsBlockOptions>,
259    ) -> Result<(mpsc::Sender<ReadRequest>, Streaming<ReadResponse>)> {
260        let (tx, rx) = mpsc::channel::<ReadRequest>(32);
261
262        // Send the initial read request
263        let initial_request = ReadRequest {
264            block_id: Some(block_id),
265            offset: Some(offset),
266            length: Some(length),
267            chunk_size: Some(chunk_size),
268            open_ufs_block_options,
269            offset_received: None,
270            position_short: None,
271            request_id: None,
272            capability: None,
273            block_size: None,
274            prefetch_window: None,
275        };
276        tx.send(initial_request)
277            .await
278            .map_err(|_| Error::BlockIoError {
279                message: "failed to send initial ReadRequest".to_string(),
280            })?;
281
282        let stream = ReceiverStream::new(rx);
283        let response = self.inner.clone().read_block(stream).await?;
284
285        Ok((tx, response.into_inner()))
286    }
287
288    /// Open a positioned (random-access) block read stream.
289    ///
290    /// Identical to [`read_block`](Self::read_block) but sets `position_short = true` in the
291    /// initial `ReadRequest`, instructing the worker to skip prefetch and
292    /// serve the exact requested byte range.
293    ///
294    /// Used by [`crate::io::reader::GrpcBlockReader::positioned_read`].
295    pub async fn read_block_positioned(
296        &self,
297        block_id: i64,
298        offset: i64,
299        length: i64,
300        chunk_size: i64,
301        open_ufs_block_options: Option<OpenUfsBlockOptions>,
302    ) -> Result<(mpsc::Sender<ReadRequest>, Streaming<ReadResponse>)> {
303        let (tx, rx) = mpsc::channel::<ReadRequest>(32);
304
305        let initial_request = ReadRequest {
306            block_id: Some(block_id),
307            offset: Some(offset),
308            length: Some(length),
309            chunk_size: Some(chunk_size),
310            open_ufs_block_options,
311            offset_received: None,
312            position_short: Some(true), // positioned-read hint to worker
313            request_id: None,
314            capability: None,
315            block_size: None,
316            prefetch_window: None,
317        };
318        tx.send(initial_request)
319            .await
320            .map_err(|_| Error::BlockIoError {
321                message: "failed to send initial positioned ReadRequest".to_string(),
322            })?;
323
324        let stream = ReceiverStream::new(rx);
325        let response = self.inner.clone().read_block(stream).await?;
326
327        Ok((tx, response.into_inner()))
328    }
329
330    /// Start a bidirectional streaming WriteBlock RPC.
331    ///
332    /// Returns a [`WriteBlockHandle`] that manages the background gRPC task.
333    /// The caller sends data chunks through `handle.request_tx`, then calls
334    /// `handle.recv_response()` to get flush acknowledgements.
335    ///
336    /// ## Why a background task?
337    ///
338    /// Goosefs Worker's `WriteBlock` RPC does **not** send HTTP/2 response
339    /// headers until the client sends a `flush` command or closes the stream.
340    /// tonic's `client.write_block(stream).await` waits for response headers
341    /// before resolving, so calling it inline would deadlock — we'd need the
342    /// returned sender to send flush, but we can't get the sender until the
343    /// call resolves.
344    ///
345    /// By spawning the gRPC call in a background task and forwarding responses
346    /// through an mpsc channel, we decouple request sending from response
347    /// receiving.
348    #[instrument(skip(self, options), fields(block_id = %block_id))]
349    pub async fn write_block(
350        &self,
351        block_id: i64,
352        space_to_reserve: i64,
353        options: WriteBlockOptions,
354    ) -> Result<WriteBlockHandle> {
355        let (tx, rx) = mpsc::channel::<WriteRequest>(32);
356
357        // Build the initial write command
358        let initial_command = WriteRequest {
359            value: Some(write_request::Value::Command(WriteRequestCommand {
360                r#type: Some(options.request_type as i32),
361                id: Some(block_id),
362                offset: Some(0),
363                flush: None,
364                create_ufs_file_options: options.create_ufs_file_options,
365                space_to_reserve: Some(space_to_reserve),
366                capability: None,
367                medium_type: None,
368            })),
369        };
370
371        // Build a composite stream: initial command first, then channel messages.
372        let initial_stream = tokio_stream::once(initial_command);
373        let subsequent_stream = ReceiverStream::new(rx);
374        let combined_stream = initial_stream.chain(subsequent_stream);
375
376        // Channel for forwarding server responses from the background task.
377        let (resp_tx, resp_rx) =
378            mpsc::channel::<std::result::Result<WriteResponse, tonic::Status>>(8);
379
380        let mut client = self.inner.clone();
381        let addr = self.addr.clone();
382
383        let task_handle = tokio::spawn(async move {
384            debug!(block_id = block_id, addr = %addr, "WriteBlock gRPC task started");
385
386            // This call blocks until the server sends response headers,
387            // which happens on the first flush or stream close.
388            let call_result = client.write_block(combined_stream).await;
389
390            match call_result {
391                Ok(response) => {
392                    let mut stream = response.into_inner();
393                    // Forward all server responses to the caller.
394                    loop {
395                        match stream.message().await {
396                            Ok(Some(msg)) => {
397                                if resp_tx.send(Ok(msg)).await.is_err() {
398                                    debug!(block_id = block_id, "response receiver dropped");
399                                    break;
400                                }
401                            }
402                            Ok(None) => {
403                                debug!(block_id = block_id, "server closed response stream");
404                                break;
405                            }
406                            Err(status) => {
407                                warn!(block_id = block_id, %status, "server response error");
408                                let _ = resp_tx.send(Err(status)).await;
409                                break;
410                            }
411                        }
412                    }
413                }
414                Err(status) => {
415                    warn!(block_id = block_id, %status, "WriteBlock RPC failed");
416                    let _ = resp_tx.send(Err(status)).await;
417                }
418            }
419
420            debug!(block_id = block_id, "WriteBlock gRPC task finished");
421        });
422
423        debug!(block_id = block_id, "WriteBlock handle created");
424
425        Ok(WriteBlockHandle {
426            block_id,
427            request_tx: tx,
428            response_rx: resp_rx,
429            _task_handle: task_handle,
430        })
431    }
432
433    /// The worker address this client is connected to.
434    pub fn addr(&self) -> &str {
435        &self.addr
436    }
437
438    /// The monotonic generation tag assigned by the pool.
439    ///
440    /// Callers should save this value alongside the `WorkerClient` when
441    /// starting an RPC; if the RPC fails with an authentication error they
442    /// pass the saved generation back to
443    /// [`WorkerClientPool::reconnect_if_stale`] to trigger a single-flight
444    /// reconnect (de-duplicating concurrent observers of the same failure).
445    pub fn generation(&self) -> u64 {
446        self.generation
447    }
448}
449
450/// Connection pool for `WorkerClient` instances.
451///
452/// Caches authenticated gRPC channels by worker address, avoiding the overhead
453/// of re-establishing connections and re-authenticating for every block I/O.
454/// Matches Java's `FileSystemContext.acquireBlockWorkerClient()` pattern.
455///
456/// The pool is thread-safe and can be shared across concurrent workers.
457///
458/// ## Single-Flight Reconnect
459///
460/// When a SASL stream silently expires server-side, many concurrent RPCs on
461/// the same cached channel will fail simultaneously with UNAUTHENTICATED.
462/// Without coordination each observer would independently invoke `reconnect`,
463/// producing a "thundering herd" that serialises through the pool's write
464/// lock and wastes CPU/RTT on duplicate TCP+SASL handshakes.
465///
466/// To collapse this herd, each [`WorkerClient`] carries a monotonic
467/// `generation` tag.  Callers pass the observed generation back into
468/// [`reconnect_if_stale`](Self::reconnect_if_stale) after an auth failure;
469/// only the **first** observer of a given generation actually performs the
470/// reconnect, all other concurrent observers receive the already-replaced
471/// client.  This reduces N concurrent reconnects to exactly 1.
472pub struct WorkerClientPool {
473    /// Cached worker clients keyed by `"host:port"` address.
474    ///
475    /// The stored client carries its own `generation` in-band; readers simply
476    /// clone it and inspect `client.generation()`.
477    clients: RwLock<HashMap<String, WorkerClient>>,
478    /// Per-address async mutex guarding the reconnect critical section.
479    ///
480    /// Separated from `clients` so the reconnect handshake (which performs
481    /// network I/O) does not hold the clients-map write lock.  Acquiring this
482    /// mutex for one address does not block other addresses' reconnects.
483    reconnect_locks: RwLock<HashMap<String, Arc<AsyncMutex<()>>>>,
484    /// Monotonic counter used to hand out a unique `generation` for every
485    /// freshly-created `WorkerClient`.
486    next_generation: AtomicU64,
487    /// Config used to create new connections.
488    config: GoosefsConfig,
489}
490
491impl WorkerClientPool {
492    /// Create a new empty connection pool.
493    pub fn new(config: GoosefsConfig) -> Self {
494        Self {
495            clients: RwLock::new(HashMap::new()),
496            reconnect_locks: RwLock::new(HashMap::new()),
497            // Start generations at 1 so `0` (the default on constructed-but-
498            // never-pooled clients) is always "stale" relative to any pooled
499            // client — this makes `reconnect_if_stale(addr, 0)` always force
500            // a fresh connection when needed.
501            next_generation: AtomicU64::new(1),
502            config,
503        }
504    }
505
506    /// Acquire a `WorkerClient` for the given address.
507    ///
508    /// Returns a cached client if one exists, otherwise creates a new connection.
509    /// The tonic `Channel` supports multiplexing, so a single cached client can
510    /// handle multiple concurrent RPCs.
511    pub async fn acquire(&self, addr: &str) -> Result<WorkerClient> {
512        // Fast path: check read lock first
513        {
514            let cache = self.clients.read().await;
515            if let Some(client) = cache.get(addr) {
516                debug!(addr = %addr, generation = client.generation, "reusing cached WorkerClient");
517                return Ok(client.clone());
518            }
519        }
520
521        // Slow path: create new connection under write lock
522        let mut cache = self.clients.write().await;
523        // Double-check after acquiring write lock (another task may have inserted)
524        if let Some(client) = cache.get(addr) {
525            return Ok(client.clone());
526        }
527
528        debug!(addr = %addr, "creating new WorkerClient for pool");
529        let mut client = WorkerClient::connect(addr, &self.config).await?;
530        client.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
531        cache.insert(addr.to_string(), client.clone());
532        Ok(client)
533    }
534
535    /// Remove a worker from the pool (e.g., after a connection failure).
536    ///
537    /// The next `acquire()` call for this address will create a fresh connection.
538    pub async fn invalidate(&self, addr: &str) {
539        let mut cache = self.clients.write().await;
540        if cache.remove(addr).is_some() {
541            debug!(addr = %addr, "invalidated WorkerClient from pool");
542        }
543    }
544
545    /// Get (or lazily create) the per-address reconnect mutex.
546    async fn reconnect_lock_for(&self, addr: &str) -> Arc<AsyncMutex<()>> {
547        {
548            let locks = self.reconnect_locks.read().await;
549            if let Some(m) = locks.get(addr) {
550                return Arc::clone(m);
551            }
552        }
553        let mut locks = self.reconnect_locks.write().await;
554        Arc::clone(
555            locks
556                .entry(addr.to_string())
557                .or_insert_with(|| Arc::new(AsyncMutex::new(()))),
558        )
559    }
560
561    /// **Single-flight reconnect**: invalidate + reconnect only if the
562    /// currently cached client's generation still matches `stale_generation`.
563    ///
564    /// This is the preferred recovery path on authentication failure.  The
565    /// caller passes the `generation()` of the client that just failed;
566    /// because every `WorkerClient` carries a unique monotonic generation
567    /// allocated by this pool:
568    ///
569    /// - If another concurrent task has **already** reconnected in response
570    ///   to the same underlying SASL expiry, the cached generation will have
571    ///   advanced past `stale_generation` and this call returns the
572    ///   already-replaced client **without** performing another
573    ///   TCP+SASL handshake.
574    /// - Otherwise, this call performs exactly one reconnect under the
575    ///   per-address mutex.
576    ///
577    /// Net effect: N concurrent `AuthenticationFailed` observers on the
578    /// same channel trigger exactly **one** reconnect instead of N.
579    pub async fn reconnect_if_stale(
580        &self,
581        addr: &str,
582        stale_generation: u64,
583    ) -> Result<WorkerClient> {
584        // Take the per-address reconnect mutex.  Concurrent callers for the
585        // same address serialise here; callers for *different* addresses do
586        // not block each other.
587        let lock = self.reconnect_lock_for(addr).await;
588        let _guard = lock.lock().await;
589
590        // Under the mutex, re-check the cache.  If another task already
591        // replaced the stale client while we were queuing, skip the
592        // reconnect entirely.
593        {
594            let cache = self.clients.read().await;
595            if let Some(client) = cache.get(addr) {
596                if client.generation > stale_generation {
597                    debug!(
598                        addr = %addr,
599                        observed = stale_generation,
600                        current = client.generation,
601                        "reconnect coalesced — another task already refreshed this channel"
602                    );
603                    return Ok(client.clone());
604                }
605            }
606        }
607
608        // We are the designated reconnect-er: drop the stale entry, then
609        // build and install a new one.
610        debug!(
611            addr = %addr,
612            stale_generation = stale_generation,
613            "performing single-flight reconnect"
614        );
615        {
616            let mut cache = self.clients.write().await;
617            cache.remove(addr);
618        }
619        let mut fresh = WorkerClient::connect(addr, &self.config).await?;
620        fresh.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
621        {
622            let mut cache = self.clients.write().await;
623            cache.insert(addr.to_string(), fresh.clone());
624        }
625        debug!(
626            addr = %addr,
627            new_generation = fresh.generation,
628            "single-flight reconnect installed fresh WorkerClient"
629        );
630        Ok(fresh)
631    }
632
633    /// Invalidate a cached worker connection and immediately reconnect.
634    ///
635    /// **Prefer [`reconnect_if_stale`](Self::reconnect_if_stale) whenever the
636    /// caller holds a reference to the failing `WorkerClient`** — it
637    /// deduplicates concurrent reconnects triggered by the same underlying
638    /// SASL expiry.
639    ///
640    /// This unconditional variant is kept for paths where the caller does
641    /// not know the generation of the failing client (e.g. a stand-alone
642    /// `connect()` failure that never produced a `WorkerClient`).  It
643    /// acquires the same per-address reconnect mutex so it still coalesces
644    /// against any in-flight `reconnect_if_stale`.
645    pub async fn reconnect(&self, addr: &str) -> Result<WorkerClient> {
646        // Use `u64::MAX` as "stale" so `reconnect_if_stale` always proceeds
647        // with the handshake (current generation can never exceed MAX).
648        // This still passes through the per-address mutex so concurrent
649        // callers on the same address share a single handshake.
650        self.reconnect_if_stale(addr, u64::MAX).await
651    }
652
653    /// Create a new pool wrapped in `Arc` for shared ownership.
654    pub fn new_shared(config: GoosefsConfig) -> Arc<Self> {
655        Arc::new(Self::new(config))
656    }
657
658    // ── Test-only helpers ────────────────────────────────────────────
659    //
660    // These helpers are gated on `cfg(test)` so downstream code cannot
661    // accidentally inject bypass-auth clients into the pool.  They exist
662    // purely to let the unit tests in this module drive the single-flight
663    // reconnect logic without needing a live Worker process to handshake
664    // against.
665
666    /// Manually insert a client with a specific `generation` into the
667    /// pool for testing.  Returns the previously-cached client, if any.
668    #[cfg(test)]
669    async fn test_install(&self, addr: &str, mut client: WorkerClient) -> Option<WorkerClient> {
670        client.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
671        let mut cache = self.clients.write().await;
672        cache.insert(addr.to_string(), client)
673    }
674
675    /// Snapshot the current cached generation for `addr` (if any).
676    #[cfg(test)]
677    async fn test_current_generation(&self, addr: &str) -> Option<u64> {
678        let cache = self.clients.read().await;
679        cache.get(addr).map(|c| c.generation)
680    }
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686    use tonic::transport::Channel;
687
688    /// Fabricate a `WorkerClient` from a *never-connected* channel.  The
689    /// client is fully usable for anything that only touches the in-memory
690    /// struct (addr/generation lookups, clone, drop), which is all the
691    /// coalesce tests need.
692    fn fake_client(addr: &str) -> WorkerClient {
693        // `Channel::from_static` is synchronous and does not open a TCP
694        // connection; any actual RPC on this channel would fail but the
695        // tests below never issue one.
696        let channel = Channel::from_static("http://127.0.0.1:1").connect_lazy();
697        WorkerClient::from_channel(channel, addr.to_string())
698    }
699
700    #[tokio::test]
701    async fn test_reconnect_if_stale_coalesces_when_generation_advanced() {
702        // Scenario: generation 5 is cached.  Caller A "observes" a failure
703        // on gen 5 and calls reconnect_if_stale(5).  Before it enters the
704        // critical section, caller B has already replaced gen 5 with gen 6
705        // (simulated by manually bumping via test_install).  Caller A must
706        // NOT trigger a second reconnect — it should return gen 6 as-is.
707        let pool = WorkerClientPool::new(GoosefsConfig::new("127.0.0.1:9200"));
708        let addr = "test-worker:9203";
709
710        // Install a gen-1 client, then another gen-2 client (simulating
711        // "someone else already reconnected").
712        pool.test_install(addr, fake_client(addr)).await;
713        let gen_before = pool.test_current_generation(addr).await.unwrap();
714        pool.test_install(addr, fake_client(addr)).await;
715        let gen_after = pool.test_current_generation(addr).await.unwrap();
716        assert!(gen_after > gen_before);
717
718        // Caller passes the *old* generation — pool must short-circuit and
719        // NOT call WorkerClient::connect (which would fail against a
720        // non-existent host and fail the test).
721        let result = pool.reconnect_if_stale(addr, gen_before).await;
722        assert!(
723            result.is_ok(),
724            "coalesced reconnect must short-circuit without network I/O, got {:?}",
725            result.err()
726        );
727        let returned = result.unwrap();
728        assert_eq!(
729            returned.generation(),
730            gen_after,
731            "caller must receive the already-replaced generation"
732        );
733        assert_eq!(
734            pool.test_current_generation(addr).await,
735            Some(gen_after),
736            "cached generation must not advance for a coalesced caller"
737        );
738    }
739
740    #[tokio::test]
741    async fn test_reconnect_locks_are_per_address() {
742        // Acquiring the reconnect lock for addr-A must not block acquiring
743        // the lock for addr-B.  Without per-address locks, unrelated worker
744        // reconnects would serialise through one global mutex.
745        let pool = WorkerClientPool::new(GoosefsConfig::new("127.0.0.1:9200"));
746        let lock_a = pool.reconnect_lock_for("worker-a:9203").await;
747        let lock_b = pool.reconnect_lock_for("worker-b:9203").await;
748
749        // Hold A, must still be able to grab B immediately.
750        let guard_a = lock_a.lock().await;
751        let guard_b = tokio::time::timeout(std::time::Duration::from_millis(50), lock_b.lock())
752            .await
753            .expect("lock for different address must not be blocked");
754        drop(guard_b);
755        drop(guard_a);
756    }
757
758    #[tokio::test]
759    async fn test_generation_is_monotonic_across_installs() {
760        let pool = WorkerClientPool::new(GoosefsConfig::new("127.0.0.1:9200"));
761        let addr = "w:9203";
762
763        pool.test_install(addr, fake_client(addr)).await;
764        let g1 = pool.test_current_generation(addr).await.unwrap();
765
766        pool.test_install(addr, fake_client(addr)).await;
767        let g2 = pool.test_current_generation(addr).await.unwrap();
768
769        pool.test_install(addr, fake_client(addr)).await;
770        let g3 = pool.test_current_generation(addr).await.unwrap();
771
772        assert!(g1 < g2, "gen {} not less than {}", g1, g2);
773        assert!(g2 < g3, "gen {} not less than {}", g2, g3);
774    }
775}