recoco_utils/
concur_control.rs

1// ReCoco is a Rust-only fork of CocoIndex, by [CocoIndex](https://CocoIndex)
2// Original code from CocoIndex is copyrighted by CocoIndex
3// SPDX-FileCopyrightText: 2025-2026 CocoIndex (upstream)
4// SPDX-FileContributor: CocoIndex Contributors
5//
6// All modifications from the upstream for ReCoco are copyrighted by Knitli Inc.
7// SPDX-FileCopyrightText: 2026 Knitli Inc. (ReCoco)
8// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
9//
10// Both the upstream CocoIndex code and the ReCoco modifications are licensed under the Apache-2.0 License.
11// SPDX-License-Identifier: Apache-2.0
12
13use std::sync::Arc;
14use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
15
16struct WeightedSemaphore {
17    downscale_factor: u8,
18    downscaled_quota: u32,
19    sem: Arc<Semaphore>,
20}
21
22impl WeightedSemaphore {
23    pub fn new(quota: usize) -> Self {
24        let mut downscale_factor = 0;
25        let mut downscaled_quota = quota;
26        while downscaled_quota > u32::MAX as usize {
27            downscaled_quota >>= 1;
28            downscale_factor += 1;
29        }
30        let sem = Arc::new(Semaphore::new(downscaled_quota));
31        Self {
32            downscaled_quota: downscaled_quota as u32,
33            downscale_factor,
34            sem,
35        }
36    }
37
38    async fn acquire_reservation(&self) -> Result<OwnedSemaphorePermit, AcquireError> {
39        self.sem.clone().acquire_owned().await
40    }
41
42    async fn acquire(
43        &self,
44        weight: usize,
45        reserved: bool,
46    ) -> Result<Option<OwnedSemaphorePermit>, AcquireError> {
47        let downscaled_weight = (weight >> self.downscale_factor) as u32;
48        let capped_weight = downscaled_weight.min(self.downscaled_quota);
49        let reserved_weight = if reserved { 1 } else { 0 };
50        if reserved_weight >= capped_weight {
51            return Ok(None);
52        }
53        Ok(Some(
54            self.sem
55                .clone()
56                .acquire_many_owned(capped_weight - reserved_weight)
57                .await?,
58        ))
59    }
60}
61
62pub struct Options {
63    pub max_inflight_rows: Option<usize>,
64    pub max_inflight_bytes: Option<usize>,
65}
66
67pub struct ConcurrencyControllerPermit {
68    _inflight_count_permit: Option<OwnedSemaphorePermit>,
69    _inflight_bytes_permit: Option<OwnedSemaphorePermit>,
70}
71
72pub struct ConcurrencyController {
73    inflight_count_sem: Option<Arc<Semaphore>>,
74    inflight_bytes_sem: Option<WeightedSemaphore>,
75}
76
77pub static BYTES_UNKNOWN_YET: Option<fn() -> usize> = None;
78
79impl ConcurrencyController {
80    pub fn new(exec_options: &Options) -> Self {
81        Self {
82            inflight_count_sem: exec_options
83                .max_inflight_rows
84                .map(|max| Arc::new(Semaphore::new(max))),
85            inflight_bytes_sem: exec_options.max_inflight_bytes.map(WeightedSemaphore::new),
86        }
87    }
88
89    /// If `bytes_fn` is `None`, it means the number of bytes is not known yet.
90    /// The controller will reserve a minimum number of bytes.
91    /// The caller should call `acquire_bytes_with_reservation` with the actual number of bytes later.
92    pub async fn acquire(
93        &self,
94        bytes_fn: Option<impl FnOnce() -> usize>,
95    ) -> Result<ConcurrencyControllerPermit, AcquireError> {
96        let inflight_count_permit = if let Some(sem) = &self.inflight_count_sem {
97            Some(sem.clone().acquire_owned().await?)
98        } else {
99            None
100        };
101        let inflight_bytes_permit = if let Some(sem) = &self.inflight_bytes_sem {
102            if let Some(bytes_fn) = bytes_fn {
103                sem.acquire(bytes_fn(), false).await?
104            } else {
105                Some(sem.acquire_reservation().await?)
106            }
107        } else {
108            None
109        };
110        Ok(ConcurrencyControllerPermit {
111            _inflight_count_permit: inflight_count_permit,
112            _inflight_bytes_permit: inflight_bytes_permit,
113        })
114    }
115
116    pub async fn acquire_bytes_with_reservation(
117        &self,
118        bytes_fn: impl FnOnce() -> usize,
119    ) -> Result<Option<OwnedSemaphorePermit>, AcquireError> {
120        if let Some(sem) = &self.inflight_bytes_sem {
121            sem.acquire(bytes_fn(), true).await
122        } else {
123            Ok(None)
124        }
125    }
126}
127
128pub struct CombinedConcurrencyControllerPermit {
129    _permit: ConcurrencyControllerPermit,
130    _global_permit: ConcurrencyControllerPermit,
131}
132
133pub struct CombinedConcurrencyController {
134    controller: ConcurrencyController,
135    global_controller: Arc<ConcurrencyController>,
136    needs_num_bytes: bool,
137}
138
139impl CombinedConcurrencyController {
140    pub fn new(exec_options: &Options, global_controller: Arc<ConcurrencyController>) -> Self {
141        Self {
142            controller: ConcurrencyController::new(exec_options),
143            needs_num_bytes: exec_options.max_inflight_bytes.is_some()
144                || global_controller.inflight_bytes_sem.is_some(),
145            global_controller,
146        }
147    }
148
149    pub async fn acquire(
150        &self,
151        bytes_fn: Option<impl FnOnce() -> usize>,
152    ) -> Result<CombinedConcurrencyControllerPermit, AcquireError> {
153        let num_bytes_fn = if let Some(bytes_fn) = bytes_fn
154            && self.needs_num_bytes
155        {
156            let num_bytes = bytes_fn();
157            Some(move || num_bytes)
158        } else {
159            None
160        };
161
162        let permit = self.controller.acquire(num_bytes_fn).await?;
163        let global_permit = self.global_controller.acquire(num_bytes_fn).await?;
164        Ok(CombinedConcurrencyControllerPermit {
165            _permit: permit,
166            _global_permit: global_permit,
167        })
168    }
169
170    pub async fn acquire_bytes_with_reservation(
171        &self,
172        bytes_fn: impl FnOnce() -> usize,
173    ) -> Result<(Option<OwnedSemaphorePermit>, Option<OwnedSemaphorePermit>), AcquireError> {
174        let num_bytes = bytes_fn();
175        let permit = self
176            .controller
177            .acquire_bytes_with_reservation(move || num_bytes)
178            .await?;
179        let global_permit = self
180            .global_controller
181            .acquire_bytes_with_reservation(move || num_bytes)
182            .await?;
183        Ok((permit, global_permit))
184    }
185}