1use crate::error::{ApiError, RequestError};
2use crate::message::{Action, FromMessage, IntoMessage, Message};
3use crate::request::{Request, RequestHandler};
4
5pub use lafere::packet::PlainBytes;
6use lafere::packet::{Packet, PacketBytes};
7pub use lafere::server::Config;
8use lafere::server::{self, Connection};
9use lafere::util::{Listener, ListenerExt, SocketAddr};
10
11#[cfg(feature = "encrypted")]
12pub use lafere::packet::EncryptedBytes;
13
14use std::any::{Any, TypeId};
15use std::collections::HashMap;
16use std::io;
17use std::sync::{Arc, Mutex};
18
19#[cfg(feature = "encrypted")]
20use crypto::signature::Keypair;
21
22#[derive(Debug, Default)]
23#[non_exhaustive]
24pub struct ServerConfig {
25 pub log_errors: bool,
26}
27
28pub struct Data {
29 cfg: ServerConfig,
30 inner: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
31}
32
33impl Data {
34 fn new() -> Self {
35 Self {
36 cfg: ServerConfig::default(),
37 inner: HashMap::new(),
38 }
39 }
40
41 pub fn cfg(&self) -> &ServerConfig {
42 &self.cfg
43 }
44
45 pub fn exists<D>(&self) -> bool
46 where
47 D: Any,
48 {
49 TypeId::of::<D>() == TypeId::of::<Session>()
50 || self.inner.contains_key(&TypeId::of::<D>())
51 }
52
53 fn insert<D>(&mut self, data: D)
54 where
55 D: Any + Send + Sync,
56 {
57 self.inner.insert(data.type_id(), Box::new(data));
58 }
59
60 pub fn get<D>(&self) -> Option<&D>
61 where
62 D: Any,
63 {
64 self.inner
65 .get(&TypeId::of::<D>())
66 .and_then(|a| a.downcast_ref())
67 }
68
69 pub fn get_or_sess<'a, D>(&'a self, sess: &'a Session) -> Option<&'a D>
70 where
71 D: Any,
72 {
73 if TypeId::of::<D>() == TypeId::of::<Session>() {
74 <dyn Any>::downcast_ref(sess)
75 } else {
76 self.get()
77 }
78 }
79}
80
81struct Requests<A, B> {
82 inner: HashMap<A, Box<dyn RequestHandler<B, Action = A> + Send + Sync>>,
83}
84
85impl<A, B> Requests<A, B>
86where
87 A: Action,
88{
89 fn new() -> Self {
90 Self {
91 inner: HashMap::new(),
92 }
93 }
94
95 fn insert<H>(&mut self, handler: H)
96 where
97 H: RequestHandler<B, Action = A> + Send + Sync + 'static,
98 {
99 self.inner.insert(H::action(), Box::new(handler));
100 }
101
102 fn get(
103 &self,
104 action: &A,
105 ) -> Option<&Box<dyn RequestHandler<B, Action = A> + Send + Sync>> {
106 self.inner.get(action)
107 }
108}
109
110pub struct Server<A, B, L, More> {
111 inner: L,
112 requests: Requests<A, B>,
113 data: Data,
114 cfg: Config,
115 more: More,
116}
117
118impl<A, B, L, More> Server<A, B, L, More>
119where
120 A: Action,
121{
122 pub fn register_request<H>(&mut self, handler: H)
123 where
124 H: RequestHandler<B, Action = A> + Send + Sync + 'static,
125 {
126 handler.validate_data(&self.data);
127 self.requests.insert(handler);
128 }
129
130 pub fn register_data<D>(&mut self, data: D)
131 where
132 D: Any + Send + Sync,
133 {
134 self.data.insert(data);
135 }
136}
137
138impl<A, B, L, More> Server<A, B, L, More>
139where
140 A: Action,
141 L: Listener,
142{
143 pub fn set_log_errors(&mut self, log: bool) {
146 self.data.cfg.log_errors = log;
147 }
148
149 pub fn build(self) -> BuiltServer<A, B, L, More> {
151 let shared = Arc::new(Shared {
152 requests: self.requests,
153 data: self.data,
154 });
155
156 BuiltServer {
157 inner: self.inner,
158 shared,
159 more: self.more,
160 }
161 }
162}
163
164impl<A, L> Server<A, PlainBytes, L, ()>
165where
166 A: Action,
167 L: Listener,
168{
169 pub fn new(listener: L, cfg: Config) -> Self {
170 Self {
171 inner: listener,
172 requests: Requests::new(),
173 data: Data::new(),
174 cfg,
175 more: (),
176 }
177 }
178
179 pub async fn run(self) -> io::Result<()>
180 where
181 A: Send + Sync + 'static,
182 {
183 let cfg = self.cfg.clone();
184
185 self.build()
186 .run_raw(|_, stream| Connection::new(stream, cfg.clone()))
187 .await
188 }
189}
190
191#[cfg(feature = "encrypted")]
192#[cfg_attr(docsrs, doc(cfg(feature = "encrypted")))]
193impl<A, L> Server<A, EncryptedBytes, L, Keypair>
194where
195 A: Action,
196 L: Listener,
197{
198 pub fn new_encrypted(listener: L, cfg: Config, key: Keypair) -> Self {
199 Self {
200 inner: listener,
201 requests: Requests::new(),
202 data: Data::new(),
203 cfg,
204 more: key,
205 }
206 }
207
208 pub async fn run(self) -> io::Result<()>
209 where
210 A: Send + Sync + 'static,
211 {
212 let cfg = self.cfg.clone();
213
214 self.build()
215 .run_raw(move |key, stream| {
216 Connection::new_encrypted(stream, cfg.clone(), key.clone())
217 })
218 .await
219 }
220}
221
222struct Shared<A, B> {
225 requests: Requests<A, B>,
226 data: Data,
227}
228
229pub struct BuiltServer<A, B, L, More> {
230 inner: L,
231 shared: Arc<Shared<A, B>>,
232 more: More,
233}
234
235impl<A, B, L, More> BuiltServer<A, B, L, More>
236where
237 A: Action,
238 L: Listener,
239{
240 pub fn get_data<D>(&self) -> Option<&D>
241 where
242 D: Any,
243 {
244 self.shared.data.get()
245 }
246
247 pub async fn request<R>(
248 &self,
249 r: R,
250 session: &Arc<Session>,
251 ) -> Result<R::Response, R::Error>
252 where
253 R: Request<Action = A>,
254 R: IntoMessage<A, B>,
255 R::Response: FromMessage<A, B>,
256 R::Error: FromMessage<A, B>,
257 B: PacketBytes,
258 {
259 let mut msg = r.into_message().map_err(R::Error::from_message_error)?;
260 msg.header_mut().set_action(R::ACTION);
261
262 let action = *msg.action().unwrap();
264
265 let handler = match self.shared.requests.get(&action) {
266 Some(handler) => handler,
267 None => {
271 tracing::error!("no handler for {:?}", action);
272 return Err(R::Error::from_request_error(
273 RequestError::NoResponse,
274 ));
275 }
276 };
277
278 let r = handler.handle(msg, &self.shared.data, session).await;
279
280 let res = match r {
281 Ok(mut msg) => {
282 msg.header_mut().set_action(action);
283 msg
284 }
285 Err(e) => {
286 tracing::error!("handler returned an error {:?}", e);
290
291 return Err(R::Error::from_request_error(
292 RequestError::NoResponse,
293 ));
294 }
295 };
296
297 if res.is_success() {
299 R::Response::from_message(res).map_err(R::Error::from_message_error)
300 } else {
301 R::Error::from_message(res)
302 .map(Err)
303 .map_err(R::Error::from_message_error)?
304 }
305 }
306
307 async fn run_raw<F>(&mut self, new_connection: F) -> io::Result<()>
308 where
309 A: Action + Send + Sync + 'static,
310 B: PacketBytes + Send + 'static,
311 F: Fn(&More, L::Stream) -> Connection<Message<A, B>>,
312 {
313 loop {
314 let (stream, addr) = self.inner.accept().await?;
316
317 let mut con = new_connection(&self.more, stream);
318 let session = Arc::new(Session::new(addr));
319 session.set(con.configurator());
320
321 let share = self.shared.clone();
322 tokio::spawn(async move {
323 while let Some(req) = con.receive().await {
324 let (msg, resp) = match req {
326 server::Message::Request(msg, resp) => (msg, resp),
327 _ => continue,
329 };
330
331 let share = share.clone();
332 let session = session.clone();
333
334 let action = match msg.action() {
335 Some(act) => *act,
336 None => {
340 tracing::error!("invalid action received");
341 continue;
342 }
343 };
344
345 tokio::spawn(async move {
346 let handler = match share.requests.get(&action) {
347 Some(handler) => handler,
348 None => {
352 tracing::error!("no handler for {:?}", action);
353 return;
354 }
355 };
356 let r =
357 handler.handle(msg, &share.data, &session).await;
358
359 match r {
360 Ok(mut msg) => {
361 msg.header_mut().set_action(action);
362 let _ = resp.send(msg);
364 }
365 Err(e) => {
366 tracing::error!(
370 "handler returned an error {:?}",
371 e
372 );
373 }
374 }
375 });
376 }
377 });
378 }
379 }
380}
381
382pub struct Session {
383 addr: SocketAddr,
385 data: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
386}
387
388impl Session {
389 pub fn new(addr: SocketAddr) -> Self {
390 Self {
391 addr,
392 data: Mutex::new(HashMap::new()),
393 }
394 }
395
396 pub fn addr(&self) -> &SocketAddr {
397 &self.addr
398 }
399
400 pub fn set<D>(&self, data: D)
401 where
402 D: Any + Send + Sync,
403 {
404 self.data
405 .lock()
406 .unwrap()
407 .insert(data.type_id(), Box::new(data));
408 }
409
410 pub fn get<D>(&self) -> Option<D>
411 where
412 D: Any + Clone + Send + Sync,
413 {
414 self.data
415 .lock()
416 .unwrap()
417 .get(&TypeId::of::<D>())
418 .and_then(|d| d.downcast_ref())
419 .map(Clone::clone)
420 }
421
422 pub fn take<D>(&self) -> Option<D>
423 where
424 D: Any + Send + Sync,
425 {
426 self.data
427 .lock()
428 .unwrap()
429 .remove(&TypeId::of::<D>())
430 .and_then(|d| d.downcast().ok())
431 .map(|b| *b)
432 }
433}
434
435#[cfg(all(test, feature = "json"))]
436mod json_tests {
437 use super::*;
438
439 use crate::error;
440 use crate::message;
441 use crate::request::Request;
442 use codegen::{api, FromMessage, IntoMessage};
443
444 use std::fmt;
445
446 use lafere::util::testing::PanicListener;
447
448 use serde::{Deserialize, Serialize};
449
450 #[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
451 #[message(json)]
452 struct TestReq {
453 hello: u64,
454 }
455
456 #[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
457 #[message(json)]
458 struct TestReq2 {
459 hello: u64,
460 }
461
462 #[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
463 #[message(json)]
464 struct TestResp {
465 hi: u64,
466 }
467
468 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
469 pub enum Action {
470 Empty,
471 }
472
473 #[derive(
474 Debug, Clone, Serialize, Deserialize, IntoMessage, FromMessage,
475 )]
476 #[message(json)]
477 pub enum Error {
478 RequestError(String),
479 MessageError(String),
480 }
481
482 impl fmt::Display for Error {
483 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
484 fmt::Debug::fmt(self, fmt)
485 }
486 }
487
488 impl std::error::Error for Error {}
489
490 impl error::ApiError for Error {
491 fn from_request_error(e: error::RequestError) -> Self {
492 Self::RequestError(e.to_string())
493 }
494
495 fn from_message_error(e: error::MessageError) -> Self {
496 Self::MessageError(e.to_string())
497 }
498 }
499
500 impl message::Action for Action {
501 fn from_u16(_num: u16) -> Option<Self> {
502 todo!()
503 }
504 fn as_u16(&self) -> u16 {
505 todo!()
506 }
507 }
508
509 impl Request for TestReq {
510 type Action = Action;
511 type Response = TestResp;
512 type Error = Error;
513
514 const ACTION: Action = Action::Empty;
515 }
516
517 impl Request for TestReq2 {
518 type Action = Action;
519 type Response = TestResp;
520 type Error = Error;
521
522 const ACTION: Action = Action::Empty;
523 }
524
525 #[api(TestReq)]
526 async fn test(req: TestReq) -> Result<TestResp, Error> {
527 println!("req {:?}", req);
528 Ok(TestResp { hi: req.hello })
529 }
530
531 #[api(TestReq2)]
532 async fn test_2(req: TestReq2) -> Result<TestResp, Error> {
533 println!("req {:?}", req);
534 Ok(TestResp { hi: req.hello })
535 }
536
537 #[tokio::test]
538 async fn test_direct_request() {
539 let mut server = Server::new(
540 PanicListener::new(),
541 Config {
542 timeout: std::time::Duration::from_millis(10),
543 body_limit: 4096,
544 },
545 );
546
547 server.register_data(String::from("global String"));
548
549 server.register_request(test);
550 server.register_request(test_2);
551
552 let server = server.build();
553 let session = Arc::new(Session::new(SocketAddr::V4(
554 "127.0.0.1:8080".parse().unwrap(),
555 )));
556
557 let r = server
558 .request(TestReq { hello: 100 }, &session)
559 .await
560 .unwrap();
561 assert_eq!(r.hi, 100);
562
563 let r = server
564 .request(TestReq2 { hello: 100 }, &session)
565 .await
566 .unwrap();
567 assert_eq!(r.hi, 100);
568
569 assert_eq!(server.get_data::<String>().unwrap(), "global String");
570 }
571}
572
573#[cfg(all(test, feature = "protobuf"))]
574mod protobuf_tests {
575 use codegen::{FromMessage, IntoMessage};
576
577 use protopuffer::{DecodeMessage, EncodeMessage};
578
579 #[derive(
580 Debug, Default, EncodeMessage, DecodeMessage, IntoMessage, FromMessage,
581 )]
582 #[message(protobuf)]
583 struct TestReq {
584 #[field(1)]
585 hello: u64,
586 }
587
588 #[derive(
589 Debug, Default, EncodeMessage, DecodeMessage, IntoMessage, FromMessage,
590 )]
591 #[message(protobuf)]
592 struct TestReq2 {
593 #[field(1)]
594 hello: u64,
595 }
596}