fire_stream_api/
server.rs

1use crate::error::{RequestError, ApiError};
2use crate::message::{Action, Message, IntoMessage, FromMessage};
3use crate::request::{RequestHandler, Request};
4
5use stream::util::{SocketAddr, Listener, ListenerExt};
6use stream::packet::{Packet, PacketBytes};
7pub use stream::packet::PlainBytes;
8use stream::server::{self, Connection};
9pub use stream::server::Config;
10
11#[cfg(feature = "encrypted")]
12pub use stream::packet::EncryptedBytes;
13
14use std::collections::HashMap;
15use std::any::{Any, TypeId};
16use std::sync::{Arc, Mutex};
17use std::io;
18
19#[cfg(feature = "encrypted")]
20use crypto::signature::Keypair;
21
22#[derive(Debug, Default)]
23#[non_exhaustive]
24pub struct ServerConfig {
25	pub log_errors: bool
26}
27
28pub struct Data {
29	cfg: ServerConfig,
30	inner: HashMap<TypeId, Box<dyn Any + Send + Sync>>
31}
32
33impl Data {
34	fn new() -> Self {
35		Self {
36			cfg: ServerConfig::default(),
37			inner: HashMap::new()
38		}
39	}
40
41	pub fn cfg(&self) -> &ServerConfig {
42		&self.cfg
43	}
44
45	pub fn exists<D>(&self) -> bool
46	where D: Any {
47		TypeId::of::<D>() == TypeId::of::<Session>() ||
48		self.inner.contains_key(&TypeId::of::<D>())
49	}
50
51	fn insert<D>(&mut self, data: D)
52	where D: Any + Send + Sync {
53		self.inner.insert(data.type_id(), Box::new(data));
54	}
55
56	pub fn get<D>(&self) -> Option<&D>
57	where D: Any {
58		self.inner.get(&TypeId::of::<D>())
59			.and_then(|a| a.downcast_ref())
60	}
61
62	pub fn get_or_sess<'a, D>(&'a self, sess: &'a Session) -> Option<&'a D>
63	where D: Any {
64		if TypeId::of::<D>() == TypeId::of::<Session>() {
65			<dyn Any>::downcast_ref(sess)
66		} else {
67			self.get()
68		}
69	}
70}
71
72struct Requests<A, B> {
73	inner: HashMap<A, Box<dyn RequestHandler<B, Action=A> + Send + Sync>>
74}
75
76impl<A, B> Requests<A, B>
77where A: Action {
78	fn new() -> Self {
79		Self {
80			inner: HashMap::new()
81		}
82	}
83
84	fn insert<H>(&mut self, handler: H)
85	where H: RequestHandler<B, Action=A> + Send + Sync + 'static {
86		self.inner.insert(H::action(), Box::new(handler));
87	}
88
89	fn get(
90		&self,
91		action: &A
92	) -> Option<&Box<dyn RequestHandler<B, Action=A> + Send + Sync>> {
93		self.inner.get(action)
94	}
95
96}
97
98pub struct Server<A, B, L, More> {
99	inner: L,
100	requests: Requests<A, B>,
101	data: Data,
102	cfg: Config,
103	more: More
104}
105
106impl<A, B, L, More> Server<A, B, L, More>
107where A: Action {
108	pub fn register_request<H>(&mut self, handler: H)
109	where H: RequestHandler<B, Action=A> + Send + Sync + 'static {
110		handler.validate_data(&self.data);
111		self.requests.insert(handler);
112	}
113
114	pub fn register_data<D>(&mut self, data: D)
115	where D: Any + Send + Sync {
116		self.data.insert(data);
117	}
118}
119
120impl<A, B, L, More> Server<A, B, L, More>
121where
122	A: Action,
123	L: Listener
124{
125	/// If this is set to true
126	/// errors which are returned in `#[api(*)]` functions are logged to tracing
127	pub fn set_log_errors(&mut self, log: bool) {
128		self.data.cfg.log_errors = log;
129	}
130
131	/// optionally or just use run
132	pub fn build(self) -> BuiltServer<A, B, L, More> {
133		let shared = Arc::new(Shared {
134			requests: self.requests,
135			data: self.data
136		});
137
138		BuiltServer {
139			inner: self.inner,
140			shared,
141			more: self.more
142		}
143	}
144}
145
146impl<A, L> Server<A, PlainBytes, L, ()>
147where
148	A: Action,
149	L: Listener
150{
151	pub fn new(listener: L, cfg: Config) -> Self {
152		Self {
153			inner: listener,
154			requests: Requests::new(),
155			data: Data::new(),
156			cfg,
157			more: ()
158		}
159	}
160
161	pub async fn run(self) -> io::Result<()>
162	where A: Send + Sync + 'static {
163		let cfg = self.cfg.clone();
164
165		self.build().run_raw(|_, stream| {
166			Connection::new(stream, cfg.clone())
167		}).await
168	}
169
170}
171
172#[cfg(feature = "encrypted")]
173#[cfg_attr(docsrs, doc(cfg(feature = "encrypted")))]
174impl<A, L> Server<A, EncryptedBytes, L, Keypair>
175where
176	A: Action,
177	L: Listener
178{
179	pub fn new_encrypted(listener: L, cfg: Config, key: Keypair) -> Self {
180		Self {
181			inner: listener,
182			requests: Requests::new(),
183			data: Data::new(),
184			cfg,
185			more: key
186		}
187	}
188
189	pub async fn run(self) -> io::Result<()>
190	where A: Send + Sync + 'static {
191		let cfg = self.cfg.clone();
192
193		self.build().run_raw(move |key, stream| {
194			Connection::new_encrypted(stream, cfg.clone(), key.clone())
195		}).await
196	}
197}
198
199// impl
200
201struct Shared<A, B> {
202	requests: Requests<A, B>,
203	data: Data
204}
205
206pub struct BuiltServer<A, B, L, More> {
207	inner: L,
208	shared: Arc<Shared<A, B>>,
209	more: More
210}
211
212impl<A, B, L, More> BuiltServer<A, B, L, More>
213where
214	A: Action,
215	L: Listener
216{
217	pub fn get_data<D>(&self) -> Option<&D>
218	where D: Any {
219		self.shared.data.get()
220	}
221
222	pub async fn request<R>(
223		&self,
224		r: R,
225		session: &Arc<Session>
226	) -> Result<R::Response, R::Error>
227	where
228		R: Request<Action=A>,
229		R: IntoMessage<A, B>,
230		R::Response: FromMessage<A, B>,
231		R::Error: FromMessage<A, B>,
232		B: PacketBytes
233	{
234		let mut msg = r.into_message()
235			.map_err(R::Error::from_message_error)?;
236		msg.header_mut().set_action(R::ACTION);
237
238		// handle the request
239		let action = *msg.action().unwrap();
240
241		let handler = match self.shared.requests.get(&action) {
242			Some(handler) => handler,
243			// todo once we bump the version again
244			// we need to pass our own errors via packets
245			// not only those from the api users
246			None => {
247				tracing::error!("no handler for {:?}", action);
248				return Err(R::Error::from_request_error(
249					RequestError::NoResponse
250				))
251			}
252		};
253
254		let r = handler.handle(
255			msg,
256			&self.shared.data,
257			session
258		).await;
259
260		let res = match r {
261			Ok(mut msg) => {
262				msg.header_mut().set_action(action);
263				msg
264			},
265			Err(e) => {
266				// todo once we bump the version again
267				// we need to pass our own errors via packets
268				// not only those from the api users
269				tracing::error!(
270					"handler returned an error {:?}", e
271				);
272
273				return Err(R::Error::from_request_error(
274					RequestError::NoResponse
275				))
276			}
277		};
278
279		// now deserialize the response
280		if res.is_success() {
281			R::Response::from_message(res)
282				.map_err(R::Error::from_message_error)
283		} else {
284			R::Error::from_message(res)
285				.map(Err)
286				.map_err(R::Error::from_message_error)?
287		}
288	}
289
290	async fn run_raw<F>(&mut self, new_connection: F) -> io::Result<()>
291	where
292		A: Action + Send + Sync + 'static,
293		B: PacketBytes + Send + 'static,
294		F: Fn(&More, L::Stream) -> Connection<Message<A, B>>
295	{
296		loop {
297
298			// should we fail here??
299			let (stream, addr) = self.inner.accept().await?;
300
301			let mut con = new_connection(&self.more, stream);
302			let session = Arc::new(Session::new(addr));
303			session.set(con.configurator());
304
305			let share = self.shared.clone();
306			tokio::spawn(async move {
307				while let Some(req) = con.receive().await {
308					// todo replace with let else
309					let (msg, resp) = match req {
310						server::Message::Request(msg, resp) => (msg, resp),
311						// ignore streams for now
312						_ => continue
313					};
314
315					let share = share.clone();
316					let session = session.clone();
317
318					let action = match msg.action() {
319						Some(act) => *act,
320						// todo once we bump the version again
321						// we need to pass our own errors via packets
322						// not only those from the api users
323						None => {
324							tracing::error!("invalid action received");
325							continue
326						}
327					};
328
329					tokio::spawn(async move {
330						let handler = match share.requests.get(&action) {
331							Some(handler) => handler,
332							// todo once we bump the version again
333							// we need to pass our own errors via packets
334							// not only those from the api users
335							None => {
336								tracing::error!("no handler for {:?}", action);
337								return;
338							}
339						};
340						let r = handler.handle(
341							msg,
342							&share.data,
343							&session
344						).await;
345
346						match r {
347							Ok(mut msg) => {
348								msg.header_mut().set_action(action);
349								// i don't care about the response
350								let _ = resp.send(msg);
351							},
352							Err(e) => {
353								// todo once we bump the version again
354								// we need to pass our own errors via packets
355								// not only those from the api users
356								tracing::error!(
357									"handler returned an error {:?}", e
358								);
359							}
360						}
361					});
362				}
363			});
364		}
365	}
366}
367
368pub struct Session {
369	// (SocketAddr, S)
370	addr: SocketAddr,
371	data: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>
372}
373
374impl Session {
375	pub fn new(addr: SocketAddr) -> Self {
376		Self {
377			addr,
378			data: Mutex::new(HashMap::new())
379		}
380	}
381
382	pub fn addr(&self) -> &SocketAddr {
383		&self.addr
384	}
385
386	pub fn set<D>(&self, data: D)
387	where D: Any + Send + Sync {
388		self.data.lock().unwrap()
389			.insert(data.type_id(), Box::new(data));
390	}
391
392	pub fn get<D>(&self) -> Option<D>
393	where D: Any + Clone + Send + Sync {
394		self.data.lock().unwrap()
395			.get(&TypeId::of::<D>())
396			.and_then(|d| d.downcast_ref())
397			.map(Clone::clone)
398	}
399
400	pub fn take<D>(&self) -> Option<D>
401	where D: Any + Send + Sync {
402		self.data.lock().unwrap()
403			.remove(&TypeId::of::<D>())
404			.and_then(|d| d.downcast().ok())
405			.map(|b| *b)
406	}
407}
408
409
410#[cfg(all(test, feature = "json"))]
411mod json_tests {
412	use super::*;
413
414	use codegen::{IntoMessage, FromMessage, api};
415	use crate::request::Request;
416	use crate::message;
417	use crate::error;
418
419	use std::fmt;
420
421	use stream::util::testing::PanicListener;
422
423	use serde::{Serialize, Deserialize};
424
425
426	#[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
427	#[message(json)]
428	struct TestReq {
429		hello: u64
430	}
431
432	#[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
433	#[message(json)]
434	struct TestReq2 {
435		hello: u64
436	}
437
438	#[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
439	#[message(json)]
440	struct TestResp {
441		hi: u64
442	}
443
444	#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
445	pub enum Action {
446		Empty
447	}
448
449	#[derive(Debug, Clone, Serialize, Deserialize, IntoMessage, FromMessage)]
450	#[message(json)]
451	pub enum Error {
452		RequestError(String),
453		MessageError(String)
454	}
455
456	impl fmt::Display for Error {
457		fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
458			fmt::Debug::fmt(self, fmt)
459		}
460	}
461
462	impl std::error::Error for Error {}
463
464	impl error::ApiError for Error {
465		fn from_request_error(e: error::RequestError) -> Self {
466			Self::RequestError(e.to_string())
467		}
468
469		fn from_message_error(e: error::MessageError) -> Self {
470			Self::MessageError(e.to_string())
471		}
472	}
473
474	impl message::Action for Action {
475		fn from_u16(_num: u16) -> Option<Self> { todo!() }
476		fn as_u16(&self) -> u16 { todo!() }
477	}
478
479	impl Request for TestReq {
480		type Action = Action;
481		type Response = TestResp;
482		type Error = Error;
483
484		const ACTION: Action = Action::Empty;
485	}
486
487	impl Request for TestReq2 {
488		type Action = Action;
489		type Response = TestResp;
490		type Error = Error;
491
492		const ACTION: Action = Action::Empty;
493	}
494
495	#[api(TestReq)]
496	async fn test(req: TestReq) -> Result<TestResp, Error> {
497		println!("req {:?}", req);
498		Ok(TestResp {
499			hi: req.hello
500		})
501	}
502
503	#[api(TestReq2)]
504	async fn test_2(
505		req: TestReq2
506	) -> Result<TestResp, Error> {
507		println!("req {:?}", req);
508		Ok(TestResp {
509			hi: req.hello
510		})
511	}
512
513	#[tokio::test]
514	async fn test_direct_request() {
515		let mut server = Server::new(PanicListener::new(), Config {
516			timeout: std::time::Duration::from_millis(10),
517			body_limit: 4096
518		});
519
520		server.register_data(String::from("global String"));
521
522		server.register_request(test);
523		server.register_request(test_2);
524
525		let server = server.build();
526		let session = Arc::new(Session::new(
527			SocketAddr::V4("127.0.0.1:8080".parse().unwrap())
528		));
529
530		let r = server.request(TestReq { hello: 100 }, &session).await.unwrap();
531		assert_eq!(r.hi, 100);
532
533		let r = server.request(
534			TestReq2 { hello: 100 },
535			&session
536		).await.unwrap();
537		assert_eq!(r.hi, 100);
538
539		assert_eq!(server.get_data::<String>().unwrap(), "global String");
540	}
541}
542
543
544#[cfg(all(test, feature = "protobuf"))]
545mod protobuf_tests {
546	use codegen::{IntoMessage, FromMessage};
547
548	use fire_protobuf::{EncodeMessage, DecodeMessage};
549
550
551	#[derive(Debug, Default)]
552	#[derive(EncodeMessage, DecodeMessage, IntoMessage, FromMessage)]
553	#[message(protobuf)]
554	struct TestReq {
555		#[field(1)]
556		hello: u64
557	}
558
559	#[derive(Debug, Default)]
560	#[derive(EncodeMessage, DecodeMessage, IntoMessage, FromMessage)]
561	#[message(protobuf)]
562	struct TestReq2 {
563		#[field(1)]
564		hello: u64
565	}
566}