1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use bytes::BytesMut;
7use log::{error, info};
8use std::future::Future;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{broadcast, mpsc};
12
13use crate::error::*;
14use crate::resp::*;
15
16const DEFAULT_ADDRESS: &str = "127.0.0.1:6379";
17
18#[derive(Debug)]
19pub struct PeerContext<T>
20where
21 T: Default,
22{
23 peer: SocketAddr,
24 local: SocketAddr,
25 ctx: HashMap<String, Value>,
26 pub user_data: T,
27}
28
29impl<T> PeerContext<T>
30where
31 T: Default,
32{
33 pub fn set<K>(&mut self, key: K, value: Value) -> Option<Value>
34 where
35 K: Into<String>,
36 {
37 self.ctx.insert(key.into(), value)
38 }
39
40 pub fn contains_key(&self, key: &str) -> bool {
41 self.ctx.contains_key(key)
42 }
43
44 pub fn get(&mut self, key: &str) -> Option<&Value> {
45 self.ctx.get(key)
46 }
47
48 pub fn get_mut(&mut self, key: &str) -> Option<&mut Value> {
49 self.ctx.get_mut(key)
50 }
51
52 pub fn peer_addr(&self) -> &SocketAddr {
53 &self.peer
54 }
55
56 pub fn local_addr(&self) -> &SocketAddr {
57 &self.local
58 }
59}
60
61pub trait EventHandler {
62 type ClientUserData: Default + Send + Sync;
63
64 fn on_request(
65 &self,
66 peer: &mut PeerContext<Self::ClientUserData>,
67 request: Value,
68 ) -> impl Future<Output = Result<Value>> + Send;
69 fn on_connect(&self, _id: u64) -> impl Future<Output = Result<Self::ClientUserData>> + Send {
70 async { Ok(Self::ClientUserData::default()) }
71 }
72 fn on_disconnect(&self, _id: u64) -> impl Future<Output = ()> + Send {
73 async {}
74 }
75}
76
77struct Shutdown {
78 is_shutdown: bool,
79 notify: broadcast::Receiver<()>,
80}
81
82impl Shutdown {
83 pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
84 Shutdown {
85 is_shutdown: false,
86 notify,
87 }
88 }
89
90 pub(crate) async fn recv(&mut self) {
91 if self.is_shutdown {
92 return;
93 }
94 let _ = self.notify.recv().await;
95 self.is_shutdown = true;
96 }
97
98 pub(crate) fn is_shutdown(&self) -> bool {
99 self.is_shutdown
100 }
101}
102
103pub struct Server<H>
104where
105 H: EventHandler + Send + Sync + 'static,
106{
107 handler: Arc<H>,
108 address: String,
109 client_id: Arc<AtomicU64>,
110}
111
112impl<H> Server<H>
113where
114 H: EventHandler + Send + Sync + 'static,
115{
116 pub fn new(handler: H) -> Self {
117 Server {
118 handler: Arc::new(handler),
119 address: DEFAULT_ADDRESS.into(),
120 client_id: Arc::new(AtomicU64::default()),
121 }
122 }
123
124 pub fn listen(&mut self, addr: impl Into<String>) -> Result<&mut Self> {
125 self.address = addr.into();
126 Ok(self)
127 }
128
129 async fn run_client_loop(
130 user_data: H::ClientUserData,
131 handler: Arc<H>,
132 mut socket: TcpStream,
133 ) -> Result<()> {
134 let mut client = PeerContext {
135 peer: socket.peer_addr()?,
136 local: socket.local_addr()?,
137 ctx: HashMap::new(),
138 user_data,
139 };
140 let mut rd = BytesMut::new();
141 let mut wr = BytesMut::new();
142 let mut decoder = ValueDecoder::default();
143 loop {
144 if let Some(value) = decoder.try_decode(&mut rd)? {
145 handler
146 .on_request(&mut client, value)
147 .await?
148 .encode(&mut wr);
149 socket.write_all(&wr).await?;
150 wr.clear();
151 socket.flush().await?;
152 }
153
154 if 0 == socket.read_buf(&mut rd).await? && rd.is_empty() {
155 return Ok(());
156 }
157 }
158 }
159
160 async fn run_client_hooks(id: u64, handler: Arc<H>, socket: TcpStream) -> Result<()> {
161 let user_data = handler.on_connect(id).await?;
162 let result = Self::run_client_loop(user_data, Arc::clone(&handler), socket).await;
163 handler.on_disconnect(id).await;
164 result
165 }
166
167 async fn run_client(
168 id: u64,
169 handler: Arc<H>,
170 socket: TcpStream,
171 notify_shutdown: broadcast::Sender<()>,
172 _shutdown_complete_tx: mpsc::Sender<()>,
173 ) -> Result<()> {
174 let mut shutdown = Shutdown::new(notify_shutdown.subscribe());
175 tokio::select! {
176 res = Self::run_client_hooks(id, handler, socket) => { res }
177 _ = shutdown.recv() => { Ok(()) }
178 }
179 }
180
181 async fn run_accept_loop(
182 &mut self,
183 listener: TcpListener,
184 notify_shutdown: broadcast::Sender<()>,
185 shutdown_complete_tx: mpsc::Sender<()>,
186 ) -> Result<()> {
187 let mut shutdown = Shutdown::new(notify_shutdown.subscribe());
188 while !shutdown.is_shutdown() {
189 tokio::select! {
190 res = listener.accept() => {
191 if let Ok((socket, _)) = res {
192 let handler = Arc::clone(&self.handler);
193 let client_id = Arc::clone(&self.client_id);
194 let notify_shutdown = Clone::clone(¬ify_shutdown);
195 let shutdown_complete_tx = Clone::clone(&shutdown_complete_tx);
196 tokio::spawn(async move {
197 Self::run_client(client_id.fetch_add(1, Ordering::AcqRel), handler, socket, notify_shutdown, shutdown_complete_tx).await
198 });
199 }
200 },
201 _ = shutdown.recv() => {}
202 }
203 }
204 Ok(())
205 }
206
207 pub async fn serve(&mut self, shutdown: impl Future) -> Result<()> {
208 let listener = TcpListener::bind(&self.address).await?;
209
210 let (notify_shutdown, _) = broadcast::channel(1);
211 let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1);
212
213 tokio::select! {
214 res = self.run_accept_loop(listener, Clone::clone(¬ify_shutdown), Clone::clone(&shutdown_complete_tx)) => {
215 if let Err(err) = res {
216 error!("Failed to accept {:?}", err);
217 }
218 }
219 _ = shutdown => {
220 info!("shutting down server");
221 }
222 }
223
224 drop(notify_shutdown);
225 drop(shutdown_complete_tx);
226
227 let _ = shutdown_complete_rx.recv().await;
228 Ok(())
229 }
230}