1use fortanix_sgx_abi::FifoDescriptor;
8
9use super::{Fifo, Identified, QueueEvent, Receiver, RecvError, Sender, SendError, SynchronizationError, Synchronizer, Transmittable, TryRecvError, TrySendError};
10
11unsafe impl<T: Send, S: Send> Send for Sender<T, S> {}
12unsafe impl<T: Send, S: Sync> Sync for Sender<T, S> {}
13
14impl<T, S: Clone> Clone for Sender<T, S> {
15 fn clone(&self) -> Self {
16 Self {
17 inner: self.inner.clone(),
18 synchronizer: self.synchronizer.clone(),
19 }
20 }
21}
22
23impl<T: Transmittable, S: Synchronizer> Sender<T, S> {
24 pub unsafe fn from_descriptor(d: FifoDescriptor<T>, synchronizer: S) -> Self {
39 Self {
40 inner: Fifo::from_descriptor(d),
41 synchronizer,
42 }
43 }
44
45 pub fn try_send(&self, val: Identified<T>) -> Result<(), TrySendError> {
46 self.inner.try_send_impl(val).map(|wake_receiver| {
47 if wake_receiver {
48 self.synchronizer.notify(QueueEvent::NotEmpty);
49 }
50 })
51 }
52
53 pub fn try_send_multiple(&self, values: &[Identified<T>]) -> Result<usize, TrySendError> {
60 let mut wake_receiver = false;
61 let mut sent = 0;
62 for val in values {
63 wake_receiver |= match self.inner.try_send_impl(*val) {
64 Ok(wake_receiver) => wake_receiver,
65 Err(e) if sent == 0 => return Err(e),
66 Err(_) => break,
67 };
68 sent += 1;
69 }
70 if wake_receiver {
71 self.synchronizer.notify(QueueEvent::NotEmpty);
72 }
73 Ok(sent)
74 }
75
76 pub fn send(&self, val: Identified<T>) -> Result<(), SendError> {
77 loop {
78 match self.inner.try_send_impl(val) {
79 Ok(wake_receiver) => {
80 if wake_receiver {
81 self.synchronizer.notify(QueueEvent::NotEmpty);
82 }
83 return Ok(());
84 }
85 Err(TrySendError::QueueFull) => {
86 self.synchronizer
87 .wait(QueueEvent::NotFull)
88 .map_err(|SynchronizationError::ChannelClosed| SendError::Closed)?;
89 }
90 Err(TrySendError::Closed) => return Err(SendError::Closed),
91 };
92 }
93 }
94}
95
96unsafe impl<T: Send, S: Send> Send for Receiver<T, S> {}
97
98impl<T: Transmittable, S: Synchronizer> Receiver<T, S> {
99 pub unsafe fn from_descriptor(d: FifoDescriptor<T>, synchronizer: S) -> Self {
108 Self {
109 inner: Fifo::from_descriptor(d),
110 synchronizer,
111 }
112 }
113
114 pub fn try_recv(&self) -> Result<Identified<T>, TryRecvError> {
115 self.inner.try_recv_impl().map(|(val, wake_sender, _)| {
116 if wake_sender {
117 self.synchronizer.notify(QueueEvent::NotFull);
118 }
119 val
120 })
121 }
122
123 pub fn try_iter(&self) -> TryIter<'_, T, S> {
124 TryIter(self)
125 }
126
127 pub fn recv(&self) -> Result<Identified<T>, RecvError> {
128 loop {
129 match self.inner.try_recv_impl() {
130 Ok((val, wake_sender, _)) => {
131 if wake_sender {
132 self.synchronizer.notify(QueueEvent::NotFull);
133 }
134 return Ok(val);
135 }
136 Err(TryRecvError::QueueEmpty) => {
137 self.synchronizer
138 .wait(QueueEvent::NotEmpty)
139 .map_err(|SynchronizationError::ChannelClosed| RecvError::Closed)?;
140 }
141 Err(TryRecvError::Closed) => return Err(RecvError::Closed),
142 }
143 }
144 }
145}
146
147pub struct TryIter<'r, T: 'static, S>(&'r Receiver<T, S>);
148
149impl<'r, T: Transmittable, S: Synchronizer> Iterator for TryIter<'r, T, S> {
150 type Item = Identified<T>;
151
152 fn next(&mut self) -> Option<Self::Item> {
153 self.0.try_recv().ok()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use crate::fifo::bounded;
160 use crate::test_support::pubsub::{Channel, Subscription};
161 use crate::test_support::TestValue;
162 use crate::*;
163 use std::thread;
164
165 fn do_single_sender(len: usize, n: u64) {
166 let s = TestSynchronizer::new();
167 let (tx, rx) = bounded(len, s);
168
169 let h = thread::spawn(move || {
170 for i in 0..n {
171 tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
172 }
173 });
174
175 for i in 0..n {
176 let v = rx.recv().unwrap();
177 assert_eq!(v.id, i + 1);
178 assert_eq!(v.data.0, i);
179 }
180
181 h.join().unwrap();
182 }
183
184 #[test]
185 fn single_sender() {
186 do_single_sender(4, 10);
187 do_single_sender(1, 10);
188 do_single_sender(32, 1024);
189 do_single_sender(1024, 32);
190 }
191
192 fn do_multi_sender(len: usize, n: u64, senders: u64) {
193 let s = TestSynchronizer::new();
194 let (tx, rx) = bounded(len, s);
195 let mut handles = Vec::with_capacity(senders as _);
196
197 for t in 0..senders {
198 let tx = tx.clone();
199 handles.push(thread::spawn(move || {
200 for i in 0..n {
201 let id = t * n + i + 1;
202 tx.send(Identified { id, data: TestValue(i) }).unwrap();
203 }
204 }));
205 }
206
207 for _ in 0..(n * senders) {
208 rx.recv().unwrap();
209 }
210
211 for h in handles {
212 h.join().unwrap();
213 }
214 }
215
216 #[test]
217 fn multi_sender() {
218 do_multi_sender(4, 10, 3);
219 do_multi_sender(4, 1, 100);
220 do_multi_sender(2, 10, 100);
221 do_multi_sender(1024, 30, 100);
222 }
223
224 #[test]
225 fn try_error() {
226 const N: u64 = 8;
227 let s = TestSynchronizer::new();
228 let (tx, rx) = bounded(N as _, s);
229
230 for i in 0..N {
231 tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
232 }
233 assert!(tx.try_send(Identified { id: N + 1, data: TestValue(N) }).is_err());
234
235 for i in 0..N {
236 let v = rx.recv().unwrap();
237 assert_eq!(v.id, i + 1);
238 assert_eq!(v.data.0, i);
239 }
240 assert!(rx.try_recv().is_err());
241 }
242
243 #[test]
244 fn very_optimistic() {
245 const N: u64 = 8;
246 let s = TestSynchronizer::new();
247 let (tx, rx) = bounded(N as _, s);
248
249 for i in 0..N {
250 tx.try_send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
251 }
252
253 for i in 0..N {
254 let v = rx.try_recv().unwrap();
255 assert_eq!(v.id, i + 1);
256 assert_eq!(v.data.0, i);
257 }
258 }
259
260 #[test]
261 fn mixed_try_send() {
262 let s = TestSynchronizer::new();
263 let (tx, rx) = bounded(8, s);
264
265 let h = thread::spawn(move || {
266 let mut sent_without_wait = 0;
267 for _ in 0..7 {
268 for i in 0..11 {
269 let v = Identified { id: i + 1, data: TestValue(i) };
270 if let Err(_) = tx.try_send(v) {
271 tx.send(v).unwrap();
272 } else {
273 sent_without_wait += 1;
274 }
275 }
276 }
277 assert!(sent_without_wait > 0);
278 });
279
280 for _ in 0..7 {
281 for i in 0..11 {
282 let v = rx.recv().unwrap();
283 assert_eq!(v.id, i + 1);
284 assert_eq!(v.data.0, i);
285 }
286 }
287
288 h.join().unwrap();
289 }
290
291 #[test]
292 fn mixed_try_recv() {
293 let s = TestSynchronizer::new();
294 let (tx, rx) = bounded(8, s);
295
296 let h = thread::spawn(move || {
297 for _ in 0..11 {
298 for i in 0..13 {
299 tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
300 }
301 }
302 });
303
304 for _ in 0..11 {
305 for i in 0..13 {
306 let v = match rx.try_recv() {
307 Ok(v) => v,
308 Err(_) => rx.recv().unwrap(),
309 };
310 assert_eq!(v.id, i + 1);
311 assert_eq!(v.data.0, i);
312 }
313 }
314
315 h.join().unwrap();
316 }
317
318 #[test]
319 fn try_iter() {
320 let s = TestSynchronizer::new();
321 let (tx, rx) = bounded(8, s);
322 const N: u64 = 2048;
323
324 let h = thread::spawn(move || {
325 for i in 0..N {
326 tx.send(Identified { id: i + 1, data: TestValue(i) }).unwrap();
327 }
328 });
329
330 let mut total = 0;
331 while total < N {
332 for v in rx.recv().ok().into_iter().chain(rx.try_iter()) {
333 assert_eq!(v.id, total + 1);
334 assert_eq!(v.data.0, total);
335 total += 1;
336 }
337 }
338
339 h.join().unwrap();
340 }
341
342 #[test]
343 fn try_send_multiple() {
344 let s = TestSynchronizer::new();
345 let (tx, rx) = bounded(32, s);
346 const SENDERS: usize = 4;
347 const N: usize = 1024;
348 let mut handles = Vec::with_capacity(SENDERS);
349
350 for t in 0..SENDERS {
351 let tx = tx.clone();
352 handles.push(thread::spawn(move || {
353 let mut to_send = Vec::with_capacity(N);
354 for i in 0..N {
355 let id = (t * N + i + 1) as u64;
356 to_send.push(Identified { id, data: TestValue(i as u64) });
357 }
358 let mut sent = 0;
359 while sent < to_send.len() {
360 match tx.try_send_multiple(&to_send[sent..]) {
361 Err(_) => thread::yield_now(),
362 Ok(n) => sent += n,
363 }
364 }
365 }));
366 }
367
368 let mut values = Vec::with_capacity(N * SENDERS);
369 for _ in 0..(N * SENDERS) {
370 values.push(rx.recv().unwrap());
371 }
372 values.sort_by_key(|v| v.id);
373 assert!(values.windows(2).all(|w| w[0].id < w[1].id));
374
375 for h in handles {
376 h.join().unwrap();
377 }
378 }
379
380 #[derive(Clone)]
381 pub struct TestSynchronizer {
382 not_empty: Subscription<()>,
383 not_full: Subscription<()>,
384 }
385
386 impl TestSynchronizer {
387 pub fn new() -> Self {
388 Self {
389 not_empty: Channel::new().subscribe(),
390 not_full: Channel::new().subscribe(),
391 }
392 }
393 }
394
395 impl Synchronizer for TestSynchronizer {
396 fn wait(&self, event: QueueEvent) -> Result<(), SynchronizationError> {
397 match event {
398 QueueEvent::NotEmpty => self.not_empty.recv(),
399 QueueEvent::NotFull => self.not_full.recv(),
400 }.map_err(|_| SynchronizationError::ChannelClosed)
401 }
402
403 fn notify(&self, event: QueueEvent) {
404 let _ = match event {
405 QueueEvent::NotEmpty => self.not_empty.broadcast(()),
406 QueueEvent::NotFull => self.not_full.broadcast(()),
407 };
408 }
409 }
410}