lafere_api/
server.rs

1use crate::error::{ApiError, RequestError};
2use crate::message::{Action, FromMessage, IntoMessage, Message};
3use crate::request::{Request, RequestHandler};
4
5pub use lafere::packet::PlainBytes;
6use lafere::packet::{Packet, PacketBytes};
7pub use lafere::server::Config;
8use lafere::server::{self, Connection};
9use lafere::util::{Listener, ListenerExt, SocketAddr};
10
11#[cfg(feature = "encrypted")]
12pub use lafere::packet::EncryptedBytes;
13
14use std::any::{Any, TypeId};
15use std::collections::HashMap;
16use std::io;
17use std::sync::{Arc, Mutex};
18
19#[cfg(feature = "encrypted")]
20use crypto::signature::Keypair;
21
22#[derive(Debug, Default)]
23#[non_exhaustive]
24pub struct ServerConfig {
25	pub log_errors: bool,
26}
27
28pub struct Data {
29	cfg: ServerConfig,
30	inner: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
31}
32
33impl Data {
34	fn new() -> Self {
35		Self {
36			cfg: ServerConfig::default(),
37			inner: HashMap::new(),
38		}
39	}
40
41	pub fn cfg(&self) -> &ServerConfig {
42		&self.cfg
43	}
44
45	pub fn exists<D>(&self) -> bool
46	where
47		D: Any,
48	{
49		TypeId::of::<D>() == TypeId::of::<Session>()
50			|| self.inner.contains_key(&TypeId::of::<D>())
51	}
52
53	fn insert<D>(&mut self, data: D)
54	where
55		D: Any + Send + Sync,
56	{
57		self.inner.insert(data.type_id(), Box::new(data));
58	}
59
60	pub fn get<D>(&self) -> Option<&D>
61	where
62		D: Any,
63	{
64		self.inner
65			.get(&TypeId::of::<D>())
66			.and_then(|a| a.downcast_ref())
67	}
68
69	pub fn get_or_sess<'a, D>(&'a self, sess: &'a Session) -> Option<&'a D>
70	where
71		D: Any,
72	{
73		if TypeId::of::<D>() == TypeId::of::<Session>() {
74			<dyn Any>::downcast_ref(sess)
75		} else {
76			self.get()
77		}
78	}
79}
80
81struct Requests<A, B> {
82	inner: HashMap<A, Box<dyn RequestHandler<B, Action = A> + Send + Sync>>,
83}
84
85impl<A, B> Requests<A, B>
86where
87	A: Action,
88{
89	fn new() -> Self {
90		Self {
91			inner: HashMap::new(),
92		}
93	}
94
95	fn insert<H>(&mut self, handler: H)
96	where
97		H: RequestHandler<B, Action = A> + Send + Sync + 'static,
98	{
99		self.inner.insert(H::action(), Box::new(handler));
100	}
101
102	fn get(
103		&self,
104		action: &A,
105	) -> Option<&Box<dyn RequestHandler<B, Action = A> + Send + Sync>> {
106		self.inner.get(action)
107	}
108}
109
110pub struct Server<A, B, L, More> {
111	inner: L,
112	requests: Requests<A, B>,
113	data: Data,
114	cfg: Config,
115	more: More,
116}
117
118impl<A, B, L, More> Server<A, B, L, More>
119where
120	A: Action,
121{
122	pub fn register_request<H>(&mut self, handler: H)
123	where
124		H: RequestHandler<B, Action = A> + Send + Sync + 'static,
125	{
126		handler.validate_data(&self.data);
127		self.requests.insert(handler);
128	}
129
130	pub fn register_data<D>(&mut self, data: D)
131	where
132		D: Any + Send + Sync,
133	{
134		self.data.insert(data);
135	}
136}
137
138impl<A, B, L, More> Server<A, B, L, More>
139where
140	A: Action,
141	L: Listener,
142{
143	/// If this is set to true
144	/// errors which are returned in `#[api(*)]` functions are logged to tracing
145	pub fn set_log_errors(&mut self, log: bool) {
146		self.data.cfg.log_errors = log;
147	}
148
149	/// optionally or just use run
150	pub fn build(self) -> BuiltServer<A, B, L, More> {
151		let shared = Arc::new(Shared {
152			requests: self.requests,
153			data: self.data,
154		});
155
156		BuiltServer {
157			inner: self.inner,
158			shared,
159			more: self.more,
160		}
161	}
162}
163
164impl<A, L> Server<A, PlainBytes, L, ()>
165where
166	A: Action,
167	L: Listener,
168{
169	pub fn new(listener: L, cfg: Config) -> Self {
170		Self {
171			inner: listener,
172			requests: Requests::new(),
173			data: Data::new(),
174			cfg,
175			more: (),
176		}
177	}
178
179	pub async fn run(self) -> io::Result<()>
180	where
181		A: Send + Sync + 'static,
182	{
183		let cfg = self.cfg.clone();
184
185		self.build()
186			.run_raw(|_, stream| Connection::new(stream, cfg.clone()))
187			.await
188	}
189}
190
191#[cfg(feature = "encrypted")]
192#[cfg_attr(docsrs, doc(cfg(feature = "encrypted")))]
193impl<A, L> Server<A, EncryptedBytes, L, Keypair>
194where
195	A: Action,
196	L: Listener,
197{
198	pub fn new_encrypted(listener: L, cfg: Config, key: Keypair) -> Self {
199		Self {
200			inner: listener,
201			requests: Requests::new(),
202			data: Data::new(),
203			cfg,
204			more: key,
205		}
206	}
207
208	pub async fn run(self) -> io::Result<()>
209	where
210		A: Send + Sync + 'static,
211	{
212		let cfg = self.cfg.clone();
213
214		self.build()
215			.run_raw(move |key, stream| {
216				Connection::new_encrypted(stream, cfg.clone(), key.clone())
217			})
218			.await
219	}
220}
221
222// impl
223
224struct Shared<A, B> {
225	requests: Requests<A, B>,
226	data: Data,
227}
228
229pub struct BuiltServer<A, B, L, More> {
230	inner: L,
231	shared: Arc<Shared<A, B>>,
232	more: More,
233}
234
235impl<A, B, L, More> BuiltServer<A, B, L, More>
236where
237	A: Action,
238	L: Listener,
239{
240	pub fn get_data<D>(&self) -> Option<&D>
241	where
242		D: Any,
243	{
244		self.shared.data.get()
245	}
246
247	pub async fn request<R>(
248		&self,
249		r: R,
250		session: &Arc<Session>,
251	) -> Result<R::Response, R::Error>
252	where
253		R: Request<Action = A>,
254		R: IntoMessage<A, B>,
255		R::Response: FromMessage<A, B>,
256		R::Error: FromMessage<A, B>,
257		B: PacketBytes,
258	{
259		let mut msg = r.into_message().map_err(R::Error::from_message_error)?;
260		msg.header_mut().set_action(R::ACTION);
261
262		// handle the request
263		let action = *msg.action().unwrap();
264
265		let handler = match self.shared.requests.get(&action) {
266			Some(handler) => handler,
267			// todo once we bump the version again
268			// we need to pass our own errors via packets
269			// not only those from the api users
270			None => {
271				tracing::error!("no handler for {:?}", action);
272				return Err(R::Error::from_request_error(
273					RequestError::NoResponse,
274				));
275			}
276		};
277
278		let r = handler.handle(msg, &self.shared.data, session).await;
279
280		let res = match r {
281			Ok(mut msg) => {
282				msg.header_mut().set_action(action);
283				msg
284			}
285			Err(e) => {
286				// todo once we bump the version again
287				// we need to pass our own errors via packets
288				// not only those from the api users
289				tracing::error!("handler returned an error {:?}", e);
290
291				return Err(R::Error::from_request_error(
292					RequestError::NoResponse,
293				));
294			}
295		};
296
297		// now deserialize the response
298		if res.is_success() {
299			R::Response::from_message(res).map_err(R::Error::from_message_error)
300		} else {
301			R::Error::from_message(res)
302				.map(Err)
303				.map_err(R::Error::from_message_error)?
304		}
305	}
306
307	async fn run_raw<F>(&mut self, new_connection: F) -> io::Result<()>
308	where
309		A: Action + Send + Sync + 'static,
310		B: PacketBytes + Send + 'static,
311		F: Fn(&More, L::Stream) -> Connection<Message<A, B>>,
312	{
313		loop {
314			// should we fail here??
315			let (stream, addr) = self.inner.accept().await?;
316
317			let mut con = new_connection(&self.more, stream);
318			let session = Arc::new(Session::new(addr));
319			session.set(con.configurator());
320
321			let share = self.shared.clone();
322			tokio::spawn(async move {
323				while let Some(req) = con.receive().await {
324					// todo replace with let else
325					let (msg, resp) = match req {
326						server::Message::Request(msg, resp) => (msg, resp),
327						// ignore streams for now
328						_ => continue,
329					};
330
331					let share = share.clone();
332					let session = session.clone();
333
334					let action = match msg.action() {
335						Some(act) => *act,
336						// todo once we bump the version again
337						// we need to pass our own errors via packets
338						// not only those from the api users
339						None => {
340							tracing::error!("invalid action received");
341							continue;
342						}
343					};
344
345					tokio::spawn(async move {
346						let handler = match share.requests.get(&action) {
347							Some(handler) => handler,
348							// todo once we bump the version again
349							// we need to pass our own errors via packets
350							// not only those from the api users
351							None => {
352								tracing::error!("no handler for {:?}", action);
353								return;
354							}
355						};
356						let r =
357							handler.handle(msg, &share.data, &session).await;
358
359						match r {
360							Ok(mut msg) => {
361								msg.header_mut().set_action(action);
362								// i don't care about the response
363								let _ = resp.send(msg);
364							}
365							Err(e) => {
366								// todo once we bump the version again
367								// we need to pass our own errors via packets
368								// not only those from the api users
369								tracing::error!(
370									"handler returned an error {:?}",
371									e
372								);
373							}
374						}
375					});
376				}
377			});
378		}
379	}
380}
381
382pub struct Session {
383	// (SocketAddr, S)
384	addr: SocketAddr,
385	data: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
386}
387
388impl Session {
389	pub fn new(addr: SocketAddr) -> Self {
390		Self {
391			addr,
392			data: Mutex::new(HashMap::new()),
393		}
394	}
395
396	pub fn addr(&self) -> &SocketAddr {
397		&self.addr
398	}
399
400	pub fn set<D>(&self, data: D)
401	where
402		D: Any + Send + Sync,
403	{
404		self.data
405			.lock()
406			.unwrap()
407			.insert(data.type_id(), Box::new(data));
408	}
409
410	pub fn get<D>(&self) -> Option<D>
411	where
412		D: Any + Clone + Send + Sync,
413	{
414		self.data
415			.lock()
416			.unwrap()
417			.get(&TypeId::of::<D>())
418			.and_then(|d| d.downcast_ref())
419			.map(Clone::clone)
420	}
421
422	pub fn take<D>(&self) -> Option<D>
423	where
424		D: Any + Send + Sync,
425	{
426		self.data
427			.lock()
428			.unwrap()
429			.remove(&TypeId::of::<D>())
430			.and_then(|d| d.downcast().ok())
431			.map(|b| *b)
432	}
433}
434
435#[cfg(all(test, feature = "json"))]
436mod json_tests {
437	use super::*;
438
439	use crate::error;
440	use crate::message;
441	use crate::request::Request;
442	use codegen::{api, FromMessage, IntoMessage};
443
444	use std::fmt;
445
446	use lafere::util::testing::PanicListener;
447
448	use serde::{Deserialize, Serialize};
449
450	#[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
451	#[message(json)]
452	struct TestReq {
453		hello: u64,
454	}
455
456	#[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
457	#[message(json)]
458	struct TestReq2 {
459		hello: u64,
460	}
461
462	#[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
463	#[message(json)]
464	struct TestResp {
465		hi: u64,
466	}
467
468	#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
469	pub enum Action {
470		Empty,
471	}
472
473	#[derive(
474		Debug, Clone, Serialize, Deserialize, IntoMessage, FromMessage,
475	)]
476	#[message(json)]
477	pub enum Error {
478		RequestError(String),
479		MessageError(String),
480	}
481
482	impl fmt::Display for Error {
483		fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
484			fmt::Debug::fmt(self, fmt)
485		}
486	}
487
488	impl std::error::Error for Error {}
489
490	impl error::ApiError for Error {
491		fn from_request_error(e: error::RequestError) -> Self {
492			Self::RequestError(e.to_string())
493		}
494
495		fn from_message_error(e: error::MessageError) -> Self {
496			Self::MessageError(e.to_string())
497		}
498	}
499
500	impl message::Action for Action {
501		fn from_u16(_num: u16) -> Option<Self> {
502			todo!()
503		}
504		fn as_u16(&self) -> u16 {
505			todo!()
506		}
507	}
508
509	impl Request for TestReq {
510		type Action = Action;
511		type Response = TestResp;
512		type Error = Error;
513
514		const ACTION: Action = Action::Empty;
515	}
516
517	impl Request for TestReq2 {
518		type Action = Action;
519		type Response = TestResp;
520		type Error = Error;
521
522		const ACTION: Action = Action::Empty;
523	}
524
525	#[api(TestReq)]
526	async fn test(req: TestReq) -> Result<TestResp, Error> {
527		println!("req {:?}", req);
528		Ok(TestResp { hi: req.hello })
529	}
530
531	#[api(TestReq2)]
532	async fn test_2(req: TestReq2) -> Result<TestResp, Error> {
533		println!("req {:?}", req);
534		Ok(TestResp { hi: req.hello })
535	}
536
537	#[tokio::test]
538	async fn test_direct_request() {
539		let mut server = Server::new(
540			PanicListener::new(),
541			Config {
542				timeout: std::time::Duration::from_millis(10),
543				body_limit: 4096,
544			},
545		);
546
547		server.register_data(String::from("global String"));
548
549		server.register_request(test);
550		server.register_request(test_2);
551
552		let server = server.build();
553		let session = Arc::new(Session::new(SocketAddr::V4(
554			"127.0.0.1:8080".parse().unwrap(),
555		)));
556
557		let r = server
558			.request(TestReq { hello: 100 }, &session)
559			.await
560			.unwrap();
561		assert_eq!(r.hi, 100);
562
563		let r = server
564			.request(TestReq2 { hello: 100 }, &session)
565			.await
566			.unwrap();
567		assert_eq!(r.hi, 100);
568
569		assert_eq!(server.get_data::<String>().unwrap(), "global String");
570	}
571}
572
573#[cfg(all(test, feature = "protobuf"))]
574mod protobuf_tests {
575	use codegen::{FromMessage, IntoMessage};
576
577	use protopuffer::{DecodeMessage, EncodeMessage};
578
579	#[derive(
580		Debug, Default, EncodeMessage, DecodeMessage, IntoMessage, FromMessage,
581	)]
582	#[message(protobuf)]
583	struct TestReq {
584		#[field(1)]
585		hello: u64,
586	}
587
588	#[derive(
589		Debug, Default, EncodeMessage, DecodeMessage, IntoMessage, FromMessage,
590	)]
591	#[message(protobuf)]
592	struct TestReq2 {
593		#[field(1)]
594		hello: u64,
595	}
596}