1pub mod config;
6pub mod utils;
7
8pub use config::{CliArgs, Config};
9
10use std::{collections::HashMap, error::Error, net::SocketAddr, sync::Arc};
11
12use futures::{SinkExt, StreamExt};
13use log::{error, info, warn};
14
15use tokio::{
16 net::{TcpListener, TcpStream},
17 sync::{mpsc, RwLock},
18};
19use tokio_util::codec::{Framed, LinesCodec};
20
21type Tx = mpsc::UnboundedSender<String>;
22
23struct Server {
24 clients: Arc<RwLock<HashMap<SocketAddr, Tx>>>,
25}
26
27impl Server {
28 pub fn new() -> Self {
29 Server {
30 clients: Arc::new(RwLock::new(HashMap::new())),
31 }
32 }
33
34 pub async fn run(&mut self, config: &config::Config) -> Result<(), Box<dyn Error>> {
35 let listener = TcpListener::bind(config.addr).await?;
36
37 loop {
38 let (stream, addr) = listener.accept().await?;
39 let clients = self.clients.clone();
40 let speed_rate = config.speed_rate;
41
42 tokio::spawn(async move {
43 if let Err(e) = Self::handle_connection(clients, stream, addr, speed_rate).await {
44 error!("client {} occurred error, error: {}", addr, e);
45 }
46 });
47 }
48 }
49
50 async fn handle_connection(
51 clients: Arc<RwLock<HashMap<SocketAddr, Tx>>>,
52 stream: TcpStream,
53 addr: SocketAddr,
54 rate_limit: u16,
55 ) -> Result<(), Box<dyn Error>> {
56 let (tx, mut rx) = mpsc::unbounded_channel();
57 let mut frame = Framed::new(stream, LinesCodec::new());
58
59 Self::server_send("Welcome to Mitsuha's chat room.", &mut frame).await?;
60
61 let name = Self::get_username(&mut frame).await?;
62
63 if Self::captcha(&mut frame).await? == false {
64 return Err("Bot detected")?;
65 }
66
67 Self::broadcast(
68 &clients,
69 &addr,
70 &name,
71 &format!("{} joined chat room.", name),
72 true,
73 )
74 .await?;
75
76 clients.write().await.insert(addr, tx);
77
78 let counter = utils::Counter::new(rate_limit);
80
81 loop {
82 tokio::select! {
83 Some(message) = rx.recv() => {
84 frame.send(message).await?;
85 }
86
87 result = frame.next() => match result {
88 Some(Ok(message)) => {
89 if counter.check().await {
90 Self::broadcast(&clients, &addr, &name, &message,false).await?;
91 } else {
92 Self::server_send("Oops, you have reached rate limit, please retry after 1 minute.",&mut frame).await?;
93 }
94
95 counter.add().await;
96 }
97
98 _ => break,
99 }
100 }
101 }
102
103 Self::broadcast(
104 &clients,
105 &addr,
106 &name,
107 &format!("{} left chat room.", name),
108 true,
109 )
110 .await?;
111 clients.write().await.remove(&addr);
112 Ok(())
113 }
114
115 async fn get_username(
116 frame: &mut Framed<TcpStream, LinesCodec>,
117 ) -> Result<String, Box<dyn Error>> {
118 Self::server_send("Please enter your name: ", frame).await?;
119
120 let name = match frame.next().await {
121 Some(Ok(n)) => n,
122 _ => {
123 return Err("Invalid name")?;
124 }
125 };
126
127 frame.send("\n").await?;
128
129 Ok(name)
130 }
131
132 async fn captcha(frame: &mut Framed<TcpStream, LinesCodec>) -> Result<bool, Box<dyn Error>> {
133 let param1: u8 = rand::random();
134 let param2: u8 = rand::random();
135 let answer = param1 as u16 + param2 as u16;
136 let captcha = format!("{} + {} = ?", param1, param2);
137
138 Self::server_send(&format!("Please solve the captcha: {}", captcha), frame).await?;
139
140 let input: u16 = match frame.next().await {
141 Some(Ok(n)) => n.parse()?,
142 _ => {
143 return Err("Invalid name")?;
144 }
145 };
146
147 let ret = input == answer;
148
149 if ret {
150 Self::server_send("Correct captcha, welcome!", frame).await?;
151 } else {
152 Self::server_send("WRONG CAPTCHA, DISCONNECTED!", frame).await?;
153 }
154
155 frame.send("\n").await?;
156
157 Ok(ret)
158 }
159
160 async fn broadcast(
161 clients: &Arc<RwLock<HashMap<SocketAddr, Tx>>>,
162 sender: &SocketAddr,
163 name: &str,
164 message: &str,
165 is_server: bool,
166 ) -> Result<(), Box<dyn Error>> {
167 let message = match is_server {
168 true => format!("[SERVER] {}", message),
169 false => format!("({}) {}", name, message),
170 };
171
172 info!("{}: {}", sender, message);
173
174 for (addr, tx) in clients.read().await.iter() {
175 if sender == addr {
176 continue;
177 }
178
179 let message = message.clone();
180 if let Err(e) = tx.send(message) {
181 warn!("error sending to {}, error: {}", addr, e);
182 }
183 }
184
185 Ok(())
186 }
187
188 async fn server_send(
189 message: &str,
190 frame: &mut Framed<TcpStream, LinesCodec>,
191 ) -> Result<(), Box<dyn Error>> {
192 frame.send(format!("[SERVER] {}", message)).await?;
193 Ok(())
194 }
195}
196
197pub async fn run(config: config::Config) -> Result<(), Box<dyn Error>> {
198 let mut server = Server::new();
199
200 server.run(&config).await
201}
202
203#[cfg(test)]
204mod tests {}