#![warn(rust_2018_idioms)]
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex};
use tokio_stream::StreamExt;
use tokio_util::codec::{Framed, LinesCodec};
use futures::SinkExt;
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
const DEFAULT_ADDR: &str = "127.0.0.1:6142";
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter};
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive("chat=info".parse()?))
.with_span_events(FmtSpan::FULL)
.init();
let state = Arc::new(Mutex::new(Shared::new()));
let addr = env::args()
.nth(1)
.unwrap_or_else(|| DEFAULT_ADDR.to_string());
let listener = TcpListener::bind(&addr).await?;
tracing::info!("server running on {}", addr);
loop {
let (stream, addr) = listener.accept().await?;
let state = Arc::clone(&state);
tokio::spawn(async move {
tracing::debug!("accepted connection from {}", addr);
if let Err(e) = process(state, stream, addr).await {
tracing::warn!("Connection from {} failed: {:?}", addr, e);
}
});
}
}
type Tx = mpsc::UnboundedSender<String>;
type Rx = mpsc::UnboundedReceiver<String>;
struct Shared {
peers: HashMap<SocketAddr, Tx>,
}
struct Peer {
lines: Framed<TcpStream, LinesCodec>,
rx: Rx,
}
impl Shared {
fn new() -> Self {
Shared {
peers: HashMap::new(),
}
}
async fn broadcast(&mut self, sender: SocketAddr, message: &str) {
let mut failed_peers = Vec::new();
let message = message.to_string();
for (addr, tx) in self.peers.iter() {
if *addr != sender {
if tx.send(message.clone()).is_err() {
failed_peers.push(*addr);
}
}
}
for addr in failed_peers {
self.peers.remove(&addr);
tracing::debug!("Removed disconnected peer: {}", addr);
}
}
}
impl Peer {
async fn new(
state: Arc<Mutex<Shared>>,
lines: Framed<TcpStream, LinesCodec>,
) -> io::Result<Peer> {
let addr = lines.get_ref().peer_addr()?;
let (tx, rx) = mpsc::unbounded_channel();
state.lock().await.peers.insert(addr, tx);
Ok(Peer { lines, rx })
}
}
async fn process(
state: Arc<Mutex<Shared>>,
stream: TcpStream,
addr: SocketAddr,
) -> Result<(), Box<dyn Error>> {
let mut lines = Framed::new(stream, LinesCodec::new());
lines.send("Please enter your username:").await?;
let username = match lines.next().await {
Some(Ok(line)) => line,
_ => {
tracing::error!("Failed to get username from {}. Client disconnected.", addr);
return Ok(());
}
};
let mut peer = Peer::new(state.clone(), lines).await?;
{
let mut state = state.lock().await;
let msg = format!("{username} has joined the chat");
tracing::info!("{}", msg);
state.broadcast(addr, &msg).await;
}
loop {
tokio::select! {
Some(msg) = peer.rx.recv() => {
if let Err(e) = peer.lines.send(&msg).await {
tracing::error!("Failed to send message to {}: {:?}", username, e);
break;
}
}
result = peer.lines.next() => match result {
Some(Ok(msg)) => {
let mut state = state.lock().await;
let msg = format!("{username}: {msg}");
state.broadcast(addr, &msg).await;
}
Some(Err(e)) => {
tracing::error!(
"an error occurred while processing messages for {}; error = {:?}",
username,
e
);
break;
}
None => break,
},
}
}
{
let mut state = state.lock().await;
state.peers.remove(&addr);
let msg = format!("{username} has left the chat");
tracing::info!("{}", msg);
state.broadcast(addr, &msg).await;
}
Ok(())
}