crux_core/capability/
shell_stream.rs1use std::{
2 sync::{Arc, Mutex},
3 task::{Poll, Waker},
4};
5
6use futures::Stream;
7
8use super::{channel, channel::Receiver};
9use crate::core::Request;
10
11pub struct ShellStream<T> {
12 shared_state: Arc<Mutex<SharedState<T>>>,
13}
14
15struct SharedState<T> {
16 receiver: Receiver<T>,
17 waker: Option<Waker>,
18 send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
19}
20
21impl<T> Stream for ShellStream<T> {
22 type Item = T;
23
24 fn poll_next(
25 self: std::pin::Pin<&mut Self>,
26 cx: &mut std::task::Context<'_>,
27 ) -> Poll<Option<Self::Item>> {
28 let mut shared_state = self.shared_state.lock().unwrap();
29
30 if let Some(send_request) = shared_state.send_request.take() {
31 send_request();
32 }
33
34 match shared_state.receiver.try_receive() {
35 Ok(Some(next)) => Poll::Ready(Some(next)),
36 Ok(None) => {
37 shared_state.waker = Some(cx.waker().clone());
38 Poll::Pending
39 }
40 Err(()) => Poll::Ready(None),
41 }
42 }
43}
44
45#[expect(deprecated)]
46impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
47where
48 Op: crate::capability::Operation,
49 Ev: 'static,
50{
51 pub fn stream_from_shell(&self, operation: Op) -> ShellStream<Op::Output> {
57 let (sender, receiver) = channel();
58 let shared_state = Arc::new(Mutex::new(SharedState {
59 receiver,
60 waker: None,
61 send_request: None,
62 }));
63
64 let callback_shared_state = Arc::downgrade(&shared_state);
67
68 let request = Request::resolves_many_times(operation, move |result| {
69 let Some(shared_state) = callback_shared_state.upgrade() else {
70 return Err(());
72 };
73
74 let mut shared_state = shared_state.lock().unwrap();
75
76 sender.send(result);
77 if let Some(waker) = shared_state.waker.take() {
78 waker.wake();
79 }
80
81 Ok(())
82 });
83
84 let send_req_context = self.clone();
87 let send_request = move || send_req_context.send_request(request);
88 shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
89
90 ShellStream { shared_state }
91 }
92}
93
94#[cfg(test)]
95#[expect(deprecated)]
96mod tests {
97 use assert_matches::assert_matches;
98
99 use crate::capability::{CapabilityContext, Operation, channel, executor_and_spawner};
100
101 #[derive(Clone, PartialEq, Eq, Debug)]
102 struct TestOperation;
103
104 impl Operation for TestOperation {
105 type Output = Option<Done>;
106 }
107
108 #[derive(serde::Deserialize, PartialEq, Eq, Debug)]
109 struct Done;
110
111 #[test]
112 fn test_shell_stream() {
113 let (request_sender, requests) = channel();
114 let (event_sender, events) = channel::<()>();
115 let (executor, spawner) = executor_and_spawner();
116 let capability_context =
117 CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
118
119 let mut stream = capability_context.stream_from_shell(TestOperation);
120
121 assert_matches!(requests.receive(), None);
123 assert_matches!(events.receive(), None);
124
125 executor.run_all();
127 assert_matches!(requests.receive(), None);
128 assert_matches!(events.receive(), None);
129
130 spawner.spawn(async move {
131 use futures::StreamExt;
132 while let Some(maybe_done) = stream.next().await {
133 event_sender.send(());
134 if maybe_done.is_some() {
135 break;
136 }
137 }
138 });
139
140 assert_matches!(requests.receive(), None);
142 assert_matches!(events.receive(), None);
143
144 executor.run_all();
145 let mut request = requests.receive().expect("we should have a request here");
146
147 assert_matches!(requests.receive(), None);
148 assert_matches!(events.receive(), None);
149
150 request.resolve(None).unwrap();
151
152 executor.run_all();
153
154 assert_matches!(requests.receive(), None);
156 assert_matches!(events.receive(), Some(()));
157 assert_matches!(events.receive(), None);
158
159 request.resolve(None).unwrap();
161 request.resolve(None).unwrap();
162 request.resolve(Some(Done)).unwrap();
163 executor.run_all();
164
165 assert_matches!(requests.receive(), None);
167 assert_matches!(events.receive(), Some(()));
168 assert_matches!(events.receive(), Some(()));
169 assert_matches!(events.receive(), Some(()));
170 assert_matches!(events.receive(), None);
171
172 request
174 .resolve(None)
175 .expect_err("resolving a finished task should error");
176 }
177}