1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::net::{TcpListener, TcpStream};
use tokio::*;
use bird_protocol::packet::PacketState;
use tokio::task::yield_now;
use crate::connection::Connection;
use crate::handler::{ConnectionHandler, ReadHandler};
use crate::read::ReadStreamQueue;
use crate::write::WriteStreamQueue;
pub struct ProtocolServerDeclare<
H: ReadHandler + Sized + Send + Sync + 'static,
C: ConnectionHandler + Sized + Send + Sync + 'static,
> {
pub host: String,
pub read_handler: H,
pub connection_handler: C,
}
pub struct ProtocolServerRuntime {
pub running: AtomicBool,
}
pub struct ProtocolServerTask {
pub runtime: Arc<ProtocolServerRuntime>,
pub task: task::JoinHandle<io::Result<()>>,
}
pub fn run_server<
H: ReadHandler + Sized + Send + Sync + 'static,
C: ConnectionHandler + Sized + Send + Sync + 'static
>(declare: ProtocolServerDeclare<H, C>) -> ProtocolServerTask {
let runtime = Arc::new(
ProtocolServerRuntime {
running: AtomicBool::new(true),
}
);
let task_runtime = runtime.clone();
ProtocolServerTask {
task: tokio::spawn(async move {
run_server_runtime(declare, task_runtime).await
}),
runtime,
}
}
const CHANNEL_BUFFER_SIZE: usize = 128;
const READ_BUFFER_SIZE: usize = 1024;
async fn run_server_runtime<
H: ReadHandler + Sized + Send + Sync + 'static,
C: ConnectionHandler + Sized + Send + Sync + 'static
>(declare: ProtocolServerDeclare<H, C>, runtime: Arc<ProtocolServerRuntime>) -> io::Result<()> {
let declare = Arc::new(declare);
let listener = TcpListener::bind(&declare.host).await?;
while runtime.running.load(Ordering::Acquire) {
let (stream, addr) = listener.accept().await?;
let declare = declare.clone();
let runtime = runtime.clone();
tokio::spawn(async move { run_connection(declare, runtime, stream, addr).await });
}
Ok(())
}
async fn run_connection<
H: ReadHandler + Sized + Send + Sync + 'static,
C: ConnectionHandler + Sized + Send + Sync + 'static
>(declare: Arc<ProtocolServerDeclare<H, C>>, runtime: Arc<ProtocolServerRuntime>, stream: TcpStream, addr: SocketAddr) {
let (read_half, write_half) = stream.into_split();
let (sender, receiver) =
sync::mpsc::channel(CHANNEL_BUFFER_SIZE);
let connection = Arc::new(Connection::new(addr, sender));
declare.connection_handler.handle_connection(connection.clone());
let mut read_queue = ReadStreamQueue::<READ_BUFFER_SIZE>::new(read_half);
{
let write_queue = WriteStreamQueue { write_half, receiver };
let connection = connection.clone();
let declare = declare.clone();
tokio::spawn(async move { write_queue.run(connection, declare).await });
}
let mut state = PacketState::Handshake;
while runtime.running.load(Ordering::Acquire) {
if let Err(err) = read_queue.next_packet().await {
log::debug!("Received error while getting next packet: {:?}", err);
break;
}
if let Err(err) = declare.read_handler.handle(
connection.clone(), &mut state, &mut read_queue).await {
log::debug!("Received error while handling next packet: {:?}", err);
break;
}
yield_now().await;
}
let _ = connection.close().await;
}