1use crate::compaction::{self, CompactionPreparation, CompactionResult};
7use crate::error::{Error, Result};
8use crate::provider::Provider;
9use asupersync::runtime::{JoinHandle, RuntimeHandle};
10use futures::FutureExt;
11use futures::channel::oneshot;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone)]
17pub struct CompactionQuota {
18 pub cooldown: Duration,
20 pub timeout: Duration,
22 pub max_attempts_per_session: u32,
24}
25
26impl Default for CompactionQuota {
27 fn default() -> Self {
28 Self {
29 cooldown: Duration::from_secs(60),
30 timeout: Duration::from_secs(120),
31 max_attempts_per_session: 100,
32 }
33 }
34}
35
36type CompactionOutcome = Result<CompactionResult>;
37
38struct PendingCompaction {
39 join: JoinHandle<CompactionOutcome>,
40 abort_tx: Option<oneshot::Sender<()>>,
41 started_at: Instant,
42}
43
44impl PendingCompaction {
45 fn is_finished(&self) -> bool {
46 self.join.is_finished()
47 }
48
49 fn abort(&mut self) {
50 if let Some(abort_tx) = self.abort_tx.take() {
51 if abort_tx.send(()).is_err() {
52 tracing::debug!("abort signal receiver was already dropped");
53 }
54 }
55 }
56}
57
58pub(crate) struct CompactionWorkerState {
60 pending: Option<PendingCompaction>,
61 last_start: Option<Instant>,
62 attempt_count: u32,
63 quota: CompactionQuota,
64}
65
66impl CompactionWorkerState {
67 pub const fn new(quota: CompactionQuota) -> Self {
68 Self {
69 pending: None,
70 last_start: None,
71 attempt_count: 0,
72 quota,
73 }
74 }
75
76 pub fn can_start(&self) -> bool {
78 if self.pending.is_some() {
79 return false;
80 }
81 if self.attempt_count >= self.quota.max_attempts_per_session {
82 return false;
83 }
84 if let Some(last) = self.last_start {
85 if last.elapsed() < self.quota.cooldown {
86 return false;
87 }
88 }
89 true
90 }
91
92 pub async fn try_recv(&mut self) -> Option<CompactionOutcome> {
94 let timed_out = self
96 .pending
97 .as_ref()
98 .is_some_and(|p| p.started_at.elapsed() > self.quota.timeout);
99
100 if timed_out {
101 if let Some(mut pending) = self.pending.take() {
102 pending.abort();
103 }
104 return Some(Err(Error::session(
105 "Background compaction timed out".to_string(),
106 )));
107 }
108
109 if !self
110 .pending
111 .as_ref()
112 .is_some_and(PendingCompaction::is_finished)
113 {
114 return None;
115 }
116
117 let pending = self.pending.take()?;
118 Some(pending.join.await)
119 }
120
121 pub fn start(
123 &mut self,
124 runtime_handle: &RuntimeHandle,
125 preparation: CompactionPreparation,
126 provider: Arc<dyn Provider>,
127 api_key: String,
128 custom_instructions: Option<String>,
129 ) {
130 debug_assert!(
131 self.can_start(),
132 "start() called while can_start() is false"
133 );
134
135 let (abort_tx, abort_rx) = oneshot::channel();
136 let now = Instant::now();
137 let join = runtime_handle.spawn(async move {
138 run_compaction_task(
139 preparation,
140 provider,
141 api_key,
142 custom_instructions,
143 abort_rx,
144 )
145 .await
146 });
147
148 self.pending = Some(PendingCompaction {
149 join,
150 abort_tx: Some(abort_tx),
151 started_at: now,
152 });
153 self.last_start = Some(now);
154 self.attempt_count = self.attempt_count.saturating_add(1);
155 }
156}
157
158impl Drop for CompactionWorkerState {
159 fn drop(&mut self) {
160 if let Some(mut pending) = self.pending.take() {
161 pending.abort();
162 }
163 }
164}
165
166#[allow(clippy::needless_pass_by_value)]
167async fn run_compaction_task(
168 preparation: CompactionPreparation,
169 provider: Arc<dyn Provider>,
170 api_key: String,
171 custom_instructions: Option<String>,
172 abort_rx: oneshot::Receiver<()>,
173) -> CompactionOutcome {
174 let abort_fut = async move {
175 if abort_rx.await.is_err() {
176 tracing::debug!("abort signal sender was dropped before sending abort");
177 }
178 Err(Error::session("Background compaction aborted".to_string()))
179 }
180 .fuse();
181 let compaction_fut = std::panic::AssertUnwindSafe(compaction::compact(
182 preparation,
183 provider,
184 &api_key,
185 custom_instructions.as_deref(),
186 ))
187 .catch_unwind()
188 .fuse();
189
190 futures::pin_mut!(abort_fut, compaction_fut);
191
192 match futures::future::select(abort_fut, compaction_fut).await {
193 futures::future::Either::Left((abort_result, _)) => abort_result,
194 futures::future::Either::Right((Ok(result), _)) => result,
195 futures::future::Either::Right((Err(_), _)) => Err(Error::session(
196 "Background compaction worker panicked".to_string(),
197 )),
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use std::sync::atomic::{AtomicBool, Ordering};
205
206 fn make_worker(quota: CompactionQuota) -> CompactionWorkerState {
207 CompactionWorkerState::new(quota)
208 }
209
210 fn default_worker() -> CompactionWorkerState {
211 make_worker(CompactionQuota::default())
212 }
213
214 fn run_async<T, F>(make_future: impl FnOnce(RuntimeHandle) -> F) -> T
215 where
216 F: std::future::Future<Output = T>,
217 {
218 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
219 .build()
220 .expect("build test runtime");
221 let runtime_handle = runtime.handle();
222 runtime.block_on(make_future(runtime_handle))
223 }
224
225 fn inject_pending(worker: &mut CompactionWorkerState, pending: PendingCompaction) {
226 worker.pending = Some(pending);
227 worker.last_start = Some(Instant::now());
228 worker.attempt_count += 1;
229 }
230
231 async fn ready_pending_with_handle(
232 runtime_handle: RuntimeHandle,
233 outcome: CompactionOutcome,
234 ) -> PendingCompaction {
235 let join = runtime_handle.spawn(async move { outcome });
236 PendingCompaction {
237 join,
238 abort_tx: None,
239 started_at: Instant::now(),
240 }
241 }
242
243 async fn parked_pending_with_handle(
244 runtime_handle: RuntimeHandle,
245 aborted: Option<Arc<AtomicBool>>,
246 ) -> PendingCompaction {
247 let (abort_tx, abort_rx) = oneshot::channel();
248 let join = runtime_handle.spawn(async move {
249 if abort_rx.await.is_err() {
250 tracing::debug!("abort signal sender was dropped before sending abort");
251 }
252 if let Some(flag) = aborted {
253 flag.store(true, Ordering::SeqCst);
254 }
255 Err(Error::session("Background compaction aborted".to_string()))
256 });
257 PendingCompaction {
258 join,
259 abort_tx: Some(abort_tx),
260 started_at: Instant::now(),
261 }
262 }
263
264 #[test]
265 fn fresh_worker_can_start() {
266 let w = default_worker();
267 assert!(w.can_start());
268 }
269
270 #[test]
271 fn cannot_start_while_pending() {
272 run_async(|runtime_handle| async move {
273 let mut w = default_worker();
274 let pending = parked_pending_with_handle(runtime_handle, None).await;
275 inject_pending(&mut w, pending);
276 assert!(!w.can_start());
277 });
278 }
279
280 #[test]
281 fn cannot_start_during_cooldown() {
282 let mut w = make_worker(CompactionQuota {
283 cooldown: Duration::from_secs(3600),
284 ..CompactionQuota::default()
285 });
286 w.last_start = Some(Instant::now());
287 w.attempt_count = 1;
288 assert!(!w.can_start());
289 }
290
291 #[test]
292 fn can_start_after_cooldown() {
293 let mut w = make_worker(CompactionQuota {
294 cooldown: Duration::from_millis(0),
295 ..CompactionQuota::default()
296 });
297 w.last_start = Some(
298 Instant::now()
299 .checked_sub(Duration::from_secs(1))
300 .unwrap_or_else(Instant::now),
301 );
302 w.attempt_count = 1;
303 assert!(w.can_start());
304 }
305
306 #[test]
307 fn max_attempts_blocks_start() {
308 let mut w = make_worker(CompactionQuota {
309 max_attempts_per_session: 2,
310 cooldown: Duration::from_millis(0),
311 ..CompactionQuota::default()
312 });
313 w.attempt_count = 2;
314 assert!(!w.can_start());
315 }
316
317 #[test]
318 fn try_recv_none_when_no_pending() {
319 run_async(|_runtime_handle| async move {
320 let mut w = default_worker();
321 assert!(w.try_recv().await.is_none());
322 });
323 }
324
325 #[test]
326 fn try_recv_none_when_not_ready() {
327 run_async(|runtime_handle| async move {
328 let mut w = default_worker();
329 let pending = parked_pending_with_handle(runtime_handle, None).await;
330 inject_pending(&mut w, pending);
331 assert!(w.try_recv().await.is_none());
333 assert!(w.pending.is_some());
335 });
336 }
337
338 #[test]
339 fn dropping_worker_aborts_pending_task() {
340 run_async(|runtime_handle| async move {
341 let aborted = Arc::new(AtomicBool::new(false));
342 let mut w = default_worker();
343 let pending =
344 parked_pending_with_handle(runtime_handle, Some(Arc::clone(&aborted))).await;
345 inject_pending(&mut w, pending);
346
347 drop(w);
348 asupersync::time::sleep(
349 asupersync::time::wall_now(),
350 std::time::Duration::from_millis(50),
351 )
352 .await;
353
354 assert!(
355 aborted.load(Ordering::SeqCst),
356 "dropping the worker should abort the pending task"
357 );
358 });
359 }
360
361 #[test]
362 fn try_recv_timeout() {
363 run_async(|runtime_handle| async move {
364 let aborted = Arc::new(AtomicBool::new(false));
365 let mut w = make_worker(CompactionQuota {
366 timeout: Duration::from_millis(0),
367 ..CompactionQuota::default()
368 });
369 let mut pending =
370 parked_pending_with_handle(runtime_handle, Some(Arc::clone(&aborted))).await;
371 pending.started_at = Instant::now()
372 .checked_sub(Duration::from_secs(1))
373 .unwrap_or_else(Instant::now);
374 inject_pending(&mut w, pending);
375
376 let outcome = w.try_recv().await.expect("should return timeout error");
377 assert!(outcome.is_err());
378 let err_msg = outcome.unwrap_err().to_string();
379 assert!(err_msg.contains("timed out"), "got: {err_msg}");
380
381 asupersync::time::sleep(
382 asupersync::time::wall_now(),
383 std::time::Duration::from_millis(50),
384 )
385 .await;
386 assert!(
387 aborted.load(Ordering::SeqCst),
388 "timing out the worker should abort the pending task"
389 );
390 });
391 }
392
393 #[test]
394 fn try_recv_success() {
395 run_async(|runtime_handle| async move {
396 let mut w = default_worker();
397
398 let result = CompactionResult {
400 summary: "test summary".to_string(),
401 first_kept_entry_id: "entry-1".to_string(),
402 tokens_before: 1000,
403 details: compaction::CompactionDetails {
404 read_files: vec![],
405 modified_files: vec![],
406 },
407 };
408 let pending = ready_pending_with_handle(runtime_handle, Ok(result)).await;
409 inject_pending(&mut w, pending);
410 asupersync::time::sleep(
411 asupersync::time::wall_now(),
412 std::time::Duration::from_millis(50),
413 )
414 .await;
415
416 let outcome = w.try_recv().await.expect("should have result");
417 let result = outcome.expect("should be Ok");
418 assert_eq!(result.summary, "test summary");
419 assert!(w.pending.is_none());
420 });
421 }
422}