1use std::{
2 any::Any,
3 marker::PhantomData,
4 mem::ManuallyDrop,
5 pin::Pin,
6 ptr,
7 task::{self, Poll, RawWaker, RawWakerVTable, Waker},
8};
9
10use sealed::sealed;
11use unicycle::StreamsUnordered;
12
13use self::pinarcmutex::{PinArcMutex, PinArcMutexGuard};
14use crate::envelope::Envelope;
15
16pub(crate) trait SourceStream: Send + 'static {
17 fn as_any_mut(self: Pin<&mut Self>) -> Pin<&mut dyn Any>;
18 fn poll_recv(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>>;
19}
20
21#[must_use = "sources do nothing unless you attach them"]
28pub struct UnattachedSource<H> {
29 source: UntypedSourceArc,
30 handle: H,
31}
32
33impl<H> UnattachedSource<H> {
34 pub(crate) fn new<S>(source: SourceArc<S>, handle: impl FnOnce(SourceArc<S>) -> H) -> Self
35 where
36 S: SourceStream + ?Sized,
37 {
38 Self {
39 source: source.inner.to_owner(),
40 handle: handle(source),
41 }
42 }
43
44 pub(crate) fn attach_to(self, sources: &mut Sources) -> H {
45 sources.push(self.source);
46 self.handle
47 }
48}
49
50#[sealed(pub(crate))]
54pub trait SourceHandle {
55 fn is_terminated(&self) -> bool;
57
58 fn terminate(self) -> bool
63 where
64 Self: Sized,
65 {
66 self.terminate_by_ref()
67 }
68
69 fn terminate_by_ref(&self) -> bool;
74}
75
76pub(crate) struct SourceArc<S: ?Sized> {
79 inner: UntypedSourceArc,
80 marker: PhantomData<S>,
81}
82
83impl<S: ?Sized> SourceArc<S> {
84 pub(crate) fn from_untyped(inner: UntypedSourceArc) -> Self {
88 let marker = PhantomData;
89 Self { inner, marker }
90 }
91}
92
93impl<S: SourceStream> SourceArc<S> {
94 pub(crate) fn new(source: S, oneshot: bool) -> Self {
95 Self::from_untyped(UntypedSourceArc::new(source, oneshot))
96 }
97}
98
99impl<S: ?Sized> SourceArc<S> {
100 pub(crate) fn lock(&self) -> Option<SourceStreamGuard<'_, S>> {
102 let inner = self.inner.inner.lock();
103
104 if inner.status() == StreamStatus::Terminated {
108 return None;
109 }
110
111 Some(SourceStreamGuard {
112 inner,
113 marker: PhantomData,
114 })
115 }
116
117 pub(crate) fn is_terminated(&self) -> bool {
118 self.inner.inner.lock().status() == StreamStatus::Terminated
119 }
120
121 pub(crate) fn terminate_by_ref(&self) -> bool {
122 if let Some(guard) = self.lock() {
123 guard.terminate();
124 true
125 } else {
126 false
127 }
128 }
129}
130
131pub(crate) struct SourceStreamGuard<'a, S: ?Sized> {
134 inner: PinArcMutexGuard<'a, StreamWithWaker<dyn SourceStream>>,
135 marker: PhantomData<S>,
136}
137
138impl<S: ?Sized> SourceStreamGuard<'_, S> {
139 pub(crate) fn terminate(mut self) {
140 self.inner.get_mut().terminate();
141
142 self.inner.wake();
145 }
146
147 pub(crate) fn wake(&self) {
148 self.inner.wake();
149 }
150}
151
152impl<S: 'static> SourceStreamGuard<'_, S> {
153 pub(crate) fn stream(&mut self) -> Pin<&mut S> {
154 let inner = self.inner.get_mut();
155 let stream = inner.stream().as_any_mut();
156
157 unsafe { stream.map_unchecked_mut(|s| s.downcast_mut::<S>().expect("invalid source type")) }
159 }
160}
161
162pub(crate) struct UntypedSourceArc {
165 is_owner: bool,
168 inner: PinArcMutex<StreamWithWaker<dyn SourceStream>>,
169}
170
171impl UntypedSourceArc {
172 pub(crate) fn new(stream: impl SourceStream, oneshot: bool) -> Self {
173 Self {
174 is_owner: false,
175 inner: pinarcmutex::new!(StreamWithWaker {
176 waker: noop_waker(),
177 status: if oneshot {
178 StreamStatus::Oneshot
179 } else {
180 StreamStatus::Stream
181 },
182 stream: ManuallyDrop::new(stream),
183 }),
184 }
185 }
186
187 fn to_owner(&self) -> Self {
188 Self {
189 is_owner: true,
190 inner: self.inner.clone(),
191 }
192 }
193}
194
195impl Drop for UntypedSourceArc {
196 fn drop(&mut self) {
197 if !self.is_owner {
201 return;
202 }
203
204 let mut inner = self.inner.lock();
205 if inner.status() != StreamStatus::Terminated {
206 inner.get_mut().terminate();
207 }
208 }
209}
210
211struct StreamWithWaker<S: ?Sized> {
212 waker: Waker,
213 status: StreamStatus,
214 stream: ManuallyDrop<S>,
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222enum StreamStatus {
223 Terminated,
224 Stream,
225 Oneshot,
226}
227
228impl<S: ?Sized> StreamWithWaker<S> {
229 fn status(&self) -> StreamStatus {
230 self.status
231 }
232
233 fn update_waker(self: Pin<&mut Self>, cx: &task::Context<'_>) {
234 let new_waker = cx.waker();
235
236 unsafe { self.get_unchecked_mut().waker.clone_from(new_waker) }
241 }
242
243 fn wake(&self) {
244 self.waker.wake_by_ref();
245 }
246
247 fn stream(self: Pin<&mut Self>) -> Pin<&mut S> {
248 assert_ne!(self.status, StreamStatus::Terminated);
249
250 unsafe { self.map_unchecked_mut(|s| &mut *s.stream) }
253 }
254
255 fn terminate(self: Pin<&mut Self>) {
256 assert_ne!(self.status, StreamStatus::Terminated);
257
258 let this = unsafe { self.get_unchecked_mut() };
260 this.status = StreamStatus::Terminated;
261
262 unsafe { ManuallyDrop::drop(&mut this.stream) };
265 }
266}
267
268impl futures::Stream for UntypedSourceArc {
269 type Item = Envelope;
270
271 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
272 let mut guard = self.inner.lock();
273
274 if guard.status() == StreamStatus::Terminated {
276 return Poll::Ready(None);
277 }
278
279 let result = guard.get_mut().stream().poll_recv(cx);
280
281 if result.is_pending() {
282 guard.get_mut().update_waker(cx);
284 } else if matches!(result, Poll::Ready(None)) || guard.status() == StreamStatus::Oneshot {
285 guard.get_mut().terminate();
286 }
287
288 result
289 }
290}
291
292fn noop_waker() -> Waker {
293 unsafe { Waker::from_raw(noop_raw_waker()) }
295}
296
297fn noop_raw_waker() -> RawWaker {
298 fn noop_clone(_: *const ()) -> RawWaker {
299 noop_raw_waker()
300 }
301 fn noop_wake(_: *const ()) {}
302 fn noop_wake_by_ref(_: *const ()) {}
303 fn noop_drop(_: *const ()) {}
304
305 let vtable = &RawWakerVTable::new(noop_clone, noop_wake, noop_wake_by_ref, noop_drop);
306 RawWaker::new(ptr::null(), vtable)
307}
308
309pub(crate) type Sources = StreamsUnordered<UntypedSourceArc>;
310
311mod pinarcmutex {
314 use std::{ops::Deref, pin::Pin, sync::Arc};
315
316 use parking_lot::{Mutex, MutexGuard};
317
318 macro_rules! new {
320 ($value:expr) => {
321 pinarcmutex::PinArcMutex {
322 __inner: std::sync::Arc::new(parking_lot::Mutex::new($value)),
323 }
324 };
325 }
326 pub(super) use new;
327
328 pub(super) struct PinArcMutex<T: ?Sized> {
329 pub(super) __inner: Arc<Mutex<T>>,
331 }
332
333 impl<T: ?Sized> PinArcMutex<T> {
334 pub(super) fn lock(&self) -> PinArcMutexGuard<'_, T> {
335 PinArcMutexGuard(self.__inner.lock())
336 }
337 }
338
339 impl<T: ?Sized> Clone for PinArcMutex<T> {
340 fn clone(&self) -> Self {
341 Self {
342 __inner: self.__inner.clone(),
343 }
344 }
345 }
346
347 pub(super) struct PinArcMutexGuard<'a, T: ?Sized>(MutexGuard<'a, T>);
348
349 impl<T: ?Sized> PinArcMutexGuard<'_, T> {
350 pub(super) fn get_mut(&mut self) -> Pin<&mut T> {
351 unsafe { Pin::new_unchecked(&mut *self.0) }
353 }
354 }
355
356 impl<T: ?Sized> Deref for PinArcMutexGuard<'_, T> {
357 type Target = T;
358
359 fn deref(&self) -> &Self::Target {
360 &self.0
363 }
364 }
365}