another_rxrust/operators/
take.rs

1use crate::internals::stream_controller::*;
2use crate::prelude::*;
3use std::{
4  marker::PhantomData,
5  sync::{Arc, RwLock},
6};
7
8#[derive(Clone)]
9pub struct Take<Item>
10where
11  Item: Clone + Send + Sync,
12{
13  count: usize,
14  _item: PhantomData<Item>,
15}
16
17impl<'a, Item> Take<Item>
18where
19  Item: Clone + Send + Sync,
20{
21  pub fn new(count: usize) -> Take<Item> {
22    Take { count, _item: PhantomData }
23  }
24  pub fn execute(&self, source: Observable<'a, Item>) -> Observable<'a, Item> {
25    let count = self.count;
26
27    Observable::<Item>::create(move |s| {
28      let n = Arc::new(RwLock::new(0));
29
30      let sctl = StreamController::new(s);
31      let sctl_next = sctl.clone();
32      let sctl_error = sctl.clone();
33      let sctl_complete = sctl.clone();
34
35      source.inner_subscribe(sctl.new_observer(
36        move |serial, x| {
37          let (emit, complete) = {
38            let mut n = n.write().unwrap();
39            let nn = *n;
40            *n += 1;
41            (nn < count, (nn + 1) >= count)
42          };
43          if emit {
44            sctl_next.sink_next(x);
45          }
46          if complete {
47            sctl_next.upstream_abort_observe(&serial);
48            sctl_next.sink_complete(&serial);
49            sctl_next.finalize();
50          }
51        },
52        move |_, e| {
53          sctl_error.sink_error(e);
54        },
55        move |serial| sctl_complete.sink_complete(&serial),
56      ));
57    })
58  }
59}
60
61impl<'a, Item> Observable<'a, Item>
62where
63  Item: Clone + Send + Sync,
64{
65  pub fn take(&self, count: usize) -> Observable<'a, Item> {
66    Take::new(count).execute(self.clone())
67  }
68}
69
70#[cfg(test)]
71mod test {
72  use crate::prelude::*;
73  use std::{thread, time};
74
75  #[test]
76  fn basic() {
77    let o = Observable::create(|s| {
78      for n in 0..10 {
79        s.next(n);
80      }
81      s.complete();
82    });
83
84    o.take(2).subscribe(
85      print_next_fmt!("{}"),
86      print_error!(),
87      print_complete!(),
88    );
89  }
90
91  #[test]
92  fn thread() {
93    let o = Observable::create(|s| {
94      for n in 0..100 {
95        if !s.is_subscribed() {
96          println!("break!");
97          break;
98        }
99        println!("emit {}", n);
100        s.next(n);
101        thread::sleep(time::Duration::from_millis(100));
102      }
103      if s.is_subscribed() {
104        s.complete();
105      }
106    });
107
108    o.take(2).subscribe(
109      print_next_fmt!("{}"),
110      print_error!(),
111      print_complete!(),
112    );
113    thread::sleep(time::Duration::from_millis(1000));
114  }
115}