crux_core/capability/
shell_stream.rs

1use std::{
2    sync::{Arc, Mutex},
3    task::{Poll, Waker},
4};
5
6use futures::Stream;
7
8use super::{channel, channel::Receiver};
9use crate::core::Request;
10
11pub struct ShellStream<T> {
12    shared_state: Arc<Mutex<SharedState<T>>>,
13}
14
15struct SharedState<T> {
16    receiver: Receiver<T>,
17    waker: Option<Waker>,
18    send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
19}
20
21impl<T> Stream for ShellStream<T> {
22    type Item = T;
23
24    fn poll_next(
25        self: std::pin::Pin<&mut Self>,
26        cx: &mut std::task::Context<'_>,
27    ) -> Poll<Option<Self::Item>> {
28        let mut shared_state = self.shared_state.lock().unwrap();
29
30        if let Some(send_request) = shared_state.send_request.take() {
31            send_request();
32        }
33
34        match shared_state.receiver.try_receive() {
35            Ok(Some(next)) => Poll::Ready(Some(next)),
36            Ok(None) => {
37                shared_state.waker = Some(cx.waker().clone());
38                Poll::Pending
39            }
40            Err(()) => Poll::Ready(None),
41        }
42    }
43}
44
45#[expect(deprecated)]
46impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
47where
48    Op: crate::capability::Operation,
49    Ev: 'static,
50{
51    /// Send an effect request to the shell, expecting a stream of responses
52    ///
53    /// # Panics
54    ///
55    /// Panics if we can't acquire the shared state lock.
56    pub fn stream_from_shell(&self, operation: Op) -> ShellStream<Op::Output> {
57        let (sender, receiver) = channel();
58        let shared_state = Arc::new(Mutex::new(SharedState {
59            receiver,
60            waker: None,
61            send_request: None,
62        }));
63
64        // Our callback holds a weak pointer so the channel can be freed
65        // whenever the associated task ends.
66        let callback_shared_state = Arc::downgrade(&shared_state);
67
68        let request = Request::resolves_many_times(operation, move |result| {
69            let Some(shared_state) = callback_shared_state.upgrade() else {
70                // Let the caller know that the associated task has finished.
71                return Err(());
72            };
73
74            let mut shared_state = shared_state.lock().unwrap();
75
76            sender.send(result);
77            if let Some(waker) = shared_state.waker.take() {
78                waker.wake();
79            }
80
81            Ok(())
82        });
83
84        // Put a callback into our shared_state so that we only send
85        // our request to the shell when the stream is first polled.
86        let send_req_context = self.clone();
87        let send_request = move || send_req_context.send_request(request);
88        shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
89
90        ShellStream { shared_state }
91    }
92}
93
94#[cfg(test)]
95#[expect(deprecated)]
96mod tests {
97    use assert_matches::assert_matches;
98
99    use crate::capability::{CapabilityContext, Operation, channel, executor_and_spawner};
100
101    #[derive(Clone, PartialEq, Eq, Debug)]
102    struct TestOperation;
103
104    impl Operation for TestOperation {
105        type Output = Option<Done>;
106    }
107
108    #[derive(serde::Deserialize, PartialEq, Eq, Debug)]
109    struct Done;
110
111    #[test]
112    fn test_shell_stream() {
113        let (request_sender, requests) = channel();
114        let (event_sender, events) = channel::<()>();
115        let (executor, spawner) = executor_and_spawner();
116        let capability_context =
117            CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
118
119        let mut stream = capability_context.stream_from_shell(TestOperation);
120
121        // The stream hasn't been polled so we shouldn't have any requests.
122        assert_matches!(requests.receive(), None);
123        assert_matches!(events.receive(), None);
124
125        // It also shouldn't have spawned anything so check that
126        executor.run_all();
127        assert_matches!(requests.receive(), None);
128        assert_matches!(events.receive(), None);
129
130        spawner.spawn(async move {
131            use futures::StreamExt;
132            while let Some(maybe_done) = stream.next().await {
133                event_sender.send(());
134                if maybe_done.is_some() {
135                    break;
136                }
137            }
138        });
139
140        // We still shouldn't have any requests
141        assert_matches!(requests.receive(), None);
142        assert_matches!(events.receive(), None);
143
144        executor.run_all();
145        let mut request = requests.receive().expect("we should have a request here");
146
147        assert_matches!(requests.receive(), None);
148        assert_matches!(events.receive(), None);
149
150        request.resolve(None).unwrap();
151
152        executor.run_all();
153
154        // We should have one event
155        assert_matches!(requests.receive(), None);
156        assert_matches!(events.receive(), Some(()));
157        assert_matches!(events.receive(), None);
158
159        // Resolve it a few more times and then finish.
160        request.resolve(None).unwrap();
161        request.resolve(None).unwrap();
162        request.resolve(Some(Done)).unwrap();
163        executor.run_all();
164
165        // We should have three events
166        assert_matches!(requests.receive(), None);
167        assert_matches!(events.receive(), Some(()));
168        assert_matches!(events.receive(), Some(()));
169        assert_matches!(events.receive(), Some(()));
170        assert_matches!(events.receive(), None);
171
172        // The next resolve should error as we've terminated the task
173        request
174            .resolve(None)
175            .expect_err("resolving a finished task should error");
176    }
177}