use std::convert::Infallible;
use crate::{
context::Context,
observable::{CoreObservable, ObservableType},
observer::Observer,
scheduler::{Duration, Instant, Scheduler, Task, TaskState},
};
pub struct Interval<S> {
pub period: Duration,
pub scheduler: S,
}
struct IntervalState<O> {
observer: Option<O>,
counter: usize,
period: Duration,
}
fn interval_task<O, Err>(state: &mut IntervalState<O>) -> TaskState
where
O: Observer<usize, Err>,
{
if let Some(observer) = &mut state.observer
&& !observer.is_closed()
{
let scheduled_time = Instant::now();
observer.next(state.counter);
state.counter += 1;
let next_scheduled_time = scheduled_time + state.period;
let current_time = Instant::now();
let sleep_duration = if next_scheduled_time > current_time {
next_scheduled_time - current_time
} else {
Duration::from_nanos(0)
};
return TaskState::Sleeping(sleep_duration);
}
TaskState::Finished
}
impl<S> ObservableType for Interval<S> {
type Item<'a>
= usize
where
Self: 'a;
type Err = Infallible;
}
impl<S, C> CoreObservable<C> for Interval<S>
where
C: Context,
C::Inner: Observer<usize, Infallible>,
S: Scheduler<Task<IntervalState<C::Inner>>> + Clone,
{
type Unsub = crate::scheduler::TaskHandle;
fn subscribe(self, context: C) -> Self::Unsub {
let observer = context.into_inner();
let state = IntervalState { observer: Some(observer), counter: 0, period: self.period };
let task = Task::new(state, interval_task);
self.scheduler.schedule(task, Some(self.period))
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use super::*;
use crate::{
prelude::*,
scheduler::{Duration, Instant, LocalScheduler, SharedScheduler},
subscription::Subscription,
};
fn create_unsubscribe_task<H: Subscription>(handle: H) -> Task<Option<H>> {
Task::new(Some(handle), |h| {
if let Some(h) = h.take() {
h.unsubscribe();
}
TaskState::Finished
})
}
#[rxrust_macro::test(local)]
async fn test_interval_basic() {
let values = Arc::new(Mutex::new(Vec::new()));
let values_c = values.clone();
let handle = Local::interval(Duration::from_millis(10)).subscribe(move |v| {
values_c.lock().unwrap().push(v);
});
let unsubscribe_task = create_unsubscribe_task(handle);
let _scheduled_task =
LocalScheduler.schedule(unsubscribe_task, Some(Duration::from_millis(65)));
_scheduled_task.await;
let result = values.lock().unwrap().clone();
assert!(result.len() >= 5, "Expected at least 5 values, got {}", result.len());
for (i, &val) in result.iter().enumerate() {
assert_eq!(val, i, "Value at position {} should be {}", i, i);
}
}
#[rxrust_macro::test]
async fn test_interval_shared() {
let values = Arc::new(Mutex::new(Vec::new()));
let values_c = values.clone();
let handle = Shared::interval(Duration::from_millis(10)).subscribe(move |v| {
values_c.lock().unwrap().push(v);
});
let unsubscribe_task = create_unsubscribe_task(handle);
let _scheduled_task =
SharedScheduler.schedule(unsubscribe_task, Some(Duration::from_millis(65)));
_scheduled_task.await;
let result = values.lock().unwrap().clone();
assert!(result.len() >= 5, "Expected at least 5 values, got {}", result.len());
for (i, &val) in result.iter().enumerate() {
assert_eq!(val, i, "Shared interval value at position {} should be {}", i, i);
}
}
#[rxrust_macro::test(local)]
async fn test_interval_timing() {
let start_time = Instant::now();
let values = Arc::new(Mutex::new(Vec::new()));
let values_c = values.clone();
let handle = Local::interval(Duration::from_millis(20)).subscribe(move |v| {
values_c.lock().unwrap().push(v);
});
let unsubscribe_task = create_unsubscribe_task(handle);
let _scheduled_task =
LocalScheduler.schedule(unsubscribe_task, Some(Duration::from_millis(80)));
_scheduled_task.await;
let elapsed_time = start_time.elapsed();
let result = values.lock().unwrap().clone();
assert!(result.len() >= 3, "Expected at least 3 values in 80ms, got {}", result.len());
for (i, &val) in result.iter().enumerate() {
assert_eq!(val, i, "Timing test value at position {} should be {}", i, i);
}
assert!(
elapsed_time >= Duration::from_millis(60),
"Expected elapsed time >= 60ms, got {:?}",
elapsed_time
);
}
#[rxrust_macro::test]
async fn test_interval_unsubscribe() {
let values = Arc::new(Mutex::new(Vec::new()));
let values_c = values.clone();
let handle = Shared::interval(Duration::from_millis(10)).subscribe(move |v| {
values_c.lock().unwrap().push(v);
});
let cancel_interval = create_unsubscribe_task(handle);
let handle = SharedScheduler.schedule(cancel_interval, Some(Duration::from_millis(35)));
handle.await;
let count_at_unsub = values.lock().unwrap().len();
let wati_50_mills = SharedScheduler
.schedule(Task::new((), |_| TaskState::Finished), Some(Duration::from_millis(50)));
wati_50_mills.await;
assert_eq!(values.lock().unwrap().len(), count_at_unsub);
}
#[cfg(not(target_arch = "wasm32"))]
#[rxrust_macro::test(local)]
async fn test_interval_fixed_rate_behavior() {
let interval_period = Duration::from_millis(20);
let slow_processing = Duration::from_millis(30); let fast_processing = Duration::from_millis(5);
let emission_times = Arc::new(Mutex::new(Vec::new()));
let times_clone = emission_times.clone();
let test_start = Instant::now();
let handle = Local::interval(interval_period).subscribe(move |value| {
let emission_time = test_start.elapsed();
times_clone
.lock()
.unwrap()
.push((value, emission_time));
let processing_time = if value < 2 { slow_processing } else { fast_processing };
std::thread::sleep(processing_time);
});
let unsubscribe_task = create_unsubscribe_task(handle);
let _scheduled_task =
LocalScheduler.schedule(unsubscribe_task, Some(Duration::from_millis(200)));
_scheduled_task.await;
let emissions = emission_times.lock().unwrap().clone();
assert!(
emissions.len() >= 4,
"Need at least 4 emissions to test behavior, got {}",
emissions.len()
);
let intervals: Vec<Duration> = emissions
.windows(2)
.map(|pair| pair[1].1 - pair[0].1)
.collect();
let tolerance = Duration::from_millis(10);
let expected_slow_interval = slow_processing;
assert!(
intervals[0] >= expected_slow_interval - tolerance,
"Slow phase 1: interval should be ~processing_time. Expected ~{:?}, got {:?}",
expected_slow_interval,
intervals[0]
);
assert!(
intervals[0] < slow_processing + interval_period,
"Slow phase 1: should NOT add full period on top of processing. Got {:?}",
intervals[0]
);
assert!(
intervals[1] >= expected_slow_interval - tolerance,
"Slow phase 2: interval should be ~processing_time. Expected ~{:?}, got {:?}",
expected_slow_interval,
intervals[1]
);
let expected_fast_interval = interval_period;
assert!(
intervals[2] >= expected_fast_interval - tolerance,
"Fast phase: interval should be ~period. Expected ~{:?}, got {:?}",
expected_fast_interval,
intervals[2]
);
assert!(
intervals[2] < interval_period + tolerance,
"Fast phase: interval should be close to period. Got {:?}",
intervals[2]
);
if intervals.len() > 3 {
assert!(
intervals[3] >= expected_fast_interval - tolerance,
"Fast phase continued: expected ~{:?}, got {:?}",
expected_fast_interval,
intervals[3]
);
}
}
}