chatting/
lib.rs

1//! # chatting
2//!
3//! A simple chat server written in Rust
4
5pub mod config;
6pub mod utils;
7
8pub use config::{CliArgs, Config};
9
10use std::{collections::HashMap, error::Error, net::SocketAddr, sync::Arc};
11
12use futures::{SinkExt, StreamExt};
13use log::{error, info, warn};
14
15use tokio::{
16    net::{TcpListener, TcpStream},
17    sync::{mpsc, RwLock},
18};
19use tokio_util::codec::{Framed, LinesCodec};
20
21type Tx = mpsc::UnboundedSender<String>;
22
23struct Server {
24    clients: Arc<RwLock<HashMap<SocketAddr, Tx>>>,
25}
26
27impl Server {
28    pub fn new() -> Self {
29        Server {
30            clients: Arc::new(RwLock::new(HashMap::new())),
31        }
32    }
33
34    pub async fn run(&mut self, config: &config::Config) -> Result<(), Box<dyn Error>> {
35        let listener = TcpListener::bind(config.addr).await?;
36
37        loop {
38            let (stream, addr) = listener.accept().await?;
39            let clients = self.clients.clone();
40            let speed_rate = config.speed_rate;
41
42            tokio::spawn(async move {
43                if let Err(e) = Self::handle_connection(clients, stream, addr, speed_rate).await {
44                    error!("client {} occurred error, error: {}", addr, e);
45                }
46            });
47        }
48    }
49
50    async fn handle_connection(
51        clients: Arc<RwLock<HashMap<SocketAddr, Tx>>>,
52        stream: TcpStream,
53        addr: SocketAddr,
54        rate_limit: u16,
55    ) -> Result<(), Box<dyn Error>> {
56        let (tx, mut rx) = mpsc::unbounded_channel();
57        let mut frame = Framed::new(stream, LinesCodec::new());
58
59        Self::server_send("Welcome to Mitsuha's chat room.", &mut frame).await?;
60
61        let name = Self::get_username(&mut frame).await?;
62
63        if Self::captcha(&mut frame).await? == false {
64            return Err("Bot detected")?;
65        }
66
67        Self::broadcast(
68            &clients,
69            &addr,
70            &name,
71            &format!("{} joined chat room.", name),
72            true,
73        )
74        .await?;
75
76        clients.write().await.insert(addr, tx);
77
78        // Initialize speed counter
79        let counter = utils::Counter::new(rate_limit);
80
81        loop {
82            tokio::select! {
83                Some(message) = rx.recv() => {
84                    frame.send(message).await?;
85                }
86
87                result = frame.next() => match result {
88                    Some(Ok(message)) => {
89                        if counter.check().await {
90                            Self::broadcast(&clients, &addr, &name, &message,false).await?;
91                        } else {
92                            Self::server_send("Oops, you have reached rate limit, please retry after 1 minute.",&mut frame).await?;
93                        }
94
95                        counter.add().await;
96                    }
97
98                    _ => break,
99                }
100            }
101        }
102
103        Self::broadcast(
104            &clients,
105            &addr,
106            &name,
107            &format!("{} left chat room.", name),
108            true,
109        )
110        .await?;
111        clients.write().await.remove(&addr);
112        Ok(())
113    }
114
115    async fn get_username(
116        frame: &mut Framed<TcpStream, LinesCodec>,
117    ) -> Result<String, Box<dyn Error>> {
118        Self::server_send("Please enter your name: ", frame).await?;
119
120        let name = match frame.next().await {
121            Some(Ok(n)) => n,
122            _ => {
123                return Err("Invalid name")?;
124            }
125        };
126
127        frame.send("\n").await?;
128
129        Ok(name)
130    }
131
132    async fn captcha(frame: &mut Framed<TcpStream, LinesCodec>) -> Result<bool, Box<dyn Error>> {
133        let param1: u8 = rand::random();
134        let param2: u8 = rand::random();
135        let answer = param1 as u16 + param2 as u16;
136        let captcha = format!("{} + {} = ?", param1, param2);
137
138        Self::server_send(&format!("Please solve the captcha: {}", captcha), frame).await?;
139
140        let input: u16 = match frame.next().await {
141            Some(Ok(n)) => n.parse()?,
142            _ => {
143                return Err("Invalid name")?;
144            }
145        };
146
147        let ret = input == answer;
148
149        if ret {
150            Self::server_send("Correct captcha, welcome!", frame).await?;
151        } else {
152            Self::server_send("WRONG CAPTCHA, DISCONNECTED!", frame).await?;
153        }
154
155        frame.send("\n").await?;
156
157        Ok(ret)
158    }
159
160    async fn broadcast(
161        clients: &Arc<RwLock<HashMap<SocketAddr, Tx>>>,
162        sender: &SocketAddr,
163        name: &str,
164        message: &str,
165        is_server: bool,
166    ) -> Result<(), Box<dyn Error>> {
167        let message = match is_server {
168            true => format!("[SERVER] {}", message),
169            false => format!("({}) {}", name, message),
170        };
171
172        info!("{}: {}", sender, message);
173
174        for (addr, tx) in clients.read().await.iter() {
175            if sender == addr {
176                continue;
177            }
178
179            let message = message.clone();
180            if let Err(e) = tx.send(message) {
181                warn!("error sending to {}, error: {}", addr, e);
182            }
183        }
184
185        Ok(())
186    }
187
188    async fn server_send(
189        message: &str,
190        frame: &mut Framed<TcpStream, LinesCodec>,
191    ) -> Result<(), Box<dyn Error>> {
192        frame.send(format!("[SERVER] {}", message)).await?;
193        Ok(())
194    }
195}
196
197pub async fn run(config: config::Config) -> Result<(), Box<dyn Error>> {
198    let mut server = Server::new();
199
200    server.run(&config).await
201}
202
203#[cfg(test)]
204mod tests {}