evoke_core/client_server/
server.rs

1use std::{collections::HashMap, error::Error, num::NonZeroU64};
2
3use alkahest::{Schema, SeqUnpacked, Unpacked};
4use scoped_arena::Scope;
5
6use crate::channel::{Channel, Listener};
7
8use super::*;
9
10#[derive(Debug, thiserror::Error)]
11pub enum ServerError<E: Error + 'static> {
12    #[error("Client channel error: {source}")]
13    ChannelError {
14        #[from]
15        source: E,
16    },
17
18    #[error("Unexpected server message")]
19    UnexpectedMessage,
20}
21
22#[derive(PartialEq, Eq)]
23enum ClientState {
24    Pending,
25    Connected,
26    Disconnected,
27}
28
29struct Client<C> {
30    state: ClientState,
31    last_input_step: u64,
32    next_update_step: u64,
33    channel: C,
34}
35
36pub struct ServerSession<C, L> {
37    listener: L,
38    current_step: u64,
39    clients: HashMap<NonZeroU64, Client<C>>,
40    next_client_id: NonZeroU64,
41}
42
43pub enum Event<'a, C, P: Schema, I: Schema> {
44    ClientConnect(ClientConnectEvent<'a, C>),
45    AddPlayer(AddPlayerEvent<'a, C, P>),
46    Inputs(InputsEvent<'a, I>),
47    Disconnected,
48}
49
50pub struct ClientConnectEvent<'a, C> {
51    client: &'a mut Client<C>,
52    current_step: u64,
53}
54
55impl<C> ClientConnectEvent<'_, C>
56where
57    C: Channel,
58{
59    pub async fn accept(self, scope: &Scope<'_>) -> Result<(), C::Error> {
60        self.client.state = ClientState::Connected;
61
62        self.client
63            .channel
64            .send_reliable::<ServerMessage, _>(
65                ServerMessageConnectedPack {
66                    step: self.current_step,
67                },
68                scope,
69            )
70            .await
71    }
72}
73
74pub struct AddPlayerEvent<'a, C, P: Schema> {
75    client: &'a mut Client<C>,
76    player: Unpacked<'a, P>,
77}
78
79impl<'a, C, P> AddPlayerEvent<'a, C, P>
80where
81    C: Channel,
82    P: Schema,
83{
84    pub fn player(&self) -> &Unpacked<'a, P> {
85        &self.player
86    }
87
88    pub async fn accept<J, K>(self, info: K, scope: &Scope<'_>) -> Result<(), C::Error>
89    where
90        J: Schema,
91        K: Pack<J>,
92    {
93        self.accept_with::<J, K, _>(|_| info, scope).await
94    }
95
96    pub async fn accept_with<J, K, F>(self, f: F, scope: &Scope<'_>) -> Result<(), C::Error>
97    where
98        J: Schema,
99        K: Pack<J>,
100        F: FnOnce(Unpacked<'a, P>) -> K,
101    {
102        self.try_accept_with(|player| Ok(f(player)), scope).await
103    }
104
105    pub async fn try_accept_with<J, K, F, E>(self, f: F, scope: &Scope<'_>) -> Result<(), E>
106    where
107        J: Schema,
108        K: Pack<J>,
109        F: FnOnce(Unpacked<'a, P>) -> Result<K, E>,
110        E: From<C::Error>,
111    {
112        let info = f(self.player)?;
113
114        self.client
115            .channel
116            .send_reliable::<ServerMessage<J>, _>(ServerMessagePlayerJoinedPack { info }, scope)
117            .await?;
118
119        Ok(())
120    }
121}
122
123pub struct InputsEvent<'a, I: Schema> {
124    inputs: SeqUnpacked<'a, (PlayerId, I)>,
125    step: u64,
126}
127
128impl<'a, I> InputsEvent<'a, I>
129where
130    I: Schema,
131{
132    pub fn inputs(&self) -> impl Iterator<Item = (PlayerId, Unpacked<'a, I>)> {
133        self.inputs
134            .filter_map(|(pid, input)| Some((pid.ok()?, input)))
135    }
136
137    pub fn step(&self) -> u64 {
138        self.step
139    }
140}
141
142impl<C, L> ServerSession<C, L>
143where
144    C: Channel,
145    L: Listener<Channel = C>,
146{
147    /// Create new server session via specified channel.
148    pub fn new(listener: L) -> Self {
149        ServerSession {
150            listener,
151            current_step: 0,
152            clients: HashMap::new(),
153            next_client_id: unsafe {
154                // # Safety
155                // 1 is not zero
156                NonZeroU64::new_unchecked(1)
157            },
158        }
159    }
160
161    pub fn current_step(&self) -> u64 {
162        self.current_step
163    }
164
165    /// Advances server-side simulation by one step.
166    /// Broadcasts updates to all clients.
167    pub async fn advance<'a, U, F, K>(&mut self, mut updates: F, scope: &Scope<'_>)
168    where
169        U: Schema,
170        F: FnMut(u64) -> K,
171        K: Pack<U>,
172    {
173        for client in self.clients.values_mut() {
174            if let ClientState::Connected = client.state {
175                let result = client
176                    .channel
177                    .send::<ServerMessage<(), U>, _>(
178                        ServerMessageUpdatesPack {
179                            updates: updates(self.current_step - client.next_update_step),
180                            server_step: self.current_step,
181                        },
182                        scope,
183                    )
184                    .await;
185
186                if let Err(err) = result {
187                    tracing::error!("Client channel error: {}", err);
188                    client.state = ClientState::Disconnected;
189                }
190            }
191        }
192        self.current_step += 1;
193    }
194
195    pub fn events<'a, P, I>(
196        &'a mut self,
197        scope: &'a Scope<'_>,
198    ) -> Result<impl Iterator<Item = (ClientId, Event<'a, C, P, I>)> + 'a, L::Error>
199    where
200        P: Schema,
201        I: Schema,
202    {
203        let current_step = self.current_step;
204
205        let disconnected = self
206            .clients
207            .iter()
208            .filter_map(|(cid, client)| match client.state {
209                ClientState::Disconnected => Some(*cid),
210                _ => None,
211            });
212
213        let disconnected = scope.to_scope_from_iter(disconnected);
214
215        let disconnected_events = disconnected
216            .iter()
217            .map(|cid| (ClientId(*cid), Event::Disconnected));
218
219        self.clients
220            .retain(|_, client| !matches!(client.state, ClientState::Disconnected));
221
222        loop {
223            match self.listener.try_accept()? {
224                None => break,
225                Some(channel) => {
226                    let client = Client {
227                        state: ClientState::Pending,
228                        channel,
229                        last_input_step: 0,
230                        next_update_step: 0,
231                    };
232
233                    self.clients.insert(self.next_client_id, client);
234                    self.next_client_id = NonZeroU64::new(self.next_client_id.get() + 1)
235                        .expect("u64 overflow is unexpected");
236                }
237            }
238        }
239
240        let events = self.clients.iter_mut().filter_map(move |(&id, client)| {
241            debug_assert!(!matches!(client.state, ClientState::Disconnected));
242
243            let cid = ClientId(id);
244            let msgs = client.channel.recv::<ClientMessage<P, I>>(scope);
245            match msgs {
246                Ok(Some(ClientMessageUnpacked::Connect { token: _ })) => {
247                    if let ClientState::Pending = client.state {
248                        Some((
249                            cid,
250                            Event::ClientConnect(ClientConnectEvent {
251                                client,
252                                current_step,
253                            }),
254                        ))
255                    } else {
256                        client.state = ClientState::Disconnected;
257                        Some((cid, Event::Disconnected))
258                    }
259                }
260                Ok(Some(ClientMessageUnpacked::AddPlayer { player })) => {
261                    if let ClientState::Connected = client.state {
262                        Some((cid, Event::AddPlayer(AddPlayerEvent { client, player })))
263                    } else {
264                        Some((cid, Event::Disconnected))
265                    }
266                }
267                Ok(Some(ClientMessageUnpacked::Inputs {
268                    step,
269                    next_update_step,
270                    inputs,
271                })) => {
272                    if let ClientState::Connected = client.state {
273                        client.next_update_step = next_update_step;
274                        if client.last_input_step <= step {
275                            client.last_input_step = step;
276                            Some((cid, Event::Inputs(InputsEvent { inputs, step })))
277                        } else {
278                            None
279                        }
280                    } else {
281                        Some((cid, Event::Disconnected))
282                    }
283                }
284                Ok(None) => None,
285                Err(err) => {
286                    tracing::error!("Client error: {}", err);
287                    client.state = ClientState::Disconnected;
288                    Some((cid, Event::Disconnected))
289                }
290            }
291        });
292
293        Ok(disconnected_events.chain(events))
294    }
295}