1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::mpsc;
6
7use crate::PluginError;
8
9const QUEUED_WORK_POLL_INTERVAL: Duration = Duration::from_millis(400);
14
15#[derive(Clone, Debug)]
16pub struct QueuedWorkRunRequest {
17 pub session_id: Option<String>,
18 pub reason: String,
19 pub trace_idle: bool,
20}
21
22impl QueuedWorkRunRequest {
23 fn new(session_id: Option<String>, reason: impl Into<String>, trace_idle: bool) -> Self {
24 Self {
25 session_id,
26 reason: reason.into(),
27 trace_idle,
28 }
29 }
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub enum QueuedWorkRunOutcome {
34 Submitted { session_id: String },
35 Idle,
36}
37
38#[async_trait::async_trait]
39pub trait QueuedWorkRunHandle: Send + Sync {
40 async fn run_queued_work(
41 &self,
42 request: QueuedWorkRunRequest,
43 ) -> Result<QueuedWorkRunOutcome, PluginError>;
44}
45
46enum QueuedWorkRunnerCommand {
47 Poke {
48 session_id: Option<String>,
49 reason: String,
50 },
51 Complete {
52 session_id: String,
53 reason: String,
54 },
55}
56
57pub struct QueuedWorkRunner {
58 run_handle: Arc<dyn QueuedWorkRunHandle>,
59 tx: mpsc::UnboundedSender<QueuedWorkRunnerCommand>,
60 rx: mpsc::UnboundedReceiver<QueuedWorkRunnerCommand>,
61}
62
63impl QueuedWorkRunner {
64 pub fn new(run_handle: Arc<dyn QueuedWorkRunHandle>) -> Self {
65 let (tx, rx) = mpsc::unbounded_channel();
66 Self { run_handle, tx, rx }
67 }
68
69 pub fn poke_handle(&self) -> QueuedWorkPoke {
70 QueuedWorkPoke {
71 tx: self.tx.clone(),
72 }
73 }
74
75 pub fn spawn(self) -> QueuedWorkPoke {
76 let poke = self.poke_handle();
77 tokio::spawn(async move {
78 self.run().await;
79 });
80 poke
81 }
82
83 async fn run(mut self) {
84 let mut inflight = HashSet::new();
85 self.drive(
86 QueuedWorkRunRequest::new(None, "startup", false),
87 &mut inflight,
88 )
89 .await;
90 let mut poll = tokio::time::interval(QUEUED_WORK_POLL_INTERVAL);
91 poll.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
92 loop {
93 tokio::select! {
94 command = self.rx.recv() => {
95 let Some(command) = command else {
96 break;
97 };
98 match command {
99 QueuedWorkRunnerCommand::Poke { session_id, reason } => {
100 self.drive(
101 QueuedWorkRunRequest::new(session_id, reason, true),
102 &mut inflight,
103 )
104 .await;
105 }
106 QueuedWorkRunnerCommand::Complete { session_id, reason } => {
107 inflight.remove(&session_id);
108 self.drive(
109 QueuedWorkRunRequest::new(Some(session_id), reason, false),
110 &mut inflight,
111 )
112 .await;
113 }
114 }
115 }
116 _ = poll.tick() => {
117 self.drive(
118 QueuedWorkRunRequest::new(None, "poll", false),
119 &mut inflight,
120 )
121 .await;
122 }
123 }
124 }
125 }
126
127 async fn drive(&self, request: QueuedWorkRunRequest, inflight: &mut HashSet<String>) {
128 if let Some(session_id) = request.session_id.as_deref()
129 && inflight.contains(session_id)
130 {
131 return;
132 }
133 if request.session_id.is_none() && !inflight.is_empty() {
134 return;
135 }
136 match self.run_handle.run_queued_work(request).await {
137 Ok(QueuedWorkRunOutcome::Submitted { session_id }) => {
138 inflight.insert(session_id);
139 }
140 Ok(QueuedWorkRunOutcome::Idle) => {}
141 Err(err) => tracing::warn!("queued work runner drive failed: {err}"),
142 }
143 }
144}
145
146#[derive(Clone)]
147pub struct QueuedWorkPoke {
148 tx: mpsc::UnboundedSender<QueuedWorkRunnerCommand>,
149}
150
151impl QueuedWorkPoke {
152 pub fn poke(&self, reason: impl Into<String>) {
153 let _ = self.tx.send(QueuedWorkRunnerCommand::Poke {
154 session_id: None,
155 reason: reason.into(),
156 });
157 }
158
159 pub fn poke_session(&self, session_id: impl Into<String>, reason: impl Into<String>) {
160 let _ = self.tx.send(QueuedWorkRunnerCommand::Poke {
161 session_id: Some(session_id.into()),
162 reason: reason.into(),
163 });
164 }
165
166 pub fn complete_session(&self, session_id: impl Into<String>, reason: impl Into<String>) {
167 let _ = self.tx.send(QueuedWorkRunnerCommand::Complete {
168 session_id: session_id.into(),
169 reason: reason.into(),
170 });
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::collections::{HashSet, VecDeque};
178 use std::sync::Mutex;
179
180 struct RecordingQueuedWorkRunHandle {
181 requests: mpsc::UnboundedSender<QueuedWorkRunRequest>,
182 responses: Mutex<VecDeque<QueuedWorkRunOutcome>>,
183 }
184
185 #[async_trait::async_trait]
186 impl QueuedWorkRunHandle for RecordingQueuedWorkRunHandle {
187 async fn run_queued_work(
188 &self,
189 request: QueuedWorkRunRequest,
190 ) -> Result<QueuedWorkRunOutcome, PluginError> {
191 self.requests
192 .send(request)
193 .expect("record queued work request");
194 Ok(self
195 .responses
196 .lock()
197 .expect("lock responses")
198 .pop_front()
199 .unwrap_or(QueuedWorkRunOutcome::Idle))
200 }
201 }
202
203 fn recording_handle(
204 responses: impl IntoIterator<Item = QueuedWorkRunOutcome>,
205 ) -> (
206 Arc<RecordingQueuedWorkRunHandle>,
207 mpsc::UnboundedReceiver<QueuedWorkRunRequest>,
208 ) {
209 let (requests, request_rx) = mpsc::unbounded_channel();
210 (
211 Arc::new(RecordingQueuedWorkRunHandle {
212 requests,
213 responses: Mutex::new(responses.into_iter().collect()),
214 }),
215 request_rx,
216 )
217 }
218
219 #[tokio::test]
220 async fn drive_holds_submitted_session_inflight_until_completion() {
221 let (handle, mut requests) = recording_handle([
222 QueuedWorkRunOutcome::Submitted {
223 session_id: "root".to_string(),
224 },
225 QueuedWorkRunOutcome::Idle,
226 ]);
227 let runner = QueuedWorkRunner::new(handle);
228 let mut inflight = HashSet::new();
229
230 runner
231 .drive(
232 QueuedWorkRunRequest::new(Some("root".to_string()), "process_wake", true),
233 &mut inflight,
234 )
235 .await;
236
237 let first = requests.try_recv().expect("first queued work request");
238 assert_eq!(first.session_id.as_deref(), Some("root"));
239 assert_eq!(first.reason, "process_wake");
240 assert!(first.trace_idle);
241 assert!(inflight.contains("root"));
242
243 runner
244 .drive(
245 QueuedWorkRunRequest::new(Some("root".to_string()), "duplicate", true),
246 &mut inflight,
247 )
248 .await;
249 assert!(
250 requests.try_recv().is_err(),
251 "inflight session should suppress duplicate submission"
252 );
253
254 runner
255 .drive(
256 QueuedWorkRunRequest::new(None, "poll", false),
257 &mut inflight,
258 )
259 .await;
260 assert!(
261 requests.try_recv().is_err(),
262 "global poll should not submit while work is inflight"
263 );
264
265 inflight.remove("root");
266 runner
267 .drive(
268 QueuedWorkRunRequest::new(Some("root".to_string()), "queued_turn_completed", false),
269 &mut inflight,
270 )
271 .await;
272
273 let resumed = requests
274 .try_recv()
275 .expect("completion should re-drive the session");
276 assert_eq!(resumed.session_id.as_deref(), Some("root"));
277 assert_eq!(resumed.reason, "queued_turn_completed");
278 assert!(!resumed.trace_idle);
279 }
280}