resp_async/
io.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use bytes::BytesMut;
7use log::{error, info};
8use std::future::Future;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{broadcast, mpsc};
12
13use crate::error::*;
14use crate::resp::*;
15
16const DEFAULT_ADDRESS: &str = "127.0.0.1:6379";
17
18#[derive(Debug)]
19pub struct PeerContext<T>
20where
21    T: Default,
22{
23    peer: SocketAddr,
24    local: SocketAddr,
25    ctx: HashMap<String, Value>,
26    pub user_data: T,
27}
28
29impl<T> PeerContext<T>
30where
31    T: Default,
32{
33    pub fn set<K>(&mut self, key: K, value: Value) -> Option<Value>
34    where
35        K: Into<String>,
36    {
37        self.ctx.insert(key.into(), value)
38    }
39
40    pub fn contains_key(&self, key: &str) -> bool {
41        self.ctx.contains_key(key)
42    }
43
44    pub fn get(&mut self, key: &str) -> Option<&Value> {
45        self.ctx.get(key)
46    }
47
48    pub fn get_mut(&mut self, key: &str) -> Option<&mut Value> {
49        self.ctx.get_mut(key)
50    }
51
52    pub fn peer_addr(&self) -> &SocketAddr {
53        &self.peer
54    }
55
56    pub fn local_addr(&self) -> &SocketAddr {
57        &self.local
58    }
59}
60
61pub trait EventHandler {
62    type ClientUserData: Default + Send + Sync;
63
64    fn on_request(
65        &self,
66        peer: &mut PeerContext<Self::ClientUserData>,
67        request: Value,
68    ) -> impl Future<Output = Result<Value>> + Send;
69    fn on_connect(&self, _id: u64) -> impl Future<Output = Result<Self::ClientUserData>> + Send {
70        async { Ok(Self::ClientUserData::default()) }
71    }
72    fn on_disconnect(&self, _id: u64) -> impl Future<Output = ()> + Send {
73        async {}
74    }
75}
76
77struct Shutdown {
78    is_shutdown: bool,
79    notify: broadcast::Receiver<()>,
80}
81
82impl Shutdown {
83    pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
84        Shutdown {
85            is_shutdown: false,
86            notify,
87        }
88    }
89
90    pub(crate) async fn recv(&mut self) {
91        if self.is_shutdown {
92            return;
93        }
94        let _ = self.notify.recv().await;
95        self.is_shutdown = true;
96    }
97
98    pub(crate) fn is_shutdown(&self) -> bool {
99        self.is_shutdown
100    }
101}
102
103pub struct Server<H>
104where
105    H: EventHandler + Send + Sync + 'static,
106{
107    handler: Arc<H>,
108    address: String,
109    client_id: Arc<AtomicU64>,
110}
111
112impl<H> Server<H>
113where
114    H: EventHandler + Send + Sync + 'static,
115{
116    pub fn new(handler: H) -> Self {
117        Server {
118            handler: Arc::new(handler),
119            address: DEFAULT_ADDRESS.into(),
120            client_id: Arc::new(AtomicU64::default()),
121        }
122    }
123
124    pub fn listen(&mut self, addr: impl Into<String>) -> Result<&mut Self> {
125        self.address = addr.into();
126        Ok(self)
127    }
128
129    async fn run_client_loop(
130        user_data: H::ClientUserData,
131        handler: Arc<H>,
132        mut socket: TcpStream,
133    ) -> Result<()> {
134        let mut client = PeerContext {
135            peer: socket.peer_addr()?,
136            local: socket.local_addr()?,
137            ctx: HashMap::new(),
138            user_data,
139        };
140        let mut rd = BytesMut::new();
141        let mut wr = BytesMut::new();
142        let mut decoder = ValueDecoder::default();
143        loop {
144            if let Some(value) = decoder.try_decode(&mut rd)? {
145                handler
146                    .on_request(&mut client, value)
147                    .await?
148                    .encode(&mut wr);
149                socket.write_all(&wr).await?;
150                wr.clear();
151                socket.flush().await?;
152            }
153
154            if 0 == socket.read_buf(&mut rd).await? && rd.is_empty() {
155                return Ok(());
156            }
157        }
158    }
159
160    async fn run_client_hooks(id: u64, handler: Arc<H>, socket: TcpStream) -> Result<()> {
161        let user_data = handler.on_connect(id).await?;
162        let result = Self::run_client_loop(user_data, Arc::clone(&handler), socket).await;
163        handler.on_disconnect(id).await;
164        result
165    }
166
167    async fn run_client(
168        id: u64,
169        handler: Arc<H>,
170        socket: TcpStream,
171        notify_shutdown: broadcast::Sender<()>,
172        _shutdown_complete_tx: mpsc::Sender<()>,
173    ) -> Result<()> {
174        let mut shutdown = Shutdown::new(notify_shutdown.subscribe());
175        tokio::select! {
176            res = Self::run_client_hooks(id, handler, socket) => { res }
177            _ = shutdown.recv() => { Ok(()) }
178        }
179    }
180
181    async fn run_accept_loop(
182        &mut self,
183        listener: TcpListener,
184        notify_shutdown: broadcast::Sender<()>,
185        shutdown_complete_tx: mpsc::Sender<()>,
186    ) -> Result<()> {
187        let mut shutdown = Shutdown::new(notify_shutdown.subscribe());
188        while !shutdown.is_shutdown() {
189            tokio::select! {
190                res = listener.accept() => {
191                    if let Ok((socket, _)) = res {
192                        let handler = Arc::clone(&self.handler);
193                        let client_id = Arc::clone(&self.client_id);
194                        let notify_shutdown = Clone::clone(&notify_shutdown);
195                        let shutdown_complete_tx = Clone::clone(&shutdown_complete_tx);
196                        tokio::spawn(async move {
197                            Self::run_client(client_id.fetch_add(1, Ordering::AcqRel), handler, socket, notify_shutdown, shutdown_complete_tx).await
198                        });
199                    }
200                },
201                _ = shutdown.recv() => {}
202            }
203        }
204        Ok(())
205    }
206
207    pub async fn serve(&mut self, shutdown: impl Future) -> Result<()> {
208        let listener = TcpListener::bind(&self.address).await?;
209
210        let (notify_shutdown, _) = broadcast::channel(1);
211        let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1);
212
213        tokio::select! {
214            res = self.run_accept_loop(listener, Clone::clone(&notify_shutdown), Clone::clone(&shutdown_complete_tx)) => {
215                if let Err(err) = res {
216                    error!("Failed to accept {:?}", err);
217                }
218            }
219            _ = shutdown => {
220                info!("shutting down server");
221            }
222        }
223
224        drop(notify_shutdown);
225        drop(shutdown_complete_tx);
226
227        let _ = shutdown_complete_rx.recv().await;
228        Ok(())
229    }
230}