use anyhow::{Result, anyhow, bail};
use log::{error, info, warn};
use std::{
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
signal,
sync::{broadcast, mpsc},
time::{self, Duration},
};
use tokio_rustls::{
TlsAcceptor,
rustls::{
ServerConfig as TlsServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
},
server::TlsStream,
};
use freezeout_core::{
connection::{self, EncryptedConnection},
crypto::{PeerId, SigningKey},
message::{Message, SignedMessage},
poker::Chips,
};
use crate::{
db::Db,
table::{Table, TableMessage},
tables_pool::{TablesPool, TablesPoolsError},
};
#[derive(Debug)]
pub struct Config {
pub address: String,
pub port: u16,
pub tables: usize,
pub seats: usize,
pub data_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub chain_path: Option<PathBuf>,
}
pub async fn run(config: Config) -> Result<()> {
let addr = format!("{}:{}", config.address, config.port);
info!(
"Listening on {} with {} tables and {} seats per table",
addr, config.tables, config.seats
);
let listener = TcpListener::bind(&addr)
.await
.map_err(|e| anyhow!("Tcp listener bind error: {e}"))?;
let sk = load_signing_key(&config.data_path)?;
let db = open_database(&config.data_path)?;
let tls = match (config.key_path, config.chain_path) {
(Some(key), Some(chain)) => Some(load_tls(&key, &chain)?),
_ => {
warn!("TLS not enabled, using NOISE encryption");
None
}
};
let shutdown_signal = signal::ctrl_c();
let (shutdown_broadcast_tx, _) = broadcast::channel(1);
let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1);
let tables = TablesPool::new(
config.tables,
config.seats,
sk.clone(),
db.clone(),
&shutdown_broadcast_tx,
&shutdown_complete_tx,
);
let mut server = Server {
tables,
sk,
db,
listener,
tls,
shutdown_broadcast_tx,
shutdown_complete_tx,
};
tokio::select! {
res = server.run() => {
res.map_err(|e| anyhow!("Tcp listener accept error: {e}"))?;
}
_ = shutdown_signal => {
info!("Received shutdown signal...");
}
}
let Server {
shutdown_broadcast_tx,
shutdown_complete_tx,
..
} = server;
drop(shutdown_broadcast_tx);
drop(shutdown_complete_tx);
let _ = shutdown_complete_rx.recv().await;
Ok(())
}
struct Server {
tables: TablesPool,
sk: Arc<SigningKey>,
db: Db,
listener: TcpListener,
tls: Option<TlsAcceptor>,
shutdown_broadcast_tx: broadcast::Sender<()>,
shutdown_complete_tx: mpsc::Sender<()>,
}
impl Server {
async fn run(&mut self) -> Result<()> {
loop {
let (stream, addr) = self.accept_with_retry().await?;
info!("Accepted connection from {addr}");
let mut handler = Handler {
tables: self.tables.clone(),
sk: self.sk.clone(),
db: self.db.clone(),
table: None,
shutdown_broadcast_rx: self.shutdown_broadcast_tx.subscribe(),
_shutdown_complete_tx: self.shutdown_complete_tx.clone(),
};
let tls_acceptor = self.tls.clone();
tokio::spawn(async move {
let res = if let Some(acceptor) = tls_acceptor {
match acceptor.accept(stream).await {
Ok(stream) => handler.run_tls(stream).await,
Err(e) => Err(e.into()),
}
} else {
handler.run_tcp(stream).await
};
if let Err(err) = res {
error!("Connection to {addr} {err}");
}
info!("Connection to {addr} closed");
});
}
}
async fn accept_with_retry(&self) -> Result<(TcpStream, SocketAddr)> {
let mut retry = 0;
loop {
match self.listener.accept().await {
Ok((socket, addr)) => {
return Ok((socket, addr));
}
Err(err) => {
if retry == 5 {
return Err(err.into());
}
}
}
time::sleep(Duration::from_secs(1 << retry)).await;
retry += 1;
}
}
}
struct Handler {
tables: TablesPool,
sk: Arc<SigningKey>,
db: Db,
table: Option<Arc<Table>>,
shutdown_broadcast_rx: broadcast::Receiver<()>,
_shutdown_complete_tx: mpsc::Sender<()>,
}
impl Handler {
const JOIN_TABLE_CHIPS: Chips = Chips::new(1_000_000);
async fn run_tls(&mut self, stream: TlsStream<TcpStream>) -> Result<()> {
let mut conn = connection::accept_async(stream).await?;
let res = self.handle_connection(&mut conn).await;
conn.close().await;
res
}
async fn run_tcp(&mut self, stream: TcpStream) -> Result<()> {
let mut conn = connection::accept_async(stream).await?;
let res = self.handle_connection(&mut conn).await;
conn.close().await;
res
}
async fn handle_connection<S>(&mut self, conn: &mut EncryptedConnection<S>) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let msg = tokio::select! {
res = conn.recv() => match res {
Some(Ok(msg)) => msg,
Some(Err(err)) => return Err(err),
None => return Ok(()),
},
_ = self.shutdown_broadcast_rx.recv() => {
return Ok(());
}
};
let (nickname, player_id) = match msg.message() {
Message::JoinServer { nickname } => {
let player = self
.db
.join_server(msg.sender(), nickname, Self::JOIN_TABLE_CHIPS)
.await?;
let smsg = SignedMessage::new(
&self.sk,
Message::ServerJoined {
nickname: player.nickname,
chips: player.chips,
},
);
conn.send(&smsg).await?;
(nickname.to_string(), msg.sender())
}
_ => bail!(
"Invalid message from {} expecting a join server.",
msg.sender()
),
};
let (table_tx, mut table_rx) = mpsc::channel(128);
let res = loop {
enum Branch {
Conn(SignedMessage),
Table(TableMessage),
}
let branch = tokio::select! {
res = conn.recv() => match res {
Some(Ok(msg)) => Branch::Conn(msg),
Some(Err(err)) => break Err(err),
None => break Ok(()),
},
res = table_rx.recv() => match res {
Some(msg) => Branch::Table(msg),
None => break Ok(()),
},
_ = self.shutdown_broadcast_rx.recv() => break Ok(()),
};
match branch {
Branch::Conn(msg) => match msg.message() {
Message::JoinTable => {
self.get_or_refill_chips(&player_id).await?;
let has_chips = self
.db
.pay_from_player(player_id.clone(), Self::JOIN_TABLE_CHIPS)
.await?;
if has_chips {
let res = self
.tables
.join(
&player_id,
&nickname,
Self::JOIN_TABLE_CHIPS,
table_tx.clone(),
)
.await;
match res {
Ok(table) => self.table = Some(table),
Err(e) => {
self.db
.pay_to_player(player_id.clone(), Self::JOIN_TABLE_CHIPS)
.await?;
let msg = match e {
TablesPoolsError::NoTablesLeft => Message::NoTablesLeft,
TablesPoolsError::AlreadyJoined => {
Message::PlayerAlreadyJoined
}
};
conn.send(&SignedMessage::new(&self.sk, msg)).await?;
}
};
} else {
conn.send(&SignedMessage::new(&self.sk, Message::NotEnoughChips))
.await?;
}
}
Message::LeaveTable => {
if let Some(table) = &self.table {
table.leave(&player_id).await;
}
}
_ => {
if let Some(table) = &self.table {
table.message(msg).await;
}
}
},
Branch::Table(msg) => match msg {
TableMessage::Send(msg) => {
if let err @ Err(_) = conn.send(&msg).await {
break err;
}
}
TableMessage::PlayerLeft => {
self.table = None;
let chips = self.get_or_refill_chips(&player_id).await?;
let msg = Message::ShowAccount { chips };
conn.send(&SignedMessage::new(&self.sk, msg)).await?;
}
TableMessage::Throttle(dt) => {
time::sleep(dt).await;
}
TableMessage::Close => {
info!("Connection closed by table message");
break Ok(());
}
},
}
};
if let Some(table) = &self.table {
table.leave(&player_id).await;
}
res
}
async fn get_or_refill_chips(&mut self, player_id: &PeerId) -> Result<Chips> {
let mut player = self.db.get_player(player_id.clone()).await?;
if player.chips < Self::JOIN_TABLE_CHIPS {
let refill = Self::JOIN_TABLE_CHIPS - player.chips;
self.db.pay_to_player(player_id.clone(), refill).await?;
player.chips = Self::JOIN_TABLE_CHIPS;
}
Ok(player.chips)
}
}
fn load_signing_key(path: &Option<PathBuf>) -> Result<Arc<SigningKey>> {
fn load_or_create(path: &Path) -> Result<Arc<SigningKey>> {
let keypair_path = path.join("server.phrase");
let keypair = if keypair_path.exists() {
info!("Loading keypair {}", keypair_path.display());
let passphrase = std::fs::read_to_string(keypair_path)?;
SigningKey::from_phrase(&passphrase)?
} else {
let keypair = SigningKey::default();
std::fs::create_dir_all(path)?;
std::fs::write(&keypair_path, keypair.phrase().as_bytes())?;
info!("Writing keypair {}", keypair_path.display());
keypair
};
Ok(Arc::new(keypair))
}
if let Some(path) = path {
load_or_create(path)
} else {
let Some(proj_dirs) = directories::ProjectDirs::from("", "", "freezeout") else {
bail!("Cannot find project dirs");
};
load_or_create(proj_dirs.config_dir())
}
}
fn open_database(path: &Option<PathBuf>) -> Result<Db> {
fn load_or_create(path: &Path) -> Result<Db> {
let db_path = path.join("game.db");
if db_path.exists() {
info!("Loading database {}", db_path.display());
Db::open(db_path)
} else {
std::fs::create_dir_all(path)?;
info!("Writing database {}", db_path.display());
Db::open(db_path)
}
}
if let Some(path) = path {
load_or_create(path)
} else {
let Some(proj_dirs) = directories::ProjectDirs::from("", "", "freezeout") else {
bail!("Cannot find project dirs");
};
load_or_create(proj_dirs.config_dir())
}
}
fn load_tls(key_path: &PathBuf, chain_path: &PathBuf) -> Result<TlsAcceptor> {
let key = PrivateKeyDer::from_pem_file(key_path)?;
let chain = CertificateDer::pem_file_iter(chain_path)?.collect::<Result<Vec<_>, _>>()?;
info!("Loaded TLS chain from {}", chain_path.display());
info!("Loaded TLS key from {}", key_path.display());
let config = TlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(chain, key)?;
Ok(TlsAcceptor::from(Arc::new(config)))
}