Skip to main content

fluss/client/write/
accumulator.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::client::broadcast;
19use crate::client::write::IdempotenceManager;
20use crate::client::write::batch::WriteBatch::{ArrowLog, Kv};
21use crate::client::write::batch::{ArrowLogWriteBatch, KvWriteBatch, WriteBatch};
22use crate::client::{LogWriteRecord, Record, ResultHandle, WriteRecord};
23use crate::cluster::{BucketLocation, Cluster, ServerNode};
24use crate::config::Config;
25use crate::error::{Error, Result};
26use crate::metadata::{PhysicalTablePath, TableBucket};
27use crate::record::{NO_BATCH_SEQUENCE, NO_WRITER_ID};
28use crate::util::current_time_ms;
29use crate::{BucketId, PartitionId, TableId};
30use dashmap::DashMap;
31use parking_lot::{Condvar, Mutex, RwLock};
32use std::collections::{HashMap, HashSet, VecDeque};
33use std::sync::Arc;
34use std::sync::atomic::{AtomicBool, AtomicI32, AtomicI64, AtomicUsize, Ordering};
35use std::time::{Duration, Instant};
36use tokio::sync::Notify;
37
38/// Byte-counting semaphore that blocks producers when total buffered memory
39/// exceeds the configured limit. Matches Java's `LazyMemorySegmentPool` behavior.
40///
41/// TODO: Replace `notify_all()` with per-waiter FIFO signaling (Java uses per-request
42/// Condition objects in a Deque) to avoid thundering herd under high contention.
43///
44/// TODO: Track actual batch memory usage instead of reserving a fixed `writer_batch_size`
45/// per batch. This over-counts when batches don't fill completely, reducing effective
46/// throughput. Requires tighter coupling with batch internals.
47pub(crate) struct MemoryLimiter {
48    state: Mutex<usize>,
49    cond: Condvar,
50    max_memory: usize,
51    wait_timeout: Duration,
52    closed: AtomicBool,
53    waiting_count: AtomicUsize,
54}
55
56impl MemoryLimiter {
57    pub fn new(max_memory: usize, wait_timeout: Duration) -> Self {
58        Self {
59            state: Mutex::new(0),
60            cond: Condvar::new(),
61            max_memory,
62            wait_timeout,
63            closed: AtomicBool::new(false),
64            waiting_count: AtomicUsize::new(0),
65        }
66    }
67
68    /// Try to acquire `size` bytes. Blocks until memory is available,
69    /// the timeout expires, or the limiter is closed.
70    /// Returns a `MemoryPermit` on success.
71    pub fn acquire(self: &Arc<Self>, size: usize) -> Result<MemoryPermit> {
72        if self.closed.load(Ordering::Acquire) {
73            return Err(Error::WriterClosed {
74                message: "Memory limiter is closed".to_string(),
75            });
76        }
77
78        if size > self.max_memory {
79            return Err(Error::IllegalArgument {
80                message: format!(
81                    "Batch size {} exceeds total buffer memory limit {}",
82                    size, self.max_memory
83                ),
84            });
85        }
86
87        let mut used = self.state.lock();
88        let deadline = Instant::now() + self.wait_timeout;
89        while *used + size > self.max_memory {
90            self.waiting_count.fetch_add(1, Ordering::Relaxed);
91            let result = self.cond.wait_until(&mut used, deadline);
92            self.waiting_count.fetch_sub(1, Ordering::Relaxed);
93
94            if self.closed.load(Ordering::Acquire) {
95                return Err(Error::WriterClosed {
96                    message: "Memory limiter is closed".to_string(),
97                });
98            }
99            if result.timed_out() && *used + size > self.max_memory {
100                return Err(Error::BufferExhausted {
101                    message: format!(
102                        "Failed to allocate {} bytes for write batch within {}ms. \
103                         {} of {} bytes in use, {} threads waiting.",
104                        size,
105                        self.wait_timeout.as_millis(),
106                        *used,
107                        self.max_memory,
108                        self.waiting_count.load(Ordering::Relaxed),
109                    ),
110                });
111            }
112        }
113
114        *used += size;
115        Ok(MemoryPermit {
116            limiter: Arc::clone(self),
117            size,
118        })
119    }
120
121    fn release(&self, size: usize) {
122        let mut used = self.state.lock();
123        *used = used.saturating_sub(size);
124        self.cond.notify_all();
125    }
126
127    /// Returns true if any producers are currently blocked waiting for memory.
128    /// Used by `ready()` to mark all batches as immediately sendable when
129    /// memory is exhausted (matching Java's `exhausted` flag).
130    pub fn has_waiters(&self) -> bool {
131        self.waiting_count.load(Ordering::Relaxed) > 0
132    }
133
134    /// Mark the limiter as closed and wake all blocked producers.
135    fn close(&self) {
136        self.closed.store(true, Ordering::Release);
137        self.cond.notify_all();
138    }
139}
140
141/// RAII guard that releases memory back to the `MemoryLimiter` on drop.
142pub(crate) struct MemoryPermit {
143    limiter: Arc<MemoryLimiter>,
144    size: usize,
145}
146
147impl std::fmt::Debug for MemoryPermit {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("MemoryPermit")
150            .field("size", &self.size)
151            .finish_non_exhaustive()
152    }
153}
154
155impl Drop for MemoryPermit {
156    fn drop(&mut self) {
157        if self.size > 0 {
158            self.limiter.release(self.size);
159        }
160    }
161}
162
163// Type alias to simplify complex nested types
164type BucketBatches = Vec<(BucketId, Arc<Mutex<VecDeque<WriteBatch>>>)>;
165
166#[allow(dead_code)]
167pub struct RecordAccumulator {
168    config: Config,
169    write_batches: DashMap<Arc<PhysicalTablePath>, BucketAndWriteBatches>,
170    // batch_id -> (complete callback, memory permit)
171    incomplete_batches: RwLock<HashMap<i64, (ResultHandle, MemoryPermit)>>,
172    batch_timeout_ms: i64,
173    closed: AtomicBool,
174    flushes_in_progress: AtomicI32,
175    appends_in_progress: i32,
176    nodes_drain_index: Mutex<HashMap<i32, usize>>,
177    batch_id: AtomicI64,
178    idempotence_manager: Arc<IdempotenceManager>,
179    memory_limiter: Arc<MemoryLimiter>,
180    /// Wakes the sender task when new batches are created or existing batches
181    /// become full, so the sender can drain them immediately instead of waiting
182    /// for its next poll cycle. This is the Rust equivalent of Java's
183    /// `Sender.wakeup()` / Kafka's `RecordAccumulator.wakeup()`.
184    sender_wakeup: Notify,
185}
186
187impl RecordAccumulator {
188    pub fn new(config: Config, idempotence_manager: Arc<IdempotenceManager>) -> Self {
189        let batch_timeout_ms = config.writer_batch_timeout_ms;
190        let memory_limiter = Arc::new(MemoryLimiter::new(
191            config.writer_buffer_memory_size,
192            Duration::from_millis(config.writer_buffer_wait_timeout_ms),
193        ));
194        RecordAccumulator {
195            config,
196            write_batches: Default::default(),
197            incomplete_batches: Default::default(),
198            batch_timeout_ms,
199            closed: Default::default(),
200            flushes_in_progress: Default::default(),
201            appends_in_progress: Default::default(),
202            nodes_drain_index: Default::default(),
203            batch_id: Default::default(),
204            idempotence_manager,
205            memory_limiter,
206            sender_wakeup: Notify::new(),
207        }
208    }
209
210    fn try_append(
211        &self,
212        record: &WriteRecord,
213        dq: &mut VecDeque<WriteBatch>,
214    ) -> Result<Option<RecordAppendResult>> {
215        let dq_size = dq.len();
216        if let Some(last_batch) = dq.back_mut() {
217            return if let Some(result_handle) = last_batch.try_append(record)? {
218                Ok(Some(RecordAppendResult::new(
219                    result_handle,
220                    dq_size > 1 || last_batch.is_closed(),
221                    false,
222                    false,
223                )))
224            } else {
225                Ok(None)
226            };
227        }
228        Ok(None)
229    }
230
231    fn append_new_batch(
232        &self,
233        cluster: &Cluster,
234        record: &WriteRecord,
235        dq: &mut VecDeque<WriteBatch>,
236        permit: MemoryPermit,
237        alloc_size: usize,
238    ) -> Result<RecordAppendResult> {
239        let physical_table_path = &record.physical_table_path;
240        let table_path = physical_table_path.get_table_path();
241        let table_info = cluster.get_table(table_path)?;
242        let arrow_compression_info = table_info.get_table_config().get_arrow_compression_info()?;
243        let row_type = &table_info.row_type;
244
245        let schema_id = table_info.schema_id;
246
247        let mut batch: WriteBatch = match record.record() {
248            Record::Log(_) => ArrowLog(ArrowLogWriteBatch::new(
249                self.batch_id.fetch_add(1, Ordering::Relaxed),
250                Arc::clone(physical_table_path),
251                schema_id,
252                arrow_compression_info,
253                row_type,
254                current_time_ms(),
255                matches!(&record.record, Record::Log(LogWriteRecord::RecordBatch(_))),
256            )?),
257            Record::Kv(kv_record) => Kv(KvWriteBatch::new(
258                self.batch_id.fetch_add(1, Ordering::Relaxed),
259                Arc::clone(physical_table_path),
260                schema_id,
261                alloc_size,
262                record.write_format.to_kv_format()?,
263                kv_record.target_columns.clone(),
264                current_time_ms(),
265            )),
266        };
267
268        let batch_id = batch.batch_id();
269
270        let result_handle = batch
271            .try_append(record)?
272            .expect("must append to a new batch");
273
274        let batch_is_closed = batch.is_closed();
275        dq.push_back(batch);
276
277        self.incomplete_batches
278            .write()
279            .insert(batch_id, (result_handle.clone(), permit));
280        Ok(RecordAppendResult::new(
281            result_handle,
282            dq.len() > 1 || batch_is_closed,
283            true,
284            false,
285        ))
286    }
287
288    pub fn append(
289        &self,
290        record: &WriteRecord<'_>,
291        bucket_id: BucketId,
292        cluster: &Cluster,
293        abort_if_batch_full: bool,
294    ) -> Result<RecordAppendResult> {
295        let physical_table_path = &record.physical_table_path;
296        let table_path = physical_table_path.get_table_path();
297        let table_info = cluster.get_table(table_path)?;
298        let is_partitioned_table = table_info.is_partitioned();
299
300        let partition_id = if is_partitioned_table {
301            cluster.get_partition_id(physical_table_path)
302        } else {
303            None
304        };
305
306        let dq = {
307            let mut binding = self
308                .write_batches
309                .entry(Arc::clone(physical_table_path))
310                .or_insert_with(|| BucketAndWriteBatches {
311                    table_id: table_info.table_id,
312                    is_partitioned_table,
313                    partition_id,
314                    batches: Default::default(),
315                });
316            let bucket_and_batches = binding.value_mut();
317            bucket_and_batches
318                .batches
319                .entry(bucket_id)
320                .or_insert_with(|| Arc::new(Mutex::new(VecDeque::new())))
321                .clone()
322        };
323
324        let mut dq_guard = dq.lock();
325        if let Some(append_result) = self.try_append(record, &mut dq_guard)? {
326            return Ok(append_result);
327        }
328
329        if abort_if_batch_full {
330            return Ok(RecordAppendResult::new_without_result_handle(
331                true, false, true,
332            ));
333        }
334
335        // Drop dq lock before blocking on memory to prevent deadlock:
336        // producer holds dq + blocks on memory, while sender needs dq to drain.
337        drop(dq_guard);
338
339        let batch_size = self.config.writer_batch_size as usize;
340        let record_size = record.estimated_record_size();
341        let alloc_size = batch_size.max(record_size);
342        let permit = self.memory_limiter.acquire(alloc_size)?;
343
344        // Re-acquire dq lock after memory is available
345        let mut dq_guard = dq.lock();
346        // Re-try: another thread may have created a batch while we waited
347        if let Some(append_result) = self.try_append(record, &mut dq_guard)? {
348            return Ok(append_result); // permit drops here, memory released
349        }
350
351        self.append_new_batch(cluster, record, &mut dq_guard, permit, alloc_size)
352    }
353
354    pub fn ready(&self, cluster: &Arc<Cluster>) -> Result<ReadyCheckResult> {
355        // Snapshot just the Arcs we need, avoiding cloning the entire BucketAndWriteBatches struct
356        let entries: Vec<(Arc<PhysicalTablePath>, Option<PartitionId>, BucketBatches)> = self
357            .write_batches
358            .iter()
359            .map(|entry| {
360                let physical_table_path = Arc::clone(entry.key());
361                let partition_id = entry.value().partition_id;
362                let bucket_batches: Vec<_> = entry
363                    .value()
364                    .batches
365                    .iter()
366                    .map(|(bucket_id, batch_arc)| (*bucket_id, batch_arc.clone()))
367                    .collect();
368                (physical_table_path, partition_id, bucket_batches)
369            })
370            .collect();
371
372        let mut ready_nodes = HashSet::new();
373        let mut next_ready_check_delay_ms = self.batch_timeout_ms;
374        let mut unknown_leader_tables = HashSet::new();
375        let exhausted = self.memory_limiter.has_waiters();
376
377        for (physical_table_path, mut partition_id, bucket_batches) in entries {
378            next_ready_check_delay_ms = self.bucket_ready(
379                &physical_table_path,
380                physical_table_path.get_partition_name().is_some(),
381                &mut partition_id,
382                bucket_batches,
383                &mut ready_nodes,
384                &mut unknown_leader_tables,
385                cluster,
386                next_ready_check_delay_ms,
387                exhausted,
388            )?
389        }
390
391        Ok(ReadyCheckResult {
392            ready_nodes,
393            next_ready_check_delay_ms,
394            unknown_leader_tables,
395        })
396    }
397
398    #[allow(clippy::too_many_arguments)]
399    fn bucket_ready(
400        &self,
401        physical_table_path: &Arc<PhysicalTablePath>,
402        is_partitioned_table: bool,
403        partition_id: &mut Option<PartitionId>,
404        bucket_batches: BucketBatches,
405        ready_nodes: &mut HashSet<ServerNode>,
406        unknown_leader_tables: &mut HashSet<Arc<PhysicalTablePath>>,
407        cluster: &Cluster,
408        next_ready_check_delay_ms: i64,
409        exhausted: bool,
410    ) -> Result<i64> {
411        let mut next_delay = next_ready_check_delay_ms;
412
413        // First check this table has partitionId.
414        if is_partitioned_table && partition_id.is_none() {
415            let partition_id = cluster.get_partition_id(physical_table_path);
416
417            if partition_id.is_some() {
418                // Update the cached partition_id
419                if let Some(mut entry) = self.write_batches.get_mut(physical_table_path) {
420                    entry.partition_id = partition_id;
421                }
422            } else {
423                log::debug!(
424                    "Partition does not exist for {}, bucket will not be set to ready",
425                    physical_table_path.as_ref()
426                );
427
428                // TODO: we shouldn't add unready partitions to unknownLeaderTables,
429                // because it cases PartitionNotExistException later
430                unknown_leader_tables.insert(Arc::clone(physical_table_path));
431                return Ok(next_delay);
432            }
433        }
434
435        for (bucket_id, batch) in bucket_batches {
436            let batch_guard = batch.lock();
437            if batch_guard.is_empty() {
438                continue;
439            }
440
441            let batch = batch_guard.front().unwrap();
442            let waited_time_ms = batch.waited_time_ms(current_time_ms());
443            let deque_size = batch_guard.len();
444            let full = deque_size > 1 || batch.is_closed();
445            let table_bucket = cluster.get_table_bucket(physical_table_path, bucket_id)?;
446            if let Some(leader) = cluster.leader_for(&table_bucket) {
447                next_delay = self.batch_ready(
448                    leader,
449                    waited_time_ms,
450                    full,
451                    exhausted,
452                    ready_nodes,
453                    next_delay,
454                );
455            } else {
456                unknown_leader_tables.insert(Arc::clone(physical_table_path));
457            }
458        }
459        Ok(next_delay)
460    }
461
462    fn batch_ready(
463        &self,
464        leader: &ServerNode,
465        waited_time_ms: i64,
466        full: bool,
467        exhausted: bool,
468        ready_nodes: &mut HashSet<ServerNode>,
469        next_ready_check_delay_ms: i64,
470    ) -> i64 {
471        if !ready_nodes.contains(leader) {
472            let expired = waited_time_ms >= self.batch_timeout_ms;
473            let sendable = full
474                || expired
475                || exhausted
476                || self.closed.load(Ordering::Acquire)
477                || self.flush_in_progress();
478
479            if sendable {
480                ready_nodes.insert(leader.clone());
481            } else {
482                let time_left_ms = self.batch_timeout_ms.saturating_sub(waited_time_ms);
483                return next_ready_check_delay_ms.min(time_left_ms);
484            }
485        }
486        next_ready_check_delay_ms
487    }
488
489    pub fn drain(
490        &self,
491        cluster: Arc<Cluster>,
492        nodes: &HashSet<ServerNode>,
493        max_size: i32,
494    ) -> Result<HashMap<i32, Vec<ReadyWriteBatch>>> {
495        if nodes.is_empty() {
496            return Ok(HashMap::new());
497        }
498        let mut batches = HashMap::new();
499        for node in nodes {
500            let ready = self.drain_batches_for_one_node(&cluster, node, max_size)?;
501            if !ready.is_empty() {
502                batches.insert(node.id(), ready);
503            }
504        }
505
506        Ok(batches)
507    }
508
509    /// Matches Java's `shouldStopDrainBatchesForBucket`. Returns true if
510    /// this bucket should be skipped during drain.
511    fn should_stop_drain_batches_for_bucket(
512        &self,
513        first: &WriteBatch,
514        table_bucket: &TableBucket,
515    ) -> bool {
516        if !self.idempotence_manager.is_enabled() {
517            return false;
518        }
519        if !self.idempotence_manager.is_writer_id_valid() {
520            return true;
521        }
522
523        // Use batch_id comparison instead of sequence comparison. After
524        // handle_failed_batch adjusts InFlightBatch sequences, the WriteBatch's
525        // stored sequence may be stale (re_enqueue syncs it, but this is more
526        // robust). Java can compare sequences because resetWriterState mutates
527        // the batch directly; Rust uses lightweight InFlightBatch proxies.
528        let is_first_in_flight = self.idempotence_manager.in_flight_count(table_bucket) == 0
529            || (first.has_batch_sequence()
530                && self
531                    .idempotence_manager
532                    .is_first_in_flight_batch(table_bucket, first.batch_id()));
533
534        if is_first_in_flight {
535            return false;
536        }
537
538        if !first.has_batch_sequence() {
539            // Fresh batch: respect max in-flight limit
540            !self
541                .idempotence_manager
542                .can_send_more_requests(table_bucket)
543        } else {
544            // Re-enqueued batch that's NOT first in-flight: stop
545            true
546        }
547    }
548
549    fn drain_batches_for_one_node(
550        &self,
551        cluster: &Cluster,
552        node: &ServerNode,
553        max_size: i32,
554    ) -> Result<Vec<ReadyWriteBatch>> {
555        let mut size: usize = 0;
556        let buckets = self.get_all_buckets_in_current_node(node, cluster);
557        let mut ready = Vec::new();
558
559        if buckets.is_empty() {
560            return Ok(ready);
561        }
562
563        let start = {
564            let mut nodes_drain_index_guard = self.nodes_drain_index.lock();
565            let drain_index = nodes_drain_index_guard.entry(node.id()).or_insert(0);
566            *drain_index % buckets.len()
567        };
568
569        let mut current_index = start;
570        let mut last_processed_index;
571
572        loop {
573            let bucket = &buckets[current_index];
574            let table_path = bucket.physical_table_path();
575            let table_bucket = bucket.table_bucket.clone();
576            last_processed_index = current_index;
577            current_index = (current_index + 1) % buckets.len();
578
579            let deque = self
580                .write_batches
581                .get(table_path)
582                .and_then(|bucket_and_write_batches| {
583                    bucket_and_write_batches
584                        .batches
585                        .get(&table_bucket.bucket_id())
586                        .cloned()
587                });
588
589            if let Some(deque) = deque {
590                let mut maybe_batch = None;
591                {
592                    let mut batch_lock = deque.lock();
593                    if !batch_lock.is_empty() {
594                        let first_batch = batch_lock.front().unwrap();
595
596                        if size + first_batch.estimated_size_in_bytes() > max_size as usize
597                            && !ready.is_empty()
598                        {
599                            // there is a rare case that a single batch size is larger than the request size
600                            // due to compression; in this case we will still eventually send this batch in
601                            // a single request.
602                            break;
603                        }
604
605                        // Improvement: `continue` instead of `break` to skip
606                        // only this bucket, not all buckets for the node.
607                        if self.should_stop_drain_batches_for_bucket(first_batch, &table_bucket) {
608                            if current_index == start {
609                                break;
610                            }
611                            continue;
612                        }
613
614                        maybe_batch = Some(batch_lock.pop_front().unwrap());
615                    }
616                }
617
618                if let Some(ref mut batch) = maybe_batch {
619                    // Assign writer state to fresh batches (matching Java's drain loop)
620                    let writer_id = if self.idempotence_manager.is_enabled() {
621                        self.idempotence_manager.writer_id()
622                    } else {
623                        NO_WRITER_ID
624                    };
625                    if writer_id != NO_WRITER_ID && !batch.has_batch_sequence() {
626                        self.idempotence_manager
627                            .maybe_update_writer_id(&table_bucket);
628                        let seq = self
629                            .idempotence_manager
630                            .next_sequence_and_increment(&table_bucket);
631                        batch.set_writer_state(writer_id, seq);
632                        self.idempotence_manager.add_in_flight_batch(
633                            &table_bucket,
634                            seq,
635                            batch.batch_id(),
636                        );
637                    }
638                }
639
640                if let Some(mut batch) = maybe_batch {
641                    let current_batch_size = batch.estimated_size_in_bytes();
642                    size += current_batch_size;
643
644                    // mark the batch as drained.
645                    batch.drained(current_time_ms());
646                    ready.push(ReadyWriteBatch {
647                        table_bucket,
648                        write_batch: batch,
649                    });
650                }
651            }
652            if current_index == start {
653                break;
654            }
655        }
656
657        // Store the last processed index to maintain round-robin fairness
658        {
659            let mut nodes_drain_index_guard = self.nodes_drain_index.lock();
660            nodes_drain_index_guard.insert(node.id(), last_processed_index);
661        }
662
663        Ok(ready)
664    }
665
666    pub fn remove_incomplete_batches(&self, batch_id: i64) {
667        self.incomplete_batches.write().remove(&batch_id);
668    }
669
670    pub fn re_enqueue(&self, mut ready_write_batch: ReadyWriteBatch) {
671        ready_write_batch.write_batch.re_enqueued();
672
673        // Sync WriteBatch sequence with IdempotenceManager's adjusted sequence.
674        // When handle_failed_batch adjusts InFlightBatch sequences (after a prior
675        // batch fails), the WriteBatch is not updated (unlike Java which calls
676        // resetWriterState on the actual batch). We must sync here so that:
677        // 1. should_stop_drain_batches_for_bucket comparisons work correctly
678        // 2. build() produces bytes with the correct (adjusted) sequence
679        if self.idempotence_manager.is_enabled()
680            && ready_write_batch.write_batch.has_batch_sequence()
681        {
682            if let Some(adjusted_seq) = self.idempotence_manager.get_adjusted_sequence(
683                &ready_write_batch.table_bucket,
684                ready_write_batch.write_batch.batch_id(),
685            ) {
686                if adjusted_seq != ready_write_batch.write_batch.batch_sequence() {
687                    let writer_id = ready_write_batch.write_batch.writer_id();
688                    ready_write_batch
689                        .write_batch
690                        .set_writer_state(writer_id, adjusted_seq);
691                }
692            }
693        }
694
695        let dq = self.get_or_create_deque(&ready_write_batch);
696        let mut dq_guard = dq.lock();
697        if self.idempotence_manager.is_enabled() {
698            self.insert_in_sequence_order(&mut dq_guard, ready_write_batch);
699        } else {
700            dq_guard.push_front(ready_write_batch.write_batch);
701        }
702    }
703
704    /// Insert a re-enqueued batch in sequence order. Matches Java's
705    /// `insertInSequenceOrder`. If the batch is the next expected in-flight,
706    /// push to front; otherwise, find the correct sorted position.
707    fn insert_in_sequence_order(
708        &self,
709        dq: &mut VecDeque<WriteBatch>,
710        ready_write_batch: ReadyWriteBatch,
711    ) {
712        debug_assert!(
713            ready_write_batch.write_batch.batch_sequence() != NO_BATCH_SEQUENCE,
714            "Re-enqueuing a batch without a sequence (batch_id={})",
715            ready_write_batch.write_batch.batch_id()
716        );
717        debug_assert!(
718            self.idempotence_manager
719                .in_flight_count(&ready_write_batch.table_bucket)
720                > 0,
721            "Re-enqueuing a batch not tracked in in-flight (batch_id={}, bucket={})",
722            ready_write_batch.write_batch.batch_id(),
723            ready_write_batch.table_bucket
724        );
725
726        if dq.is_empty() {
727            dq.push_front(ready_write_batch.write_batch);
728            return;
729        }
730
731        // If it's the first in-flight batch for its bucket, push to front
732        if self.idempotence_manager.is_first_in_flight_batch(
733            &ready_write_batch.table_bucket,
734            ready_write_batch.write_batch.batch_id(),
735        ) {
736            dq.push_front(ready_write_batch.write_batch);
737            return;
738        }
739
740        // Find the correct position sorted by batch_sequence
741        let batch_seq = ready_write_batch.write_batch.batch_sequence();
742        let mut insert_pos = dq.len();
743        for (i, existing) in dq.iter().enumerate() {
744            if existing.has_batch_sequence() && existing.batch_sequence() > batch_seq {
745                insert_pos = i;
746                break;
747            }
748        }
749        dq.insert(insert_pos, ready_write_batch.write_batch);
750    }
751
752    fn get_or_create_deque(
753        &self,
754        ready_write_batch: &ReadyWriteBatch,
755    ) -> Arc<Mutex<VecDeque<WriteBatch>>> {
756        let physical_table_path = ready_write_batch.write_batch.physical_table_path();
757        let bucket_id = ready_write_batch.table_bucket.bucket_id();
758        let table_id = ready_write_batch.table_bucket.table_id();
759        let partition_id = ready_write_batch.table_bucket.partition_id();
760        let is_partitioned_table = partition_id.is_some();
761
762        let mut binding = self
763            .write_batches
764            .entry(Arc::clone(physical_table_path))
765            .or_insert_with(|| BucketAndWriteBatches {
766                table_id,
767                is_partitioned_table,
768                partition_id,
769                batches: Default::default(),
770            });
771        let bucket_and_batches = binding.value_mut();
772        bucket_and_batches
773            .batches
774            .entry(bucket_id)
775            .or_insert_with(|| Arc::new(Mutex::new(VecDeque::new())))
776            .clone()
777    }
778
779    /// Mark the accumulator as closed. All batches become immediately ready
780    /// (sendable) in `batch_ready`, triggering a full drain without waiting
781    /// for `batch_timeout_ms`. Matches Java's `RecordAccumulator.close()`.
782    pub fn close(&self) {
783        self.closed.store(true, Ordering::Release);
784        self.wakeup_sender();
785    }
786
787    pub fn is_closed(&self) -> bool {
788        self.closed.load(Ordering::Acquire)
789    }
790
791    pub fn abort_batches(&self, error: broadcast::Error) {
792        self.memory_limiter.close();
793        // Complete batches still in deques (not yet drained).
794        for mut entry in self.write_batches.iter_mut() {
795            for (_bucket_id, deque) in entry.value_mut().batches.iter_mut() {
796                let mut dq = deque.lock();
797                while let Some(batch) = dq.pop_front() {
798                    batch.complete(Err(error.clone()));
799                }
800            }
801        }
802        // Fail any remaining handles (including in-flight batches that were
803        // drained but not yet completed). This is a no-op for handles already
804        // completed above via WriteBatch::complete.
805        let mut incomplete = self.incomplete_batches.write();
806        for (handle, _permit) in incomplete.values() {
807            handle.fail(error.clone());
808        }
809        incomplete.clear();
810    }
811
812    pub fn has_incomplete(&self) -> bool {
813        !self.incomplete_batches.read().is_empty()
814    }
815
816    /// Wake the sender task so it can drain ready batches immediately.
817    pub fn wakeup_sender(&self) {
818        self.sender_wakeup.notify_one();
819    }
820
821    /// Returns a future that completes when `wakeup_sender()` is called.
822    pub fn notified(&self) -> tokio::sync::futures::Notified<'_> {
823        self.sender_wakeup.notified()
824    }
825
826    fn get_all_buckets_in_current_node(
827        &self,
828        current: &ServerNode,
829        cluster: &Cluster,
830    ) -> Vec<BucketLocation> {
831        let mut buckets = vec![];
832        for bucket_locations in cluster.get_bucket_locations_by_path().values() {
833            for bucket_location in bucket_locations {
834                if let Some(leader) = bucket_location.leader() {
835                    if current.id() == leader.id() {
836                        buckets.push(bucket_location.clone());
837                    }
838                }
839            }
840        }
841        buckets
842    }
843
844    pub fn has_undrained(&self) -> bool {
845        for entry in self.write_batches.iter() {
846            for (_, batch_deque) in entry.value().batches.iter() {
847                if !batch_deque.lock().is_empty() {
848                    return true;
849                }
850            }
851        }
852        false
853    }
854
855    pub fn get_physical_table_paths_in_batches(&self) -> Vec<Arc<PhysicalTablePath>> {
856        self.write_batches
857            .iter()
858            .map(|entry| Arc::clone(entry.key()))
859            .collect()
860    }
861
862    fn flush_in_progress(&self) -> bool {
863        self.flushes_in_progress.load(Ordering::SeqCst) > 0
864    }
865
866    pub fn begin_flush(&self) {
867        self.flushes_in_progress.fetch_add(1, Ordering::SeqCst);
868        self.wakeup_sender();
869    }
870
871    #[allow(unused_must_use)]
872    pub async fn await_flush_completion(&self) -> Result<()> {
873        // Clone handles before awaiting to avoid holding RwLock read guard across await points
874        let handles: Vec<_> = self
875            .incomplete_batches
876            .read()
877            .values()
878            .map(|(h, _)| h.clone())
879            .collect();
880
881        // Await on all handles
882        let result = async {
883            for result_handle in handles {
884                result_handle.wait().await?;
885            }
886            Ok(())
887        }
888        .await;
889
890        // Always decrement flushes_in_progress, even if an error occurred
891        // This mimics the Java finally block behavior
892        self.flushes_in_progress.fetch_sub(1, Ordering::SeqCst);
893
894        result
895    }
896}
897
898pub struct ReadyWriteBatch {
899    pub table_bucket: TableBucket,
900    pub write_batch: WriteBatch,
901}
902
903impl ReadyWriteBatch {
904    pub fn write_batch(&self) -> &WriteBatch {
905        &self.write_batch
906    }
907}
908
909#[allow(dead_code)]
910struct BucketAndWriteBatches {
911    table_id: TableId,
912    is_partitioned_table: bool,
913    partition_id: Option<PartitionId>,
914    batches: HashMap<BucketId, Arc<Mutex<VecDeque<WriteBatch>>>>,
915}
916
917pub struct RecordAppendResult {
918    pub batch_is_full: bool,
919    pub new_batch_created: bool,
920    pub abort_record_for_new_batch: bool,
921    pub result_handle: Option<ResultHandle>,
922}
923
924impl RecordAppendResult {
925    fn new(
926        result_handle: ResultHandle,
927        batch_is_full: bool,
928        new_batch_created: bool,
929        abort_record_for_new_batch: bool,
930    ) -> Self {
931        Self {
932            batch_is_full,
933            new_batch_created,
934            abort_record_for_new_batch,
935            result_handle: Some(result_handle),
936        }
937    }
938
939    fn new_without_result_handle(
940        batch_is_full: bool,
941        new_batch_created: bool,
942        abort_record_for_new_batch: bool,
943    ) -> Self {
944        Self {
945            batch_is_full,
946            new_batch_created,
947            abort_record_for_new_batch,
948            result_handle: None,
949        }
950    }
951}
952
953pub struct ReadyCheckResult {
954    pub ready_nodes: HashSet<ServerNode>,
955    pub next_ready_check_delay_ms: i64,
956    pub unknown_leader_tables: HashSet<Arc<PhysicalTablePath>>,
957}
958
959impl ReadyCheckResult {
960    pub fn new(
961        ready_nodes: HashSet<ServerNode>,
962        next_ready_check_delay_ms: i64,
963        unknown_leader_tables: HashSet<Arc<PhysicalTablePath>>,
964    ) -> Self {
965        ReadyCheckResult {
966            ready_nodes,
967            next_ready_check_delay_ms,
968            unknown_leader_tables,
969        }
970    }
971}
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976    use crate::metadata::TablePath;
977    use crate::row::{Datum, GenericRow};
978    use crate::test_utils::{build_cluster, build_table_info};
979    use std::sync::Arc;
980
981    fn disabled_idempotence() -> Arc<IdempotenceManager> {
982        Arc::new(IdempotenceManager::new(false, 5))
983    }
984
985    fn enabled_idempotence() -> Arc<IdempotenceManager> {
986        Arc::new(IdempotenceManager::new(true, 5))
987    }
988
989    #[tokio::test]
990    async fn re_enqueue_increments_attempts() -> Result<()> {
991        let config = Config::default();
992        let accumulator = RecordAccumulator::new(config, disabled_idempotence());
993        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
994        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
995        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 1));
996        let cluster = Arc::new(build_cluster(&table_path, 1, 1));
997        let row = GenericRow {
998            values: vec![Datum::Int32(1)],
999        };
1000        let record = WriteRecord::for_append(table_info, physical_table_path, 1, &row);
1001
1002        accumulator.append(&record, 0, &cluster, false)?;
1003
1004        let server = cluster.get_tablet_server(1).expect("server");
1005        let nodes = HashSet::from([server.clone()]);
1006        let mut batches = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024)?;
1007        let mut drained = batches.remove(&1).expect("drained batches");
1008        let batch = drained.pop().expect("batch");
1009        assert_eq!(batch.write_batch.attempts(), 0);
1010
1011        accumulator.re_enqueue(batch);
1012
1013        let mut batches = accumulator.drain(cluster, &nodes, 1024 * 1024)?;
1014        let mut drained = batches.remove(&1).expect("drained batches");
1015        let batch = drained.pop().expect("batch");
1016        assert_eq!(batch.write_batch.attempts(), 1);
1017        Ok(())
1018    }
1019
1020    #[tokio::test]
1021    async fn flush_counter_decremented_on_error() -> Result<()> {
1022        use crate::client::write::broadcast::BroadcastOnce;
1023        use std::sync::atomic::Ordering;
1024
1025        let config = Config::default();
1026        let accumulator = RecordAccumulator::new(config, disabled_idempotence());
1027
1028        accumulator.begin_flush();
1029        assert_eq!(accumulator.flushes_in_progress.load(Ordering::SeqCst), 1);
1030
1031        // Create a failing batch by dropping the BroadcastOnce without broadcasting
1032        {
1033            let broadcast = BroadcastOnce::default();
1034            let receiver = broadcast.receiver();
1035            let handle = ResultHandle::new(receiver);
1036            let permit = accumulator.memory_limiter.acquire(1024).unwrap();
1037            accumulator
1038                .incomplete_batches
1039                .write()
1040                .insert(1, (handle, permit));
1041            // broadcast is dropped here, causing an error
1042        }
1043
1044        // Await flush completion should fail but still decrement counter
1045        let result = accumulator.await_flush_completion().await;
1046        assert!(result.is_err());
1047
1048        // Counter should still be decremented (this is the critical fix!)
1049        assert_eq!(accumulator.flushes_in_progress.load(Ordering::SeqCst), 0);
1050        assert!(!accumulator.flush_in_progress());
1051
1052        Ok(())
1053    }
1054
1055    fn append_and_drain(
1056        accumulator: &RecordAccumulator,
1057        cluster: &Arc<crate::cluster::Cluster>,
1058        table_path: &TablePath,
1059        bucket_id: i32,
1060    ) -> Result<ReadyWriteBatch> {
1061        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 2));
1062        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
1063        let row = GenericRow {
1064            values: vec![Datum::Int32(1)],
1065        };
1066        let record = WriteRecord::for_append(table_info, physical_table_path, 1, &row);
1067        accumulator.append(&record, bucket_id, cluster, false)?;
1068        let server = cluster.get_tablet_server(1).expect("server");
1069        let nodes = HashSet::from([server.clone()]);
1070        let mut batches = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024)?;
1071        let mut drained = batches.remove(&1).expect("drained batches");
1072        Ok(drained.pop().expect("batch"))
1073    }
1074
1075    #[test]
1076    fn test_should_stop_drain_for_fresh_batch_over_limit() {
1077        let idempotence = Arc::new(IdempotenceManager::new(true, 2));
1078        idempotence.set_writer_id(42);
1079        let config = Config::default();
1080        let accumulator = RecordAccumulator::new(config, Arc::clone(&idempotence));
1081        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1082        let cluster = Arc::new(build_cluster(&table_path, 1, 1));
1083        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 1));
1084        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
1085        let row = GenericRow {
1086            values: vec![Datum::Int32(1)],
1087        };
1088        let record = WriteRecord::for_append(table_info, physical_table_path, 1, &row);
1089        accumulator
1090            .append(&record, 0, &cluster, false)
1091            .expect("append");
1092
1093        let table_bucket = TableBucket::new(1, 0);
1094
1095        // Add 2 in-flight batches (reaching the max_in_flight=2)
1096        idempotence.add_in_flight_batch(&table_bucket, 0, 100);
1097        idempotence.add_in_flight_batch(&table_bucket, 1, 101);
1098
1099        // Get the front batch from the deque
1100        let entry = accumulator
1101            .write_batches
1102            .get(&PhysicalTablePath::of(Arc::new(table_path)))
1103            .unwrap();
1104        let dq = entry.batches.get(&0).unwrap();
1105        let dq_guard = dq.lock();
1106        let first_batch = dq_guard.front().unwrap();
1107
1108        // Fresh batch (no batch_sequence) with in-flight at limit → should stop
1109        assert!(!first_batch.has_batch_sequence());
1110        assert!(accumulator.should_stop_drain_batches_for_bucket(first_batch, &table_bucket));
1111
1112        // Remove one in-flight → under limit → should not stop
1113        drop(dq_guard);
1114        idempotence.remove_in_flight_batch(&table_bucket, 101);
1115        let dq_guard = entry.batches.get(&0).unwrap().lock();
1116        let first_batch = dq_guard.front().unwrap();
1117        assert!(!accumulator.should_stop_drain_batches_for_bucket(first_batch, &table_bucket));
1118    }
1119
1120    #[test]
1121    fn test_should_stop_drain_for_retry_not_first_inflight() {
1122        let idempotence = enabled_idempotence();
1123        idempotence.set_writer_id(42);
1124        let config = Config::default();
1125        let accumulator = RecordAccumulator::new(config, Arc::clone(&idempotence));
1126        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1127        let cluster = Arc::new(build_cluster(&table_path, 1, 1));
1128
1129        // Drain two separate batches to get batch0(seq=0) and batch1(seq=1)
1130        let batch0 =
1131            append_and_drain(&accumulator, &cluster, &table_path, 0).expect("drain batch0");
1132        let batch1 =
1133            append_and_drain(&accumulator, &cluster, &table_path, 0).expect("drain batch1");
1134
1135        assert_eq!(batch0.write_batch.batch_sequence(), 0);
1136        assert_eq!(batch1.write_batch.batch_sequence(), 1);
1137
1138        let batch1_id = batch1.write_batch.batch_id();
1139        let table_bucket = batch0.table_bucket.clone();
1140
1141        // Re-enqueue only batch1 (simulating batch0 still in-flight, batch1 got error)
1142        accumulator.re_enqueue(batch1);
1143
1144        let entry = accumulator
1145            .write_batches
1146            .get(&PhysicalTablePath::of(Arc::new(table_path)))
1147            .unwrap();
1148        let dq = entry.batches.get(&0).unwrap();
1149        let dq_guard = dq.lock();
1150        let first_batch = dq_guard.front().unwrap();
1151
1152        // Batch1 is re-enqueued with seq=1, but batch0 (seq=0) is the first in-flight.
1153        // batch1's batch_id != first in-flight batch_id → should stop.
1154        assert!(first_batch.has_batch_sequence());
1155        assert_eq!(first_batch.batch_id(), batch1_id);
1156        assert!(accumulator.should_stop_drain_batches_for_bucket(first_batch, &table_bucket));
1157    }
1158
1159    #[tokio::test]
1160    async fn test_insert_in_sequence_order() -> Result<()> {
1161        let idempotence = enabled_idempotence();
1162        idempotence.set_writer_id(42);
1163        let config = Config::default();
1164        let accumulator = RecordAccumulator::new(config, Arc::clone(&idempotence));
1165        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1166        let cluster = Arc::new(build_cluster(&table_path, 1, 2));
1167
1168        // Create and drain 3 batches to get them with sequences 0, 1, 2
1169        let batch0 = append_and_drain(&accumulator, &cluster, &table_path, 0)?;
1170        let batch1 = append_and_drain(&accumulator, &cluster, &table_path, 0)?;
1171        let batch2 = append_and_drain(&accumulator, &cluster, &table_path, 0)?;
1172
1173        assert_eq!(batch0.write_batch.batch_sequence(), 0);
1174        assert_eq!(batch1.write_batch.batch_sequence(), 1);
1175        assert_eq!(batch2.write_batch.batch_sequence(), 2);
1176
1177        let batch0_id = batch0.write_batch.batch_id();
1178        let batch1_id = batch1.write_batch.batch_id();
1179        let batch2_id = batch2.write_batch.batch_id();
1180        let table_bucket = batch0.table_bucket.clone();
1181
1182        // Re-enqueue in reverse order: 2, 0, 1
1183        // insert_in_sequence_order should sort them as: 0, 1, 2
1184        accumulator.re_enqueue(batch2);
1185        accumulator.re_enqueue(batch0);
1186        accumulator.re_enqueue(batch1);
1187
1188        // Verify the deque order directly
1189        let entry = accumulator
1190            .write_batches
1191            .get(&PhysicalTablePath::of(Arc::new(table_path)))
1192            .unwrap();
1193        let dq = entry.batches.get(&0).unwrap();
1194        let dq_guard = dq.lock();
1195        assert_eq!(dq_guard.len(), 3);
1196        // batch0 (seq=0) is the first in-flight, so it should be at front
1197        assert_eq!(dq_guard[0].batch_id(), batch0_id);
1198        assert_eq!(dq_guard[0].batch_sequence(), 0);
1199        assert_eq!(dq_guard[1].batch_id(), batch1_id);
1200        assert_eq!(dq_guard[1].batch_sequence(), 1);
1201        assert_eq!(dq_guard[2].batch_id(), batch2_id);
1202        assert_eq!(dq_guard[2].batch_sequence(), 2);
1203        drop(dq_guard);
1204
1205        // Drain: first in-flight is seq=0, so batch0 passes should_stop check
1206        let server = cluster.get_tablet_server(1).expect("server");
1207        let nodes = HashSet::from([server.clone()]);
1208        let mut batches = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024)?;
1209        let drained = batches.remove(&1).expect("drained batches");
1210        assert_eq!(drained.len(), 1);
1211        assert_eq!(drained[0].write_batch.batch_sequence(), 0);
1212
1213        // Complete batch0 so batch1 becomes first in-flight
1214        idempotence.handle_completed_batch(&table_bucket, batch0_id, 42);
1215
1216        let mut batches = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024)?;
1217        let drained = batches.remove(&1).expect("drained");
1218        assert_eq!(drained[0].write_batch.batch_sequence(), 1);
1219
1220        idempotence.handle_completed_batch(&table_bucket, batch1_id, 42);
1221
1222        let mut batches = accumulator.drain(cluster, &nodes, 1024 * 1024)?;
1223        let drained = batches.remove(&1).expect("drained");
1224        assert_eq!(drained[0].write_batch.batch_sequence(), 2);
1225
1226        Ok(())
1227    }
1228
1229    #[tokio::test]
1230    async fn test_abort_batches() -> Result<()> {
1231        let idempotence = disabled_idempotence();
1232        let config = Config::default();
1233        let accumulator = RecordAccumulator::new(config, Arc::clone(&idempotence));
1234        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1235        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
1236        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 1));
1237        let cluster = Arc::new(build_cluster(&table_path, 1, 1));
1238        let row = GenericRow {
1239            values: vec![Datum::Int32(1)],
1240        };
1241        let record = WriteRecord::for_append(table_info, physical_table_path, 1, &row);
1242
1243        let result = accumulator.append(&record, 0, &cluster, false)?;
1244        let handle = result.result_handle.expect("handle");
1245        assert!(accumulator.has_incomplete());
1246
1247        accumulator.abort_batches(broadcast::Error::Client {
1248            message: "test abort".to_string(),
1249        });
1250
1251        assert!(!accumulator.has_incomplete());
1252        assert!(!accumulator.has_undrained());
1253
1254        // The handle should receive the error
1255        let batch_result = handle.wait().await?;
1256        assert!(matches!(
1257            batch_result,
1258            Err(broadcast::Error::Client { message }) if message == "test abort"
1259        ));
1260        Ok(())
1261    }
1262
1263    #[tokio::test]
1264    async fn test_drain_skips_blocked_bucket_continues_others() -> Result<()> {
1265        // Use max_in_flight=1 so that one in-flight batch blocks further draining
1266        let idempotence = Arc::new(IdempotenceManager::new(true, 1));
1267        idempotence.set_writer_id(42);
1268        let config = Config::default();
1269        let accumulator = RecordAccumulator::new(config, Arc::clone(&idempotence));
1270        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1271        let cluster = Arc::new(build_cluster(&table_path, 1, 2));
1272
1273        // Append to both buckets
1274        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 2));
1275        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
1276        let row = GenericRow {
1277            values: vec![Datum::Int32(1)],
1278        };
1279
1280        // Append to bucket 0
1281        let record =
1282            WriteRecord::for_append(table_info.clone(), physical_table_path.clone(), 1, &row);
1283        accumulator.append(&record, 0, &cluster, false)?;
1284
1285        // Append to bucket 1
1286        let record =
1287            WriteRecord::for_append(table_info.clone(), physical_table_path.clone(), 1, &row);
1288        accumulator.append(&record, 1, &cluster, false)?;
1289
1290        // Drain once — both buckets get batches assigned with sequences
1291        let server = cluster.get_tablet_server(1).expect("server");
1292        let nodes = HashSet::from([server.clone()]);
1293        let batches = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024)?;
1294        let drained = batches.get(&1).expect("drained");
1295        // Both buckets should produce batches
1296        assert_eq!(drained.len(), 2);
1297
1298        // Now: both buckets have 1 in-flight each (added during drain).
1299        // Append another record to each bucket.
1300        let record =
1301            WriteRecord::for_append(table_info.clone(), physical_table_path.clone(), 1, &row);
1302        accumulator.append(&record, 0, &cluster, false)?;
1303        let record = WriteRecord::for_append(table_info, physical_table_path, 1, &row);
1304        accumulator.append(&record, 1, &cluster, false)?;
1305
1306        // With max_in_flight=1, both buckets are at limit → should_stop returns true
1307        // for fresh batches. The drain should skip both (continue, not break).
1308        let batches2 = accumulator.drain(cluster.clone(), &nodes, 1024 * 1024)?;
1309        // No batches should be drained (both blocked)
1310        assert!(
1311            batches2.is_empty() || batches2.get(&1).is_none_or(|b| b.is_empty()),
1312            "Expected no batches when all buckets are blocked"
1313        );
1314
1315        // Complete the in-flight for bucket 0
1316        let bucket0_batch = &drained[0];
1317        idempotence.handle_completed_batch(
1318            &bucket0_batch.table_bucket,
1319            bucket0_batch.write_batch.batch_id(),
1320            42,
1321        );
1322
1323        // Now bucket 0 is unblocked but bucket 1 is still blocked
1324        let batches3 = accumulator.drain(cluster, &nodes, 1024 * 1024)?;
1325        let drained3 = batches3.get(&1).expect("some drained");
1326        // Only bucket 0 should produce a batch (continue skipped bucket 1)
1327        assert_eq!(drained3.len(), 1);
1328        assert_eq!(drained3[0].table_bucket.bucket_id(), 0);
1329
1330        Ok(())
1331    }
1332
1333    #[test]
1334    fn test_memory_limiter_acquire_release() {
1335        let limiter = Arc::new(MemoryLimiter::new(1024, Duration::from_secs(1)));
1336
1337        let permit1 = limiter.acquire(512).unwrap();
1338        let permit2 = limiter.acquire(512).unwrap();
1339
1340        // At capacity — verify used is 1024
1341        assert_eq!(*limiter.state.lock(), 1024);
1342
1343        // Release one permit, verify used drops
1344        drop(permit1);
1345        assert_eq!(*limiter.state.lock(), 512);
1346
1347        drop(permit2);
1348        assert_eq!(*limiter.state.lock(), 0);
1349    }
1350
1351    #[test]
1352    fn test_memory_limiter_oversized_batch_fails_immediately() {
1353        let limiter = Arc::new(MemoryLimiter::new(1024, Duration::from_secs(60)));
1354
1355        let result = limiter.acquire(2048);
1356        assert!(matches!(result.unwrap_err(), Error::IllegalArgument { .. }));
1357    }
1358
1359    #[test]
1360    fn test_memory_limiter_blocks_then_unblocks() {
1361        let limiter = Arc::new(MemoryLimiter::new(1024, Duration::from_secs(5)));
1362
1363        let permit = limiter.acquire(1024).unwrap();
1364        assert_eq!(*limiter.state.lock(), 1024);
1365
1366        // Spawn a thread that tries to acquire — it should block
1367        let limiter2 = Arc::clone(&limiter);
1368        let handle = std::thread::spawn(move || limiter2.acquire(512));
1369
1370        // Give the thread time to block
1371        std::thread::sleep(Duration::from_millis(50));
1372        // Still at capacity (thread is blocked)
1373        assert_eq!(*limiter.state.lock(), 1024);
1374
1375        // Release the permit — thread should unblock
1376        drop(permit);
1377
1378        let result = handle.join().unwrap();
1379        assert!(result.is_ok());
1380        let _permit2 = result.unwrap();
1381        assert_eq!(*limiter.state.lock(), 512);
1382    }
1383
1384    #[test]
1385    fn test_memory_limiter_timeout() {
1386        let limiter = Arc::new(MemoryLimiter::new(1024, Duration::from_millis(100)));
1387
1388        let _permit = limiter.acquire(1024).unwrap();
1389
1390        // This should timeout
1391        let start = Instant::now();
1392        let result = limiter.acquire(512);
1393        let elapsed = start.elapsed();
1394
1395        assert!(matches!(result.unwrap_err(), Error::BufferExhausted { .. }));
1396        assert!(elapsed >= Duration::from_millis(80)); // allow some timing slack
1397    }
1398
1399    #[test]
1400    fn test_memory_limiter_close_fails_immediately() {
1401        let limiter = Arc::new(MemoryLimiter::new(1024, Duration::from_secs(60)));
1402
1403        let _permit = limiter.acquire(512).unwrap();
1404
1405        limiter.close();
1406
1407        // New acquire should fail immediately, not block for 60s
1408        let start = Instant::now();
1409        let result = limiter.acquire(256);
1410        let elapsed = start.elapsed();
1411
1412        assert!(matches!(result.unwrap_err(), Error::WriterClosed { .. }));
1413        assert!(elapsed < Duration::from_millis(50));
1414    }
1415
1416    #[test]
1417    fn test_memory_limiter_close_unblocks_waiting_threads() {
1418        let limiter = Arc::new(MemoryLimiter::new(1024, Duration::from_secs(60)));
1419
1420        // Fill the limiter completely
1421        let _permit = limiter.acquire(1024).unwrap();
1422
1423        // Spawn a thread that blocks waiting for memory
1424        let limiter2 = Arc::clone(&limiter);
1425        let handle = std::thread::spawn(move || {
1426            let start = Instant::now();
1427            let result = limiter2.acquire(512);
1428            (result, start.elapsed())
1429        });
1430
1431        // Give the thread time to block
1432        std::thread::sleep(Duration::from_millis(50));
1433        assert_eq!(limiter.waiting_count.load(Ordering::Relaxed), 1);
1434
1435        // Close the limiter — should unblock the waiting thread
1436        limiter.close();
1437
1438        let (result, elapsed) = handle.join().unwrap();
1439        assert!(matches!(result.unwrap_err(), Error::WriterClosed { .. }));
1440        assert!(elapsed < Duration::from_secs(5)); // should not wait the full 60s
1441    }
1442
1443    #[test]
1444    fn test_oversized_kv_record_does_not_panic() {
1445        use crate::client::write::write_format::WriteFormat;
1446        use crate::client::write::{RowBytes, WriteRecord};
1447        use bytes::Bytes;
1448
1449        // Use a tiny batch size so the KV record exceeds it
1450        let config = Config {
1451            writer_batch_size: 64,
1452            writer_buffer_memory_size: 1024 * 1024,
1453            ..Config::default()
1454        };
1455
1456        let accumulator = RecordAccumulator::new(config, disabled_idempotence());
1457        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1458        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 1));
1459        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
1460        let cluster = Arc::new(build_cluster(&table_path, 1, 1));
1461
1462        // Create a KV record larger than batch_size (64 bytes)
1463        let key = Bytes::from(vec![0u8; 32]);
1464        let value = vec![0u8; 256];
1465        let record = WriteRecord::for_upsert(
1466            table_info,
1467            physical_table_path,
1468            1,
1469            key,
1470            None,
1471            WriteFormat::CompactedKv,
1472            None,
1473            Some(RowBytes::Owned(Bytes::from(value))),
1474        );
1475
1476        // This used to panic with "must append to a new batch" because
1477        // the KV write limit was hardcoded to DEFAULT_WRITE_LIMIT (256 bytes)
1478        // instead of using alloc_size = max(batch_size, record_size).
1479        let result = accumulator.append(&record, 0, &cluster, false);
1480        assert!(result.is_ok(), "oversized KV record should not panic");
1481    }
1482
1483    #[test]
1484    fn test_memory_permit_accounts_for_oversized_record() {
1485        use crate::client::write::write_format::WriteFormat;
1486        use crate::client::write::{RowBytes, WriteRecord};
1487        use bytes::Bytes;
1488
1489        let config = Config {
1490            writer_batch_size: 64,
1491            writer_buffer_memory_size: 1024 * 1024,
1492            ..Config::default()
1493        };
1494
1495        let accumulator = RecordAccumulator::new(config, disabled_idempotence());
1496        let table_path = TablePath::new("db".to_string(), "tbl".to_string());
1497        let table_info = Arc::new(build_table_info(table_path.clone(), 1, 1));
1498        let physical_table_path = Arc::new(PhysicalTablePath::of(Arc::new(table_path.clone())));
1499        let cluster = Arc::new(build_cluster(&table_path, 1, 1));
1500
1501        let key = Bytes::from(vec![0u8; 32]);
1502        let value = vec![0u8; 256];
1503        let record = WriteRecord::for_upsert(
1504            table_info,
1505            physical_table_path,
1506            1,
1507            key,
1508            None,
1509            WriteFormat::CompactedKv,
1510            None,
1511            Some(RowBytes::Owned(Bytes::from(value))),
1512        );
1513
1514        // estimated_record_size includes batch header overhead
1515        let expected_alloc = record.estimated_record_size();
1516        assert!(expected_alloc > 64, "record should exceed batch_size=64");
1517
1518        accumulator.append(&record, 0, &cluster, false).unwrap();
1519
1520        // The permit should reserve max(batch_size, estimated_record_size) bytes.
1521        let used = *accumulator.memory_limiter.state.lock();
1522        assert_eq!(
1523            used, expected_alloc,
1524            "memory limiter should reserve max(batch_size, estimated_record_size)"
1525        );
1526    }
1527
1528    #[tokio::test]
1529    async fn test_sender_wakeup_notifies() {
1530        let accumulator = RecordAccumulator::new(Config::default(), disabled_idempotence());
1531
1532        // notified() should complete when wakeup_sender() is called
1533        let notified = accumulator.notified();
1534        accumulator.wakeup_sender();
1535        // If wakeup doesn't work, this would hang forever.
1536        tokio::time::timeout(Duration::from_millis(100), notified)
1537            .await
1538            .expect("notified should complete after wakeup_sender");
1539    }
1540}