Skip to main content

razor_stream/client/
timer.rs

1//! The struct that implements a timer for ClientTask
2//!
3//! This module is only for transport implementation, not for the user.
4
5use std::{
6    collections::vec_deque::VecDeque,
7    future::Future,
8    mem::swap,
9    pin::Pin,
10    sync::atomic::{AtomicBool, AtomicU64, Ordering},
11    task::*,
12};
13
14use crate::client::task::ClientTaskDone;
15use crate::client::*;
16use crossfire::{stream::AsyncStream, *};
17use rustc_hash::FxHashMap;
18use sync_utils::waitgroup::WaitGroupGuard;
19
20pub struct ClientTaskItem<T: ClientTask> {
21    pub task: Option<T>,
22    _upstream: WaitGroupGuard,
23}
24
25pub(crate) struct DelayTasksBatch<T: ClientTask> {
26    tasks: FxHashMap<u64, ClientTaskItem<T>>,
27}
28
29pub struct ClientTaskTimer<F: ClientFacts> {
30    conn_id: String,
31    pending_tasks_recv: AsyncStream<mpsc::Array<ClientTaskItem<F::Task>>>,
32    pending_tasks_sender: MAsyncTx<mpsc::Array<ClientTaskItem<F::Task>>>,
33    pending_task_count: AtomicU64,
34
35    sent_tasks: FxHashMap<u64, ClientTaskItem<F::Task>>, // sent_tasks of the current second
36    delay_tasks_queue: VecDeque<DelayTasksBatch<F::Task>>, // sent_tasks of past seconds
37
38    min_delay_seq: u64,
39    task_timeout: usize, // in seconds
40    // TODO what if seq reach max u64, should exit client
41    processed_seq: u64,
42    reg_stopped_flag: AtomicBool,
43}
44
45unsafe impl<T: ClientFacts> Send for ClientTaskTimer<T> {}
46unsafe impl<T: ClientFacts> Sync for ClientTaskTimer<T> {}
47
48impl<F: ClientFacts> ClientTaskTimer<F> {
49    pub fn new(conn_id: String, task_timeout: usize, mut thresholds: usize) -> Self {
50        if thresholds == 0 {
51            thresholds = 500;
52        }
53        let (pending_tx, pending_rx) = mpsc::bounded_async(thresholds * 2);
54        Self {
55            conn_id,
56            pending_tasks_recv: pending_rx.into_stream(),
57            pending_tasks_sender: pending_tx,
58            pending_task_count: AtomicU64::new(0),
59            sent_tasks: FxHashMap::default(),
60            min_delay_seq: 0,
61            task_timeout,
62            delay_tasks_queue: VecDeque::with_capacity(task_timeout),
63            processed_seq: 0,
64            reg_stopped_flag: AtomicBool::new(false),
65        }
66    }
67
68    pub fn pending_task_count_ref(&self) -> &AtomicU64 {
69        &self.pending_task_count
70    }
71
72    pub fn clean_pending_tasks(&mut self, facts: &F) {
73        loop {
74            match self.pending_tasks_recv.try_recv() {
75                Ok(task) => {
76                    self.got_pending_task(task);
77                }
78                Err(_) => {
79                    break;
80                }
81            }
82        }
83        let mut task_seqs: Vec<u64> = Vec::with_capacity(self.sent_tasks.len());
84        for (key, _) in self.sent_tasks.iter() {
85            task_seqs.push(*key);
86        }
87        for key in task_seqs {
88            let mut task_item = self.sent_tasks.remove(&key).unwrap();
89            let mut task = task_item.task.take().unwrap();
90            task.set_rpc_error(RpcIntErr::IO);
91            facts.error_handle(task);
92        }
93        for tasks_batch_in_second in self.delay_tasks_queue.iter_mut() {
94            let mut task_seqs: Vec<u64> = Vec::with_capacity(tasks_batch_in_second.tasks.len());
95            for (key, _) in tasks_batch_in_second.tasks.iter() {
96                task_seqs.push(*key);
97            }
98            for key in task_seqs {
99                let mut task_item = tasks_batch_in_second.tasks.remove(&key).unwrap();
100                let mut task = task_item.task.take().unwrap();
101                task.set_rpc_error(RpcIntErr::IO);
102                facts.error_handle(task);
103            }
104        }
105    }
106
107    pub fn check_pending_tasks_empty(&mut self) -> bool {
108        loop {
109            match self.pending_tasks_recv.try_recv() {
110                Ok(task) => {
111                    self.got_pending_task(task);
112                }
113                Err(_) => {
114                    break;
115                }
116            }
117        }
118        if !self.sent_tasks.is_empty() {
119            return false;
120        }
121        for tasks_batch_in_second in self.delay_tasks_queue.iter() {
122            if !tasks_batch_in_second.tasks.is_empty() {
123                return false;
124            }
125        }
126        return true;
127    }
128
129    // register noti for task.
130    #[inline(always)]
131    pub async fn reg_task(&self, task: F::Task, wg: WaitGroupGuard) {
132        let _ = self
133            .pending_tasks_sender
134            .send(ClientTaskItem { task: Some(task), _upstream: wg })
135            .await;
136    }
137
138    // stop register.
139    pub fn stop_reg_task(&mut self) {
140        self.reg_stopped_flag.store(true, Ordering::SeqCst);
141    }
142
143    pub async fn take_task(&mut self, seq: u64) -> Option<ClientTaskItem<F::Task>> {
144        // ping resp won't readh here
145        if seq < self.min_delay_seq {
146            return None; // Task is already timeouted by us
147        }
148        if seq > self.processed_seq {
149            let f = WaitRegTaskFuture { noti: self, target_seq: seq };
150            if f.await.is_err() {
151                return None;
152            }
153        }
154
155        if let Some(_removed_task) = self.sent_tasks.remove(&seq) {
156            return Some(_removed_task);
157        }
158        for tasks_batch_in_second in self.delay_tasks_queue.iter_mut() {
159            if let Some(_task) = tasks_batch_in_second.tasks.remove(&seq) {
160                return Some(_task);
161            }
162        }
163        return None;
164    }
165
166    #[inline]
167    pub fn poll_sent_task<'a>(&mut self, ctx: &mut Context) -> bool {
168        let mut got = false;
169        // Need to poll_item in order to register waker
170        loop {
171            match self.pending_tasks_recv.poll_item(ctx) {
172                Poll::Ready(Some(_task)) => {
173                    self.got_pending_task(_task);
174                    got = true;
175                    continue;
176                }
177                _ => break, // empty or disconnect
178            }
179        }
180        got
181    }
182
183    // return None if task is store in sent_tasks
184    #[inline]
185    fn got_pending_task(&mut self, task_item: ClientTaskItem<F::Task>) {
186        self.pending_task_count.fetch_sub(1, Ordering::SeqCst);
187        let t = task_item.task.as_ref().unwrap();
188        let task_seq = t.seq();
189        self.processed_seq = task_seq;
190        self.sent_tasks.insert(task_seq, task_item);
191    }
192
193    pub fn adjust_task_queue(&mut self, facts: &F) {
194        // 1. move wait_confirmed to overtime
195        let mut tasks_batch_in_second = FxHashMap::default();
196        swap(&mut self.sent_tasks, &mut tasks_batch_in_second);
197
198        // 2.notify req with timeout err
199        self.delay_tasks_queue.push_front(DelayTasksBatch { tasks: tasks_batch_in_second });
200
201        if self.delay_tasks_queue.len() > self.task_timeout {
202            let real_timeout = self.delay_tasks_queue.pop_back().unwrap();
203            if !real_timeout.tasks.is_empty() {
204                let mut min_seq = 0;
205                for (_seq, mut task_item) in real_timeout.tasks {
206                    let mut task = task_item.task.take().unwrap();
207                    let seq = task.seq();
208                    if min_seq == 0 {
209                        min_seq = seq;
210                    } else {
211                        if min_seq > seq {
212                            min_seq = seq;
213                        }
214                    }
215                    warn!("{} task {:?} is timeout", self.conn_id, task,);
216                    task.set_rpc_error(RpcIntErr::Timeout);
217                    facts.error_handle(task);
218                }
219                self.min_delay_seq = min_seq;
220            }
221        }
222    }
223}
224
225struct WaitRegTaskFuture<'a, F>
226where
227    F: ClientFacts,
228{
229    noti: &'a mut ClientTaskTimer<F>,
230    target_seq: u64,
231}
232
233impl<'a, F> Future for WaitRegTaskFuture<'a, F>
234where
235    F: ClientFacts,
236{
237    type Output = Result<(), ()>;
238
239    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
240        let mut _self = self.get_mut();
241        if _self.noti.processed_seq >= _self.target_seq {
242            return Poll::Ready(Ok(()));
243        }
244        if _self.noti.reg_stopped_flag.load(Ordering::SeqCst) {
245            return Poll::Ready(Err(()));
246        }
247        if _self.noti.poll_sent_task(ctx) {
248            if _self.noti.processed_seq >= _self.target_seq {
249                return Poll::Ready(Ok(()));
250            }
251        }
252        Poll::Pending
253    }
254}