1use std::{collections::HashMap, error::Error, num::NonZeroU64};
2
3use alkahest::{Schema, SeqUnpacked, Unpacked};
4use scoped_arena::Scope;
5
6use crate::channel::{Channel, Listener};
7
8use super::*;
9
10#[derive(Debug, thiserror::Error)]
11pub enum ServerError<E: Error + 'static> {
12 #[error("Client channel error: {source}")]
13 ChannelError {
14 #[from]
15 source: E,
16 },
17
18 #[error("Unexpected server message")]
19 UnexpectedMessage,
20}
21
22#[derive(PartialEq, Eq)]
23enum ClientState {
24 Pending,
25 Connected,
26 Disconnected,
27}
28
29struct Client<C> {
30 state: ClientState,
31 last_input_step: u64,
32 next_update_step: u64,
33 channel: C,
34}
35
36pub struct ServerSession<C, L> {
37 listener: L,
38 current_step: u64,
39 clients: HashMap<NonZeroU64, Client<C>>,
40 next_client_id: NonZeroU64,
41}
42
43pub enum Event<'a, C, P: Schema, I: Schema> {
44 ClientConnect(ClientConnectEvent<'a, C>),
45 AddPlayer(AddPlayerEvent<'a, C, P>),
46 Inputs(InputsEvent<'a, I>),
47 Disconnected,
48}
49
50pub struct ClientConnectEvent<'a, C> {
51 client: &'a mut Client<C>,
52 current_step: u64,
53}
54
55impl<C> ClientConnectEvent<'_, C>
56where
57 C: Channel,
58{
59 pub async fn accept(self, scope: &Scope<'_>) -> Result<(), C::Error> {
60 self.client.state = ClientState::Connected;
61
62 self.client
63 .channel
64 .send_reliable::<ServerMessage, _>(
65 ServerMessageConnectedPack {
66 step: self.current_step,
67 },
68 scope,
69 )
70 .await
71 }
72}
73
74pub struct AddPlayerEvent<'a, C, P: Schema> {
75 client: &'a mut Client<C>,
76 player: Unpacked<'a, P>,
77}
78
79impl<'a, C, P> AddPlayerEvent<'a, C, P>
80where
81 C: Channel,
82 P: Schema,
83{
84 pub fn player(&self) -> &Unpacked<'a, P> {
85 &self.player
86 }
87
88 pub async fn accept<J, K>(self, info: K, scope: &Scope<'_>) -> Result<(), C::Error>
89 where
90 J: Schema,
91 K: Pack<J>,
92 {
93 self.accept_with::<J, K, _>(|_| info, scope).await
94 }
95
96 pub async fn accept_with<J, K, F>(self, f: F, scope: &Scope<'_>) -> Result<(), C::Error>
97 where
98 J: Schema,
99 K: Pack<J>,
100 F: FnOnce(Unpacked<'a, P>) -> K,
101 {
102 self.try_accept_with(|player| Ok(f(player)), scope).await
103 }
104
105 pub async fn try_accept_with<J, K, F, E>(self, f: F, scope: &Scope<'_>) -> Result<(), E>
106 where
107 J: Schema,
108 K: Pack<J>,
109 F: FnOnce(Unpacked<'a, P>) -> Result<K, E>,
110 E: From<C::Error>,
111 {
112 let info = f(self.player)?;
113
114 self.client
115 .channel
116 .send_reliable::<ServerMessage<J>, _>(ServerMessagePlayerJoinedPack { info }, scope)
117 .await?;
118
119 Ok(())
120 }
121}
122
123pub struct InputsEvent<'a, I: Schema> {
124 inputs: SeqUnpacked<'a, (PlayerId, I)>,
125 step: u64,
126}
127
128impl<'a, I> InputsEvent<'a, I>
129where
130 I: Schema,
131{
132 pub fn inputs(&self) -> impl Iterator<Item = (PlayerId, Unpacked<'a, I>)> {
133 self.inputs
134 .filter_map(|(pid, input)| Some((pid.ok()?, input)))
135 }
136
137 pub fn step(&self) -> u64 {
138 self.step
139 }
140}
141
142impl<C, L> ServerSession<C, L>
143where
144 C: Channel,
145 L: Listener<Channel = C>,
146{
147 pub fn new(listener: L) -> Self {
149 ServerSession {
150 listener,
151 current_step: 0,
152 clients: HashMap::new(),
153 next_client_id: unsafe {
154 NonZeroU64::new_unchecked(1)
157 },
158 }
159 }
160
161 pub fn current_step(&self) -> u64 {
162 self.current_step
163 }
164
165 pub async fn advance<'a, U, F, K>(&mut self, mut updates: F, scope: &Scope<'_>)
168 where
169 U: Schema,
170 F: FnMut(u64) -> K,
171 K: Pack<U>,
172 {
173 for client in self.clients.values_mut() {
174 if let ClientState::Connected = client.state {
175 let result = client
176 .channel
177 .send::<ServerMessage<(), U>, _>(
178 ServerMessageUpdatesPack {
179 updates: updates(self.current_step - client.next_update_step),
180 server_step: self.current_step,
181 },
182 scope,
183 )
184 .await;
185
186 if let Err(err) = result {
187 tracing::error!("Client channel error: {}", err);
188 client.state = ClientState::Disconnected;
189 }
190 }
191 }
192 self.current_step += 1;
193 }
194
195 pub fn events<'a, P, I>(
196 &'a mut self,
197 scope: &'a Scope<'_>,
198 ) -> Result<impl Iterator<Item = (ClientId, Event<'a, C, P, I>)> + 'a, L::Error>
199 where
200 P: Schema,
201 I: Schema,
202 {
203 let current_step = self.current_step;
204
205 let disconnected = self
206 .clients
207 .iter()
208 .filter_map(|(cid, client)| match client.state {
209 ClientState::Disconnected => Some(*cid),
210 _ => None,
211 });
212
213 let disconnected = scope.to_scope_from_iter(disconnected);
214
215 let disconnected_events = disconnected
216 .iter()
217 .map(|cid| (ClientId(*cid), Event::Disconnected));
218
219 self.clients
220 .retain(|_, client| !matches!(client.state, ClientState::Disconnected));
221
222 loop {
223 match self.listener.try_accept()? {
224 None => break,
225 Some(channel) => {
226 let client = Client {
227 state: ClientState::Pending,
228 channel,
229 last_input_step: 0,
230 next_update_step: 0,
231 };
232
233 self.clients.insert(self.next_client_id, client);
234 self.next_client_id = NonZeroU64::new(self.next_client_id.get() + 1)
235 .expect("u64 overflow is unexpected");
236 }
237 }
238 }
239
240 let events = self.clients.iter_mut().filter_map(move |(&id, client)| {
241 debug_assert!(!matches!(client.state, ClientState::Disconnected));
242
243 let cid = ClientId(id);
244 let msgs = client.channel.recv::<ClientMessage<P, I>>(scope);
245 match msgs {
246 Ok(Some(ClientMessageUnpacked::Connect { token: _ })) => {
247 if let ClientState::Pending = client.state {
248 Some((
249 cid,
250 Event::ClientConnect(ClientConnectEvent {
251 client,
252 current_step,
253 }),
254 ))
255 } else {
256 client.state = ClientState::Disconnected;
257 Some((cid, Event::Disconnected))
258 }
259 }
260 Ok(Some(ClientMessageUnpacked::AddPlayer { player })) => {
261 if let ClientState::Connected = client.state {
262 Some((cid, Event::AddPlayer(AddPlayerEvent { client, player })))
263 } else {
264 Some((cid, Event::Disconnected))
265 }
266 }
267 Ok(Some(ClientMessageUnpacked::Inputs {
268 step,
269 next_update_step,
270 inputs,
271 })) => {
272 if let ClientState::Connected = client.state {
273 client.next_update_step = next_update_step;
274 if client.last_input_step <= step {
275 client.last_input_step = step;
276 Some((cid, Event::Inputs(InputsEvent { inputs, step })))
277 } else {
278 None
279 }
280 } else {
281 Some((cid, Event::Disconnected))
282 }
283 }
284 Ok(None) => None,
285 Err(err) => {
286 tracing::error!("Client error: {}", err);
287 client.state = ClientState::Disconnected;
288 Some((cid, Event::Disconnected))
289 }
290 }
291 });
292
293 Ok(disconnected_events.chain(events))
294 }
295}