Skip to main content

lash_core/runtime/
queued_work_runner.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::mpsc;
6
7use crate::PluginError;
8
9/// How often the queued-work runner checks for ready work absent a poke.
10///
11/// Pokes are the normal prompt path. Polling is the recovery path for queued
12/// work that already existed at startup or whose wake notification was dropped.
13const QUEUED_WORK_POLL_INTERVAL: Duration = Duration::from_millis(400);
14
15#[derive(Clone, Debug)]
16pub struct QueuedWorkRunRequest {
17    pub session_id: Option<String>,
18    pub reason: String,
19    pub trace_idle: bool,
20}
21
22impl QueuedWorkRunRequest {
23    fn new(session_id: Option<String>, reason: impl Into<String>, trace_idle: bool) -> Self {
24        Self {
25            session_id,
26            reason: reason.into(),
27            trace_idle,
28        }
29    }
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub enum QueuedWorkRunOutcome {
34    Submitted { session_id: String },
35    Idle,
36}
37
38#[async_trait::async_trait]
39pub trait QueuedWorkRunHandle: Send + Sync {
40    async fn run_queued_work(
41        &self,
42        request: QueuedWorkRunRequest,
43    ) -> Result<QueuedWorkRunOutcome, PluginError>;
44}
45
46enum QueuedWorkRunnerCommand {
47    Poke {
48        session_id: Option<String>,
49        reason: String,
50    },
51    Complete {
52        session_id: String,
53        reason: String,
54    },
55}
56
57pub struct QueuedWorkRunner {
58    run_handle: Arc<dyn QueuedWorkRunHandle>,
59    tx: mpsc::UnboundedSender<QueuedWorkRunnerCommand>,
60    rx: mpsc::UnboundedReceiver<QueuedWorkRunnerCommand>,
61}
62
63impl QueuedWorkRunner {
64    pub fn new(run_handle: Arc<dyn QueuedWorkRunHandle>) -> Self {
65        let (tx, rx) = mpsc::unbounded_channel();
66        Self { run_handle, tx, rx }
67    }
68
69    pub fn poke_handle(&self) -> QueuedWorkPoke {
70        QueuedWorkPoke {
71            tx: self.tx.clone(),
72        }
73    }
74
75    pub fn spawn(self) -> QueuedWorkPoke {
76        let poke = self.poke_handle();
77        tokio::spawn(async move {
78            self.run().await;
79        });
80        poke
81    }
82
83    async fn run(mut self) {
84        let mut inflight = HashSet::new();
85        self.drive(
86            QueuedWorkRunRequest::new(None, "startup", false),
87            &mut inflight,
88        )
89        .await;
90        let mut poll = tokio::time::interval(QUEUED_WORK_POLL_INTERVAL);
91        poll.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
92        loop {
93            tokio::select! {
94                command = self.rx.recv() => {
95                    let Some(command) = command else {
96                        break;
97                    };
98                    match command {
99                        QueuedWorkRunnerCommand::Poke { session_id, reason } => {
100                            self.drive(
101                                QueuedWorkRunRequest::new(session_id, reason, true),
102                                &mut inflight,
103                            )
104                            .await;
105                        }
106                        QueuedWorkRunnerCommand::Complete { session_id, reason } => {
107                            inflight.remove(&session_id);
108                            self.drive(
109                                QueuedWorkRunRequest::new(Some(session_id), reason, false),
110                                &mut inflight,
111                            )
112                            .await;
113                        }
114                    }
115                }
116                _ = poll.tick() => {
117                    self.drive(
118                        QueuedWorkRunRequest::new(None, "poll", false),
119                        &mut inflight,
120                    )
121                    .await;
122                }
123            }
124        }
125    }
126
127    async fn drive(&self, request: QueuedWorkRunRequest, inflight: &mut HashSet<String>) {
128        if let Some(session_id) = request.session_id.as_deref()
129            && inflight.contains(session_id)
130        {
131            return;
132        }
133        if request.session_id.is_none() && !inflight.is_empty() {
134            return;
135        }
136        match self.run_handle.run_queued_work(request).await {
137            Ok(QueuedWorkRunOutcome::Submitted { session_id }) => {
138                inflight.insert(session_id);
139            }
140            Ok(QueuedWorkRunOutcome::Idle) => {}
141            Err(err) => tracing::warn!("queued work runner drive failed: {err}"),
142        }
143    }
144}
145
146#[derive(Clone)]
147pub struct QueuedWorkPoke {
148    tx: mpsc::UnboundedSender<QueuedWorkRunnerCommand>,
149}
150
151impl QueuedWorkPoke {
152    pub fn poke(&self, reason: impl Into<String>) {
153        let _ = self.tx.send(QueuedWorkRunnerCommand::Poke {
154            session_id: None,
155            reason: reason.into(),
156        });
157    }
158
159    pub fn poke_session(&self, session_id: impl Into<String>, reason: impl Into<String>) {
160        let _ = self.tx.send(QueuedWorkRunnerCommand::Poke {
161            session_id: Some(session_id.into()),
162            reason: reason.into(),
163        });
164    }
165
166    pub fn complete_session(&self, session_id: impl Into<String>, reason: impl Into<String>) {
167        let _ = self.tx.send(QueuedWorkRunnerCommand::Complete {
168            session_id: session_id.into(),
169            reason: reason.into(),
170        });
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::collections::{HashSet, VecDeque};
178    use std::sync::Mutex;
179
180    struct RecordingQueuedWorkRunHandle {
181        requests: mpsc::UnboundedSender<QueuedWorkRunRequest>,
182        responses: Mutex<VecDeque<QueuedWorkRunOutcome>>,
183    }
184
185    #[async_trait::async_trait]
186    impl QueuedWorkRunHandle for RecordingQueuedWorkRunHandle {
187        async fn run_queued_work(
188            &self,
189            request: QueuedWorkRunRequest,
190        ) -> Result<QueuedWorkRunOutcome, PluginError> {
191            self.requests
192                .send(request)
193                .expect("record queued work request");
194            Ok(self
195                .responses
196                .lock()
197                .expect("lock responses")
198                .pop_front()
199                .unwrap_or(QueuedWorkRunOutcome::Idle))
200        }
201    }
202
203    fn recording_handle(
204        responses: impl IntoIterator<Item = QueuedWorkRunOutcome>,
205    ) -> (
206        Arc<RecordingQueuedWorkRunHandle>,
207        mpsc::UnboundedReceiver<QueuedWorkRunRequest>,
208    ) {
209        let (requests, request_rx) = mpsc::unbounded_channel();
210        (
211            Arc::new(RecordingQueuedWorkRunHandle {
212                requests,
213                responses: Mutex::new(responses.into_iter().collect()),
214            }),
215            request_rx,
216        )
217    }
218
219    #[tokio::test]
220    async fn drive_holds_submitted_session_inflight_until_completion() {
221        let (handle, mut requests) = recording_handle([
222            QueuedWorkRunOutcome::Submitted {
223                session_id: "root".to_string(),
224            },
225            QueuedWorkRunOutcome::Idle,
226        ]);
227        let runner = QueuedWorkRunner::new(handle);
228        let mut inflight = HashSet::new();
229
230        runner
231            .drive(
232                QueuedWorkRunRequest::new(Some("root".to_string()), "process_wake", true),
233                &mut inflight,
234            )
235            .await;
236
237        let first = requests.try_recv().expect("first queued work request");
238        assert_eq!(first.session_id.as_deref(), Some("root"));
239        assert_eq!(first.reason, "process_wake");
240        assert!(first.trace_idle);
241        assert!(inflight.contains("root"));
242
243        runner
244            .drive(
245                QueuedWorkRunRequest::new(Some("root".to_string()), "duplicate", true),
246                &mut inflight,
247            )
248            .await;
249        assert!(
250            requests.try_recv().is_err(),
251            "inflight session should suppress duplicate submission"
252        );
253
254        runner
255            .drive(
256                QueuedWorkRunRequest::new(None, "poll", false),
257                &mut inflight,
258            )
259            .await;
260        assert!(
261            requests.try_recv().is_err(),
262            "global poll should not submit while work is inflight"
263        );
264
265        inflight.remove("root");
266        runner
267            .drive(
268                QueuedWorkRunRequest::new(Some("root".to_string()), "queued_turn_completed", false),
269                &mut inflight,
270            )
271            .await;
272
273        let resumed = requests
274            .try_recv()
275            .expect("completion should re-drive the session");
276        assert_eq!(resumed.session_id.as_deref(), Some("root"));
277        assert_eq!(resumed.reason, "queued_turn_completed");
278        assert!(!resumed.trace_idle);
279    }
280}