goosefs_sdk/client/worker.rs
1//! Goosefs Worker gRPC client for block data read/write.
2//!
3//! Wraps `BlockWorker` service (Worker:9203) providing:
4//! - `read_block` — bidirectional streaming block read
5//! - `write_block` — bidirectional streaming block write
6//!
7//! ## Write Protocol
8//!
9//! Goosefs Worker's `WriteBlock` is a bidirectional streaming RPC but the server
10//! does **not** send HTTP/2 response headers until the client sends a `flush`
11//! command or closes the stream. This means tonic's
12//! `client.write_block(stream).await` will block until the first server response.
13//!
14//! To work around this, `write_block()` returns a [`WriteBlockHandle`] that
15//! runs the gRPC call in a background tokio task. The caller sends data chunks
16//! through the request sender, then calls `flush()` or `close()` on the handle
17//! to receive server responses.
18
19use std::collections::HashMap;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22use std::time::Duration;
23
24use tokio::sync::{mpsc, Mutex as AsyncMutex, RwLock};
25use tokio_stream::wrappers::ReceiverStream;
26use tokio_stream::StreamExt;
27use tonic::service::interceptor::InterceptedService;
28use tonic::transport::Channel;
29use tonic::Streaming;
30use tracing::{debug, instrument, warn};
31
32use crate::auth::{ChannelAuthenticator, ChannelIdInterceptor, SaslStreamGuard};
33use crate::config::GoosefsConfig;
34use crate::error::{Error, Result};
35use crate::proto::grpc::block::{
36 block_worker_client::BlockWorkerClient, write_request, ReadRequest, ReadResponse, RequestType,
37 WriteRequest, WriteRequestCommand, WriteResponse,
38};
39use crate::proto::proto::dataserver::{CreateUfsFileOptions, OpenUfsBlockOptions};
40
41/// Options for a `write_block` RPC that control *where* the Worker writes data.
42///
43/// - `GoosefsBlock` (default): write to Goosefs cache (MUST_CACHE / CACHE_THROUGH / ASYNC_THROUGH)
44/// - `UfsFile`: write directly to UFS (THROUGH mode), requires `create_ufs_file_options`
45/// - `UfsFallbackBlock`: cache-full fallback to UFS (TRY_CACHE)
46#[derive(Clone, Debug)]
47pub struct WriteBlockOptions {
48 /// The request type sent in the initial `WriteRequestCommand`.
49 pub request_type: RequestType,
50 /// UFS file creation options (required when `request_type == UfsFile`).
51 pub create_ufs_file_options: Option<CreateUfsFileOptions>,
52}
53
54impl Default for WriteBlockOptions {
55 fn default() -> Self {
56 Self {
57 request_type: RequestType::GoosefsBlock,
58 create_ufs_file_options: None,
59 }
60 }
61}
62
63/// Handle for an in-progress `WriteBlock` bidirectional streaming RPC.
64///
65/// The gRPC call runs in a background tokio task. The caller sends data through
66/// `request_tx` and receives responses via `recv_response()`. When done, call
67/// `close()` to drop the request channel and wait for the server to finalize.
68pub struct WriteBlockHandle {
69 /// Block being written.
70 block_id: i64,
71 /// Sender for client → server WriteRequest messages (data chunks, flush commands).
72 pub request_tx: mpsc::Sender<WriteRequest>,
73 /// Receiver for server → client WriteResponse messages, forwarded from the background task.
74 response_rx: mpsc::Receiver<std::result::Result<WriteResponse, tonic::Status>>,
75 /// Handle to the background gRPC task.
76 _task_handle: tokio::task::JoinHandle<()>,
77}
78
79impl WriteBlockHandle {
80 /// Receive the next `WriteResponse` from the server (e.g., flush ack).
81 ///
82 /// Returns `None` if the server has closed the response stream.
83 pub async fn recv_response(&mut self) -> Result<Option<WriteResponse>> {
84 match self.response_rx.recv().await {
85 Some(Ok(resp)) => Ok(Some(resp)),
86 Some(Err(status)) => Err(Error::GrpcError {
87 message: format!(
88 "WriteBlock server error for block_id={}: {}",
89 self.block_id, status
90 ),
91 source: status,
92 }),
93 None => Ok(None),
94 }
95 }
96
97 /// Close the write stream by dropping the request sender and wait for
98 /// any final response from the server.
99 pub async fn close(mut self) -> Result<()> {
100 // Drop the request sender to close the client→server half of the stream.
101 // The server will then call onCompleted → commitBlock → replySuccess.
102 drop(self.request_tx);
103 debug!(
104 block_id = self.block_id,
105 "closed write stream, waiting for server finalize"
106 );
107 // Wait for the server's final response (or stream close).
108 // This ensures the background task finishes before we return,
109 // preventing the Channel from being dropped while the task is still running.
110 while let Some(result) = self.response_rx.recv().await {
111 match result {
112 Ok(_resp) => {
113 debug!(
114 block_id = self.block_id,
115 "received final response from server"
116 );
117 }
118 Err(status) => {
119 return Err(Error::GrpcError {
120 message: format!(
121 "WriteBlock server error for block_id={}: {}",
122 self.block_id, status
123 ),
124 source: status,
125 });
126 }
127 }
128 }
129 Ok(())
130 }
131
132 /// Cancel the write stream without waiting for server finalization.
133 ///
134 /// Drops the request sender and response receiver immediately.
135 /// The server will detect the stream cancellation and clean up.
136 /// Matches Java's `GrpcBlockingStream.cancel()`.
137 pub async fn cancel(self) {
138 drop(self.request_tx);
139 drop(self.response_rx);
140 debug!(block_id = self.block_id, "cancelled write stream");
141 }
142}
143
144/// Type alias for the authenticated Worker gRPC client.
145type AuthenticatedBlockWorkerClient =
146 BlockWorkerClient<InterceptedService<Channel, ChannelIdInterceptor>>;
147
148/// Client for `BlockWorker` service on a single worker node.
149///
150/// Each `WorkerClient` carries a monotonic `generation` tag assigned by
151/// [`WorkerClientPool`] at construction time. The generation allows callers
152/// that observed a failure on a specific client to request a **single-flight
153/// reconnect** via [`WorkerClientPool::reconnect_if_stale`]: only the first
154/// observer of generation `N` actually re-establishes the TCP+SASL
155/// connection; all concurrent observers with the same (or older) generation
156/// simply receive the already-replaced client. This collapses the
157/// "thundering-herd reconnect" that previously produced hundreds of duplicate
158/// `authentication failed` warnings when a SASL session expired.
159#[derive(Clone)]
160pub struct WorkerClient {
161 inner: AuthenticatedBlockWorkerClient,
162 addr: String,
163 /// Monotonic tag identifying this exact connection instance.
164 ///
165 /// Two clients cached for the same address must have different
166 /// generations; a caller that observes a failure on generation `N` can
167 /// ask the pool to reconnect *only if* generation has not advanced yet.
168 generation: u64,
169 /// Keeps the SASL authentication stream alive for the channel's lifetime.
170 _sasl_guard: std::sync::Arc<Option<SaslStreamGuard>>,
171}
172
173impl WorkerClient {
174 /// Connect to a Goosefs Worker at the given address with authentication.
175 ///
176 /// Authentication is performed according to `config.auth_type`.
177 pub async fn connect(addr: &str, config: &GoosefsConfig) -> Result<Self> {
178 let endpoint = Channel::from_shared(format!("http://{}", addr))
179 .map_err(|e| Error::ConfigError {
180 message: format!("invalid worker endpoint: {}", e),
181 })?
182 .connect_timeout(config.connect_timeout);
183
184 let channel = endpoint.connect().await?;
185
186 // Perform SASL authentication based on the configured auth type
187 let authenticator =
188 ChannelAuthenticator::new(config.auth_type, config.auth_username.clone(), None)
189 .with_auth_timeout(config.auth_timeout);
190
191 let mut auth_channel = authenticator.authenticate(channel).await?;
192 let sasl_guard = auth_channel.take_sasl_guard();
193 debug!(addr = %addr, auth_type = %config.auth_type, "connected to Goosefs Worker");
194
195 Ok(Self {
196 inner: BlockWorkerClient::new(auth_channel.channel),
197 addr: addr.to_string(),
198 generation: 0,
199 _sasl_guard: std::sync::Arc::new(sasl_guard),
200 })
201 }
202
203 /// Connect to a Goosefs Worker with only connect_timeout (backward compatible, NOSASL).
204 ///
205 /// **Deprecated**: Use `connect(addr, config)` instead for proper authentication.
206 pub async fn connect_simple(addr: &str, connect_timeout: Duration) -> Result<Self> {
207 let endpoint = Channel::from_shared(format!("http://{}", addr))
208 .map_err(|e| Error::ConfigError {
209 message: format!("invalid worker endpoint: {}", e),
210 })?
211 .connect_timeout(connect_timeout);
212
213 let channel = endpoint.connect().await?;
214 let interceptor = ChannelIdInterceptor::new(uuid::Uuid::new_v4().to_string());
215 let intercepted = InterceptedService::new(channel, interceptor);
216 debug!(addr = %addr, "connected to Goosefs Worker (no auth)");
217
218 Ok(Self {
219 inner: BlockWorkerClient::new(intercepted),
220 addr: addr.to_string(),
221 generation: 0,
222 _sasl_guard: std::sync::Arc::new(None),
223 })
224 }
225
226 /// Create from an existing tonic channel (useful for testing / channel sharing).
227 ///
228 /// **Note**: This bypasses authentication.
229 pub fn from_channel(channel: Channel, addr: String) -> Self {
230 let interceptor = ChannelIdInterceptor::new("test-no-auth".to_string());
231 let intercepted = InterceptedService::new(channel, interceptor);
232 Self {
233 inner: BlockWorkerClient::new(intercepted),
234 addr,
235 generation: 0,
236 _sasl_guard: std::sync::Arc::new(None),
237 }
238 }
239
240 /// Start a bidirectional streaming ReadBlock RPC.
241 ///
242 /// Returns: (request_sender, response_stream)
243 ///
244 /// The caller sends an initial `ReadRequest` with block_id/offset/length,
245 /// then sends periodic `offset_received` ACKs. The response stream yields
246 /// `ReadResponse` containing `Chunk` data.
247 ///
248 /// When the block is only stored in UFS (e.g. written with `THROUGH` mode),
249 /// `open_ufs_block_options` must be provided so the Worker knows how to
250 /// locate and read the data from the underlying storage.
251 #[instrument(skip(self, open_ufs_block_options), fields(block_id = %block_id, offset = %offset, length = %length))]
252 pub async fn read_block(
253 &self,
254 block_id: i64,
255 offset: i64,
256 length: i64,
257 chunk_size: i64,
258 open_ufs_block_options: Option<OpenUfsBlockOptions>,
259 ) -> Result<(mpsc::Sender<ReadRequest>, Streaming<ReadResponse>)> {
260 let (tx, rx) = mpsc::channel::<ReadRequest>(32);
261
262 // Send the initial read request
263 let initial_request = ReadRequest {
264 block_id: Some(block_id),
265 offset: Some(offset),
266 length: Some(length),
267 chunk_size: Some(chunk_size),
268 open_ufs_block_options,
269 offset_received: None,
270 position_short: None,
271 request_id: None,
272 capability: None,
273 block_size: None,
274 prefetch_window: None,
275 };
276 tx.send(initial_request)
277 .await
278 .map_err(|_| Error::BlockIoError {
279 message: "failed to send initial ReadRequest".to_string(),
280 })?;
281
282 let stream = ReceiverStream::new(rx);
283 let response = self.inner.clone().read_block(stream).await?;
284
285 Ok((tx, response.into_inner()))
286 }
287
288 /// Open a positioned (random-access) block read stream.
289 ///
290 /// Identical to [`read_block`](Self::read_block) but sets `position_short = true` in the
291 /// initial `ReadRequest`, instructing the worker to skip prefetch and
292 /// serve the exact requested byte range.
293 ///
294 /// Used by [`crate::io::reader::GrpcBlockReader::positioned_read`].
295 pub async fn read_block_positioned(
296 &self,
297 block_id: i64,
298 offset: i64,
299 length: i64,
300 chunk_size: i64,
301 open_ufs_block_options: Option<OpenUfsBlockOptions>,
302 ) -> Result<(mpsc::Sender<ReadRequest>, Streaming<ReadResponse>)> {
303 let (tx, rx) = mpsc::channel::<ReadRequest>(32);
304
305 let initial_request = ReadRequest {
306 block_id: Some(block_id),
307 offset: Some(offset),
308 length: Some(length),
309 chunk_size: Some(chunk_size),
310 open_ufs_block_options,
311 offset_received: None,
312 position_short: Some(true), // positioned-read hint to worker
313 request_id: None,
314 capability: None,
315 block_size: None,
316 prefetch_window: None,
317 };
318 tx.send(initial_request)
319 .await
320 .map_err(|_| Error::BlockIoError {
321 message: "failed to send initial positioned ReadRequest".to_string(),
322 })?;
323
324 let stream = ReceiverStream::new(rx);
325 let response = self.inner.clone().read_block(stream).await?;
326
327 Ok((tx, response.into_inner()))
328 }
329
330 /// Start a bidirectional streaming WriteBlock RPC.
331 ///
332 /// Returns a [`WriteBlockHandle`] that manages the background gRPC task.
333 /// The caller sends data chunks through `handle.request_tx`, then calls
334 /// `handle.recv_response()` to get flush acknowledgements.
335 ///
336 /// ## Why a background task?
337 ///
338 /// Goosefs Worker's `WriteBlock` RPC does **not** send HTTP/2 response
339 /// headers until the client sends a `flush` command or closes the stream.
340 /// tonic's `client.write_block(stream).await` waits for response headers
341 /// before resolving, so calling it inline would deadlock — we'd need the
342 /// returned sender to send flush, but we can't get the sender until the
343 /// call resolves.
344 ///
345 /// By spawning the gRPC call in a background task and forwarding responses
346 /// through an mpsc channel, we decouple request sending from response
347 /// receiving.
348 #[instrument(skip(self, options), fields(block_id = %block_id))]
349 pub async fn write_block(
350 &self,
351 block_id: i64,
352 space_to_reserve: i64,
353 options: WriteBlockOptions,
354 ) -> Result<WriteBlockHandle> {
355 let (tx, rx) = mpsc::channel::<WriteRequest>(32);
356
357 // Build the initial write command
358 let initial_command = WriteRequest {
359 value: Some(write_request::Value::Command(WriteRequestCommand {
360 r#type: Some(options.request_type as i32),
361 id: Some(block_id),
362 offset: Some(0),
363 flush: None,
364 create_ufs_file_options: options.create_ufs_file_options,
365 space_to_reserve: Some(space_to_reserve),
366 capability: None,
367 medium_type: None,
368 })),
369 };
370
371 // Build a composite stream: initial command first, then channel messages.
372 let initial_stream = tokio_stream::once(initial_command);
373 let subsequent_stream = ReceiverStream::new(rx);
374 let combined_stream = initial_stream.chain(subsequent_stream);
375
376 // Channel for forwarding server responses from the background task.
377 let (resp_tx, resp_rx) =
378 mpsc::channel::<std::result::Result<WriteResponse, tonic::Status>>(8);
379
380 let mut client = self.inner.clone();
381 let addr = self.addr.clone();
382
383 let task_handle = tokio::spawn(async move {
384 debug!(block_id = block_id, addr = %addr, "WriteBlock gRPC task started");
385
386 // This call blocks until the server sends response headers,
387 // which happens on the first flush or stream close.
388 let call_result = client.write_block(combined_stream).await;
389
390 match call_result {
391 Ok(response) => {
392 let mut stream = response.into_inner();
393 // Forward all server responses to the caller.
394 loop {
395 match stream.message().await {
396 Ok(Some(msg)) => {
397 if resp_tx.send(Ok(msg)).await.is_err() {
398 debug!(block_id = block_id, "response receiver dropped");
399 break;
400 }
401 }
402 Ok(None) => {
403 debug!(block_id = block_id, "server closed response stream");
404 break;
405 }
406 Err(status) => {
407 warn!(block_id = block_id, %status, "server response error");
408 let _ = resp_tx.send(Err(status)).await;
409 break;
410 }
411 }
412 }
413 }
414 Err(status) => {
415 warn!(block_id = block_id, %status, "WriteBlock RPC failed");
416 let _ = resp_tx.send(Err(status)).await;
417 }
418 }
419
420 debug!(block_id = block_id, "WriteBlock gRPC task finished");
421 });
422
423 debug!(block_id = block_id, "WriteBlock handle created");
424
425 Ok(WriteBlockHandle {
426 block_id,
427 request_tx: tx,
428 response_rx: resp_rx,
429 _task_handle: task_handle,
430 })
431 }
432
433 /// The worker address this client is connected to.
434 pub fn addr(&self) -> &str {
435 &self.addr
436 }
437
438 /// The monotonic generation tag assigned by the pool.
439 ///
440 /// Callers should save this value alongside the `WorkerClient` when
441 /// starting an RPC; if the RPC fails with an authentication error they
442 /// pass the saved generation back to
443 /// [`WorkerClientPool::reconnect_if_stale`] to trigger a single-flight
444 /// reconnect (de-duplicating concurrent observers of the same failure).
445 pub fn generation(&self) -> u64 {
446 self.generation
447 }
448}
449
450/// Connection pool for `WorkerClient` instances.
451///
452/// Caches authenticated gRPC channels by worker address, avoiding the overhead
453/// of re-establishing connections and re-authenticating for every block I/O.
454/// Matches Java's `FileSystemContext.acquireBlockWorkerClient()` pattern.
455///
456/// The pool is thread-safe and can be shared across concurrent workers.
457///
458/// ## Single-Flight Reconnect
459///
460/// When a SASL stream silently expires server-side, many concurrent RPCs on
461/// the same cached channel will fail simultaneously with UNAUTHENTICATED.
462/// Without coordination each observer would independently invoke `reconnect`,
463/// producing a "thundering herd" that serialises through the pool's write
464/// lock and wastes CPU/RTT on duplicate TCP+SASL handshakes.
465///
466/// To collapse this herd, each [`WorkerClient`] carries a monotonic
467/// `generation` tag. Callers pass the observed generation back into
468/// [`reconnect_if_stale`](Self::reconnect_if_stale) after an auth failure;
469/// only the **first** observer of a given generation actually performs the
470/// reconnect, all other concurrent observers receive the already-replaced
471/// client. This reduces N concurrent reconnects to exactly 1.
472pub struct WorkerClientPool {
473 /// Cached worker clients keyed by `"host:port"` address.
474 ///
475 /// The stored client carries its own `generation` in-band; readers simply
476 /// clone it and inspect `client.generation()`.
477 clients: RwLock<HashMap<String, WorkerClient>>,
478 /// Per-address async mutex guarding the reconnect critical section.
479 ///
480 /// Separated from `clients` so the reconnect handshake (which performs
481 /// network I/O) does not hold the clients-map write lock. Acquiring this
482 /// mutex for one address does not block other addresses' reconnects.
483 reconnect_locks: RwLock<HashMap<String, Arc<AsyncMutex<()>>>>,
484 /// Monotonic counter used to hand out a unique `generation` for every
485 /// freshly-created `WorkerClient`.
486 next_generation: AtomicU64,
487 /// Config used to create new connections.
488 config: GoosefsConfig,
489}
490
491impl WorkerClientPool {
492 /// Create a new empty connection pool.
493 pub fn new(config: GoosefsConfig) -> Self {
494 Self {
495 clients: RwLock::new(HashMap::new()),
496 reconnect_locks: RwLock::new(HashMap::new()),
497 // Start generations at 1 so `0` (the default on constructed-but-
498 // never-pooled clients) is always "stale" relative to any pooled
499 // client — this makes `reconnect_if_stale(addr, 0)` always force
500 // a fresh connection when needed.
501 next_generation: AtomicU64::new(1),
502 config,
503 }
504 }
505
506 /// Acquire a `WorkerClient` for the given address.
507 ///
508 /// Returns a cached client if one exists, otherwise creates a new connection.
509 /// The tonic `Channel` supports multiplexing, so a single cached client can
510 /// handle multiple concurrent RPCs.
511 pub async fn acquire(&self, addr: &str) -> Result<WorkerClient> {
512 // Fast path: check read lock first
513 {
514 let cache = self.clients.read().await;
515 if let Some(client) = cache.get(addr) {
516 debug!(addr = %addr, generation = client.generation, "reusing cached WorkerClient");
517 return Ok(client.clone());
518 }
519 }
520
521 // Slow path: create new connection under write lock
522 let mut cache = self.clients.write().await;
523 // Double-check after acquiring write lock (another task may have inserted)
524 if let Some(client) = cache.get(addr) {
525 return Ok(client.clone());
526 }
527
528 debug!(addr = %addr, "creating new WorkerClient for pool");
529 let mut client = WorkerClient::connect(addr, &self.config).await?;
530 client.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
531 cache.insert(addr.to_string(), client.clone());
532 Ok(client)
533 }
534
535 /// Remove a worker from the pool (e.g., after a connection failure).
536 ///
537 /// The next `acquire()` call for this address will create a fresh connection.
538 pub async fn invalidate(&self, addr: &str) {
539 let mut cache = self.clients.write().await;
540 if cache.remove(addr).is_some() {
541 debug!(addr = %addr, "invalidated WorkerClient from pool");
542 }
543 }
544
545 /// Get (or lazily create) the per-address reconnect mutex.
546 async fn reconnect_lock_for(&self, addr: &str) -> Arc<AsyncMutex<()>> {
547 {
548 let locks = self.reconnect_locks.read().await;
549 if let Some(m) = locks.get(addr) {
550 return Arc::clone(m);
551 }
552 }
553 let mut locks = self.reconnect_locks.write().await;
554 Arc::clone(
555 locks
556 .entry(addr.to_string())
557 .or_insert_with(|| Arc::new(AsyncMutex::new(()))),
558 )
559 }
560
561 /// **Single-flight reconnect**: invalidate + reconnect only if the
562 /// currently cached client's generation still matches `stale_generation`.
563 ///
564 /// This is the preferred recovery path on authentication failure. The
565 /// caller passes the `generation()` of the client that just failed;
566 /// because every `WorkerClient` carries a unique monotonic generation
567 /// allocated by this pool:
568 ///
569 /// - If another concurrent task has **already** reconnected in response
570 /// to the same underlying SASL expiry, the cached generation will have
571 /// advanced past `stale_generation` and this call returns the
572 /// already-replaced client **without** performing another
573 /// TCP+SASL handshake.
574 /// - Otherwise, this call performs exactly one reconnect under the
575 /// per-address mutex.
576 ///
577 /// Net effect: N concurrent `AuthenticationFailed` observers on the
578 /// same channel trigger exactly **one** reconnect instead of N.
579 pub async fn reconnect_if_stale(
580 &self,
581 addr: &str,
582 stale_generation: u64,
583 ) -> Result<WorkerClient> {
584 // Take the per-address reconnect mutex. Concurrent callers for the
585 // same address serialise here; callers for *different* addresses do
586 // not block each other.
587 let lock = self.reconnect_lock_for(addr).await;
588 let _guard = lock.lock().await;
589
590 // Under the mutex, re-check the cache. If another task already
591 // replaced the stale client while we were queuing, skip the
592 // reconnect entirely.
593 {
594 let cache = self.clients.read().await;
595 if let Some(client) = cache.get(addr) {
596 if client.generation > stale_generation {
597 debug!(
598 addr = %addr,
599 observed = stale_generation,
600 current = client.generation,
601 "reconnect coalesced — another task already refreshed this channel"
602 );
603 return Ok(client.clone());
604 }
605 }
606 }
607
608 // We are the designated reconnect-er: drop the stale entry, then
609 // build and install a new one.
610 debug!(
611 addr = %addr,
612 stale_generation = stale_generation,
613 "performing single-flight reconnect"
614 );
615 {
616 let mut cache = self.clients.write().await;
617 cache.remove(addr);
618 }
619 let mut fresh = WorkerClient::connect(addr, &self.config).await?;
620 fresh.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
621 {
622 let mut cache = self.clients.write().await;
623 cache.insert(addr.to_string(), fresh.clone());
624 }
625 debug!(
626 addr = %addr,
627 new_generation = fresh.generation,
628 "single-flight reconnect installed fresh WorkerClient"
629 );
630 Ok(fresh)
631 }
632
633 /// Invalidate a cached worker connection and immediately reconnect.
634 ///
635 /// **Prefer [`reconnect_if_stale`](Self::reconnect_if_stale) whenever the
636 /// caller holds a reference to the failing `WorkerClient`** — it
637 /// deduplicates concurrent reconnects triggered by the same underlying
638 /// SASL expiry.
639 ///
640 /// This unconditional variant is kept for paths where the caller does
641 /// not know the generation of the failing client (e.g. a stand-alone
642 /// `connect()` failure that never produced a `WorkerClient`). It
643 /// acquires the same per-address reconnect mutex so it still coalesces
644 /// against any in-flight `reconnect_if_stale`.
645 pub async fn reconnect(&self, addr: &str) -> Result<WorkerClient> {
646 // Use `u64::MAX` as "stale" so `reconnect_if_stale` always proceeds
647 // with the handshake (current generation can never exceed MAX).
648 // This still passes through the per-address mutex so concurrent
649 // callers on the same address share a single handshake.
650 self.reconnect_if_stale(addr, u64::MAX).await
651 }
652
653 /// Create a new pool wrapped in `Arc` for shared ownership.
654 pub fn new_shared(config: GoosefsConfig) -> Arc<Self> {
655 Arc::new(Self::new(config))
656 }
657
658 // ── Test-only helpers ────────────────────────────────────────────
659 //
660 // These helpers are gated on `cfg(test)` so downstream code cannot
661 // accidentally inject bypass-auth clients into the pool. They exist
662 // purely to let the unit tests in this module drive the single-flight
663 // reconnect logic without needing a live Worker process to handshake
664 // against.
665
666 /// Manually insert a client with a specific `generation` into the
667 /// pool for testing. Returns the previously-cached client, if any.
668 #[cfg(test)]
669 async fn test_install(&self, addr: &str, mut client: WorkerClient) -> Option<WorkerClient> {
670 client.generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
671 let mut cache = self.clients.write().await;
672 cache.insert(addr.to_string(), client)
673 }
674
675 /// Snapshot the current cached generation for `addr` (if any).
676 #[cfg(test)]
677 async fn test_current_generation(&self, addr: &str) -> Option<u64> {
678 let cache = self.clients.read().await;
679 cache.get(addr).map(|c| c.generation)
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686 use tonic::transport::Channel;
687
688 /// Fabricate a `WorkerClient` from a *never-connected* channel. The
689 /// client is fully usable for anything that only touches the in-memory
690 /// struct (addr/generation lookups, clone, drop), which is all the
691 /// coalesce tests need.
692 fn fake_client(addr: &str) -> WorkerClient {
693 // `Channel::from_static` is synchronous and does not open a TCP
694 // connection; any actual RPC on this channel would fail but the
695 // tests below never issue one.
696 let channel = Channel::from_static("http://127.0.0.1:1").connect_lazy();
697 WorkerClient::from_channel(channel, addr.to_string())
698 }
699
700 #[tokio::test]
701 async fn test_reconnect_if_stale_coalesces_when_generation_advanced() {
702 // Scenario: generation 5 is cached. Caller A "observes" a failure
703 // on gen 5 and calls reconnect_if_stale(5). Before it enters the
704 // critical section, caller B has already replaced gen 5 with gen 6
705 // (simulated by manually bumping via test_install). Caller A must
706 // NOT trigger a second reconnect — it should return gen 6 as-is.
707 let pool = WorkerClientPool::new(GoosefsConfig::new("127.0.0.1:9200"));
708 let addr = "test-worker:9203";
709
710 // Install a gen-1 client, then another gen-2 client (simulating
711 // "someone else already reconnected").
712 pool.test_install(addr, fake_client(addr)).await;
713 let gen_before = pool.test_current_generation(addr).await.unwrap();
714 pool.test_install(addr, fake_client(addr)).await;
715 let gen_after = pool.test_current_generation(addr).await.unwrap();
716 assert!(gen_after > gen_before);
717
718 // Caller passes the *old* generation — pool must short-circuit and
719 // NOT call WorkerClient::connect (which would fail against a
720 // non-existent host and fail the test).
721 let result = pool.reconnect_if_stale(addr, gen_before).await;
722 assert!(
723 result.is_ok(),
724 "coalesced reconnect must short-circuit without network I/O, got {:?}",
725 result.err()
726 );
727 let returned = result.unwrap();
728 assert_eq!(
729 returned.generation(),
730 gen_after,
731 "caller must receive the already-replaced generation"
732 );
733 assert_eq!(
734 pool.test_current_generation(addr).await,
735 Some(gen_after),
736 "cached generation must not advance for a coalesced caller"
737 );
738 }
739
740 #[tokio::test]
741 async fn test_reconnect_locks_are_per_address() {
742 // Acquiring the reconnect lock for addr-A must not block acquiring
743 // the lock for addr-B. Without per-address locks, unrelated worker
744 // reconnects would serialise through one global mutex.
745 let pool = WorkerClientPool::new(GoosefsConfig::new("127.0.0.1:9200"));
746 let lock_a = pool.reconnect_lock_for("worker-a:9203").await;
747 let lock_b = pool.reconnect_lock_for("worker-b:9203").await;
748
749 // Hold A, must still be able to grab B immediately.
750 let guard_a = lock_a.lock().await;
751 let guard_b = tokio::time::timeout(std::time::Duration::from_millis(50), lock_b.lock())
752 .await
753 .expect("lock for different address must not be blocked");
754 drop(guard_b);
755 drop(guard_a);
756 }
757
758 #[tokio::test]
759 async fn test_generation_is_monotonic_across_installs() {
760 let pool = WorkerClientPool::new(GoosefsConfig::new("127.0.0.1:9200"));
761 let addr = "w:9203";
762
763 pool.test_install(addr, fake_client(addr)).await;
764 let g1 = pool.test_current_generation(addr).await.unwrap();
765
766 pool.test_install(addr, fake_client(addr)).await;
767 let g2 = pool.test_current_generation(addr).await.unwrap();
768
769 pool.test_install(addr, fake_client(addr)).await;
770 let g3 = pool.test_current_generation(addr).await.unwrap();
771
772 assert!(g1 < g2, "gen {} not less than {}", g1, g2);
773 assert!(g2 < g3, "gen {} not less than {}", g2, g3);
774 }
775}