Skip to main content

razor_stream/client/
timer.rs

1//! The struct that implements a timer for ClientTask
2//!
3//! This module is only for transport implementation, not for the user.
4
5use std::{
6    collections::{BTreeMap, vec_deque::VecDeque},
7    future::Future,
8    mem::swap,
9    pin::Pin,
10    sync::atomic::{AtomicBool, AtomicU64, Ordering},
11    task::*,
12};
13
14use crate::client::task::ClientTaskDone;
15use crate::client::*;
16use crossfire::{stream::AsyncStream, waitgroup::WaitGroupGuard, *};
17
18pub struct ClientTaskItem<T: ClientTask> {
19    pub task: Option<T>,
20    _upstream: WaitGroupGuard<()>,
21}
22
23pub(crate) struct DelayTasksBatch<T: ClientTask> {
24    tasks: BTreeMap<u64, ClientTaskItem<T>>,
25}
26
27pub struct ClientTaskTimer<F: ClientFacts> {
28    conn_id: String,
29    pending_tasks_recv: AsyncStream<mpsc::Array<ClientTaskItem<F::Task>>>,
30    pending_tasks_sender: MAsyncTx<mpsc::Array<ClientTaskItem<F::Task>>>,
31    pending_task_count: AtomicU64,
32
33    sent_tasks: BTreeMap<u64, ClientTaskItem<F::Task>>, // sent_tasks of the current second
34    delay_tasks_queue: VecDeque<DelayTasksBatch<F::Task>>, // sent_tasks of past seconds
35
36    min_delay_seq: u64,
37    task_timeout: usize, // in seconds
38    // TODO what if seq reach max u64, should exit client
39    processed_seq: u64,
40    reg_stopped_flag: AtomicBool,
41}
42
43unsafe impl<T: ClientFacts> Send for ClientTaskTimer<T> {}
44unsafe impl<T: ClientFacts> Sync for ClientTaskTimer<T> {}
45
46impl<F: ClientFacts> ClientTaskTimer<F> {
47    pub fn new(conn_id: String, task_timeout: usize, mut thresholds: usize) -> Self {
48        if thresholds == 0 {
49            thresholds = 500;
50        }
51        let (pending_tx, pending_rx) = mpsc::bounded_async(thresholds * 2);
52        Self {
53            conn_id,
54            pending_tasks_recv: pending_rx.into_stream(),
55            pending_tasks_sender: pending_tx,
56            pending_task_count: AtomicU64::new(0),
57            sent_tasks: BTreeMap::new(),
58            min_delay_seq: 0,
59            task_timeout,
60            delay_tasks_queue: VecDeque::with_capacity(task_timeout),
61            processed_seq: 0,
62            reg_stopped_flag: AtomicBool::new(false),
63        }
64    }
65
66    pub fn pending_task_count_ref(&self) -> &AtomicU64 {
67        &self.pending_task_count
68    }
69
70    pub fn clean_pending_tasks(&mut self, facts: &F) {
71        while let Ok(task) = self.pending_tasks_recv.try_recv() {
72            self.got_pending_task(task);
73        }
74        let mut task_seqs: Vec<u64> = Vec::with_capacity(self.sent_tasks.len());
75        for (key, _) in self.sent_tasks.iter() {
76            task_seqs.push(*key);
77        }
78        for key in task_seqs {
79            let mut task_item = self.sent_tasks.remove(&key).unwrap();
80            let mut task = task_item.task.take().unwrap();
81            task.set_rpc_error(RpcIntErr::IO);
82            facts.error_handle(task);
83        }
84        for tasks_batch_in_second in self.delay_tasks_queue.iter_mut() {
85            let mut task_seqs: Vec<u64> = Vec::with_capacity(tasks_batch_in_second.tasks.len());
86            for (key, _) in tasks_batch_in_second.tasks.iter() {
87                task_seqs.push(*key);
88            }
89            for key in task_seqs {
90                let mut task_item = tasks_batch_in_second.tasks.remove(&key).unwrap();
91                let mut task = task_item.task.take().unwrap();
92                task.set_rpc_error(RpcIntErr::IO);
93                facts.error_handle(task);
94            }
95        }
96    }
97
98    pub fn check_pending_tasks_empty(&mut self) -> bool {
99        while let Ok(task) = self.pending_tasks_recv.try_recv() {
100            self.got_pending_task(task);
101        }
102        if !self.sent_tasks.is_empty() {
103            return false;
104        }
105        for tasks_batch_in_second in self.delay_tasks_queue.iter() {
106            if !tasks_batch_in_second.tasks.is_empty() {
107                return false;
108            }
109        }
110        return true;
111    }
112
113    // register noti for task.
114    #[inline(always)]
115    pub async fn reg_task(&self, task: F::Task, wg: WaitGroupGuard<()>) {
116        let _ = self
117            .pending_tasks_sender
118            .send(ClientTaskItem { task: Some(task), _upstream: wg })
119            .await;
120    }
121
122    // stop register.
123    pub fn stop_reg_task(&mut self) {
124        self.reg_stopped_flag.store(true, Ordering::SeqCst);
125    }
126
127    pub async fn take_task(&mut self, seq: u64) -> Option<ClientTaskItem<F::Task>> {
128        // ping resp won't readh here
129        if seq < self.min_delay_seq {
130            return None; // Task is already timeouted by us
131        }
132        if seq > self.processed_seq {
133            let f = WaitRegTaskFuture { noti: self, target_seq: seq };
134            if f.await.is_err() {
135                return None;
136            }
137        }
138
139        if let Some(_removed_task) = self.sent_tasks.remove(&seq) {
140            return Some(_removed_task);
141        }
142        for tasks_batch_in_second in self.delay_tasks_queue.iter_mut() {
143            if let Some(_task) = tasks_batch_in_second.tasks.remove(&seq) {
144                return Some(_task);
145            }
146        }
147        return None;
148    }
149
150    #[inline]
151    pub fn poll_sent_task(&mut self, ctx: &mut Context) -> bool {
152        let mut got = false;
153        // Need to poll_item in order to register waker
154        while let Poll::Ready(Some(_task)) = self.pending_tasks_recv.poll_item(ctx) {
155            self.got_pending_task(_task);
156            got = true;
157        }
158        // break on empty or disconnect
159        got
160    }
161
162    // return None if task is store in sent_tasks
163    #[inline]
164    fn got_pending_task(&mut self, task_item: ClientTaskItem<F::Task>) {
165        self.pending_task_count.fetch_sub(1, Ordering::SeqCst);
166        let t = task_item.task.as_ref().unwrap();
167        let task_seq = t.seq();
168        self.processed_seq = task_seq;
169        self.sent_tasks.insert(task_seq, task_item);
170    }
171
172    pub fn adjust_task_queue(&mut self, facts: &F) {
173        // 1. move wait_confirmed to overtime
174        let mut tasks_batch_in_second = BTreeMap::new();
175        swap(&mut self.sent_tasks, &mut tasks_batch_in_second);
176
177        // 2.notify req with timeout err
178        self.delay_tasks_queue.push_front(DelayTasksBatch { tasks: tasks_batch_in_second });
179
180        if self.delay_tasks_queue.len() > self.task_timeout {
181            let real_timeout = self.delay_tasks_queue.pop_back().unwrap();
182            if !real_timeout.tasks.is_empty() {
183                let mut min_seq = 0;
184                for (_seq, mut task_item) in real_timeout.tasks {
185                    let mut task = task_item.task.take().unwrap();
186                    let seq = task.seq();
187                    if min_seq == 0 || min_seq > seq {
188                        min_seq = seq;
189                    }
190                    warn!("{} task {:?} is timeout", self.conn_id, task,);
191                    task.set_rpc_error(RpcIntErr::Timeout);
192                    facts.error_handle(task);
193                }
194                self.min_delay_seq = min_seq;
195            }
196        }
197    }
198}
199
200struct WaitRegTaskFuture<'a, F>
201where
202    F: ClientFacts,
203{
204    noti: &'a mut ClientTaskTimer<F>,
205    target_seq: u64,
206}
207
208impl<'a, F> Future for WaitRegTaskFuture<'a, F>
209where
210    F: ClientFacts,
211{
212    type Output = Result<(), ()>;
213
214    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
215        let mut _self = self.get_mut();
216        if _self.noti.processed_seq >= _self.target_seq {
217            return Poll::Ready(Ok(()));
218        }
219        if _self.noti.reg_stopped_flag.load(Ordering::SeqCst) {
220            return Poll::Ready(Err(()));
221        }
222        if _self.noti.poll_sent_task(ctx) && _self.noti.processed_seq >= _self.target_seq {
223            return Poll::Ready(Ok(()));
224        }
225        Poll::Pending
226    }
227}