recoco_utils/
concur_control.rs1use std::sync::Arc;
14use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
15
16struct WeightedSemaphore {
17 downscale_factor: u8,
18 downscaled_quota: u32,
19 sem: Arc<Semaphore>,
20}
21
22impl WeightedSemaphore {
23 pub fn new(quota: usize) -> Self {
24 let mut downscale_factor = 0;
25 let mut downscaled_quota = quota;
26 while downscaled_quota > u32::MAX as usize {
27 downscaled_quota >>= 1;
28 downscale_factor += 1;
29 }
30 let sem = Arc::new(Semaphore::new(downscaled_quota));
31 Self {
32 downscaled_quota: downscaled_quota as u32,
33 downscale_factor,
34 sem,
35 }
36 }
37
38 async fn acquire_reservation(&self) -> Result<OwnedSemaphorePermit, AcquireError> {
39 self.sem.clone().acquire_owned().await
40 }
41
42 async fn acquire(
43 &self,
44 weight: usize,
45 reserved: bool,
46 ) -> Result<Option<OwnedSemaphorePermit>, AcquireError> {
47 let downscaled_weight = (weight >> self.downscale_factor) as u32;
48 let capped_weight = downscaled_weight.min(self.downscaled_quota);
49 let reserved_weight = if reserved { 1 } else { 0 };
50 if reserved_weight >= capped_weight {
51 return Ok(None);
52 }
53 Ok(Some(
54 self.sem
55 .clone()
56 .acquire_many_owned(capped_weight - reserved_weight)
57 .await?,
58 ))
59 }
60}
61
62pub struct Options {
63 pub max_inflight_rows: Option<usize>,
64 pub max_inflight_bytes: Option<usize>,
65}
66
67pub struct ConcurrencyControllerPermit {
68 _inflight_count_permit: Option<OwnedSemaphorePermit>,
69 _inflight_bytes_permit: Option<OwnedSemaphorePermit>,
70}
71
72pub struct ConcurrencyController {
73 inflight_count_sem: Option<Arc<Semaphore>>,
74 inflight_bytes_sem: Option<WeightedSemaphore>,
75}
76
77pub static BYTES_UNKNOWN_YET: Option<fn() -> usize> = None;
78
79impl ConcurrencyController {
80 pub fn new(exec_options: &Options) -> Self {
81 Self {
82 inflight_count_sem: exec_options
83 .max_inflight_rows
84 .map(|max| Arc::new(Semaphore::new(max))),
85 inflight_bytes_sem: exec_options.max_inflight_bytes.map(WeightedSemaphore::new),
86 }
87 }
88
89 pub async fn acquire(
93 &self,
94 bytes_fn: Option<impl FnOnce() -> usize>,
95 ) -> Result<ConcurrencyControllerPermit, AcquireError> {
96 let inflight_count_permit = if let Some(sem) = &self.inflight_count_sem {
97 Some(sem.clone().acquire_owned().await?)
98 } else {
99 None
100 };
101 let inflight_bytes_permit = if let Some(sem) = &self.inflight_bytes_sem {
102 if let Some(bytes_fn) = bytes_fn {
103 sem.acquire(bytes_fn(), false).await?
104 } else {
105 Some(sem.acquire_reservation().await?)
106 }
107 } else {
108 None
109 };
110 Ok(ConcurrencyControllerPermit {
111 _inflight_count_permit: inflight_count_permit,
112 _inflight_bytes_permit: inflight_bytes_permit,
113 })
114 }
115
116 pub async fn acquire_bytes_with_reservation(
117 &self,
118 bytes_fn: impl FnOnce() -> usize,
119 ) -> Result<Option<OwnedSemaphorePermit>, AcquireError> {
120 if let Some(sem) = &self.inflight_bytes_sem {
121 sem.acquire(bytes_fn(), true).await
122 } else {
123 Ok(None)
124 }
125 }
126}
127
128pub struct CombinedConcurrencyControllerPermit {
129 _permit: ConcurrencyControllerPermit,
130 _global_permit: ConcurrencyControllerPermit,
131}
132
133pub struct CombinedConcurrencyController {
134 controller: ConcurrencyController,
135 global_controller: Arc<ConcurrencyController>,
136 needs_num_bytes: bool,
137}
138
139impl CombinedConcurrencyController {
140 pub fn new(exec_options: &Options, global_controller: Arc<ConcurrencyController>) -> Self {
141 Self {
142 controller: ConcurrencyController::new(exec_options),
143 needs_num_bytes: exec_options.max_inflight_bytes.is_some()
144 || global_controller.inflight_bytes_sem.is_some(),
145 global_controller,
146 }
147 }
148
149 pub async fn acquire(
150 &self,
151 bytes_fn: Option<impl FnOnce() -> usize>,
152 ) -> Result<CombinedConcurrencyControllerPermit, AcquireError> {
153 let num_bytes_fn = if let Some(bytes_fn) = bytes_fn
154 && self.needs_num_bytes
155 {
156 let num_bytes = bytes_fn();
157 Some(move || num_bytes)
158 } else {
159 None
160 };
161
162 let permit = self.controller.acquire(num_bytes_fn).await?;
163 let global_permit = self.global_controller.acquire(num_bytes_fn).await?;
164 Ok(CombinedConcurrencyControllerPermit {
165 _permit: permit,
166 _global_permit: global_permit,
167 })
168 }
169
170 pub async fn acquire_bytes_with_reservation(
171 &self,
172 bytes_fn: impl FnOnce() -> usize,
173 ) -> Result<(Option<OwnedSemaphorePermit>, Option<OwnedSemaphorePermit>), AcquireError> {
174 let num_bytes = bytes_fn();
175 let permit = self
176 .controller
177 .acquire_bytes_with_reservation(move || num_bytes)
178 .await?;
179 let global_permit = self
180 .global_controller
181 .acquire_bytes_with_reservation(move || num_bytes)
182 .await?;
183 Ok((permit, global_permit))
184 }
185}