1use crate::error::{RequestError, ApiError};
2use crate::message::{Action, Message, IntoMessage, FromMessage};
3use crate::request::{RequestHandler, Request};
4
5use stream::util::{SocketAddr, Listener, ListenerExt};
6use stream::packet::{Packet, PacketBytes};
7pub use stream::packet::PlainBytes;
8use stream::server::{self, Connection};
9pub use stream::server::Config;
10
11#[cfg(feature = "encrypted")]
12pub use stream::packet::EncryptedBytes;
13
14use std::collections::HashMap;
15use std::any::{Any, TypeId};
16use std::sync::{Arc, Mutex};
17use std::io;
18
19#[cfg(feature = "encrypted")]
20use crypto::signature::Keypair;
21
22#[derive(Debug, Default)]
23#[non_exhaustive]
24pub struct ServerConfig {
25 pub log_errors: bool
26}
27
28pub struct Data {
29 cfg: ServerConfig,
30 inner: HashMap<TypeId, Box<dyn Any + Send + Sync>>
31}
32
33impl Data {
34 fn new() -> Self {
35 Self {
36 cfg: ServerConfig::default(),
37 inner: HashMap::new()
38 }
39 }
40
41 pub fn cfg(&self) -> &ServerConfig {
42 &self.cfg
43 }
44
45 pub fn exists<D>(&self) -> bool
46 where D: Any {
47 TypeId::of::<D>() == TypeId::of::<Session>() ||
48 self.inner.contains_key(&TypeId::of::<D>())
49 }
50
51 fn insert<D>(&mut self, data: D)
52 where D: Any + Send + Sync {
53 self.inner.insert(data.type_id(), Box::new(data));
54 }
55
56 pub fn get<D>(&self) -> Option<&D>
57 where D: Any {
58 self.inner.get(&TypeId::of::<D>())
59 .and_then(|a| a.downcast_ref())
60 }
61
62 pub fn get_or_sess<'a, D>(&'a self, sess: &'a Session) -> Option<&'a D>
63 where D: Any {
64 if TypeId::of::<D>() == TypeId::of::<Session>() {
65 <dyn Any>::downcast_ref(sess)
66 } else {
67 self.get()
68 }
69 }
70}
71
72struct Requests<A, B> {
73 inner: HashMap<A, Box<dyn RequestHandler<B, Action=A> + Send + Sync>>
74}
75
76impl<A, B> Requests<A, B>
77where A: Action {
78 fn new() -> Self {
79 Self {
80 inner: HashMap::new()
81 }
82 }
83
84 fn insert<H>(&mut self, handler: H)
85 where H: RequestHandler<B, Action=A> + Send + Sync + 'static {
86 self.inner.insert(H::action(), Box::new(handler));
87 }
88
89 fn get(
90 &self,
91 action: &A
92 ) -> Option<&Box<dyn RequestHandler<B, Action=A> + Send + Sync>> {
93 self.inner.get(action)
94 }
95
96}
97
98pub struct Server<A, B, L, More> {
99 inner: L,
100 requests: Requests<A, B>,
101 data: Data,
102 cfg: Config,
103 more: More
104}
105
106impl<A, B, L, More> Server<A, B, L, More>
107where A: Action {
108 pub fn register_request<H>(&mut self, handler: H)
109 where H: RequestHandler<B, Action=A> + Send + Sync + 'static {
110 handler.validate_data(&self.data);
111 self.requests.insert(handler);
112 }
113
114 pub fn register_data<D>(&mut self, data: D)
115 where D: Any + Send + Sync {
116 self.data.insert(data);
117 }
118}
119
120impl<A, B, L, More> Server<A, B, L, More>
121where
122 A: Action,
123 L: Listener
124{
125 pub fn set_log_errors(&mut self, log: bool) {
128 self.data.cfg.log_errors = log;
129 }
130
131 pub fn build(self) -> BuiltServer<A, B, L, More> {
133 let shared = Arc::new(Shared {
134 requests: self.requests,
135 data: self.data
136 });
137
138 BuiltServer {
139 inner: self.inner,
140 shared,
141 more: self.more
142 }
143 }
144}
145
146impl<A, L> Server<A, PlainBytes, L, ()>
147where
148 A: Action,
149 L: Listener
150{
151 pub fn new(listener: L, cfg: Config) -> Self {
152 Self {
153 inner: listener,
154 requests: Requests::new(),
155 data: Data::new(),
156 cfg,
157 more: ()
158 }
159 }
160
161 pub async fn run(self) -> io::Result<()>
162 where A: Send + Sync + 'static {
163 let cfg = self.cfg.clone();
164
165 self.build().run_raw(|_, stream| {
166 Connection::new(stream, cfg.clone())
167 }).await
168 }
169
170}
171
172#[cfg(feature = "encrypted")]
173#[cfg_attr(docsrs, doc(cfg(feature = "encrypted")))]
174impl<A, L> Server<A, EncryptedBytes, L, Keypair>
175where
176 A: Action,
177 L: Listener
178{
179 pub fn new_encrypted(listener: L, cfg: Config, key: Keypair) -> Self {
180 Self {
181 inner: listener,
182 requests: Requests::new(),
183 data: Data::new(),
184 cfg,
185 more: key
186 }
187 }
188
189 pub async fn run(self) -> io::Result<()>
190 where A: Send + Sync + 'static {
191 let cfg = self.cfg.clone();
192
193 self.build().run_raw(move |key, stream| {
194 Connection::new_encrypted(stream, cfg.clone(), key.clone())
195 }).await
196 }
197}
198
199struct Shared<A, B> {
202 requests: Requests<A, B>,
203 data: Data
204}
205
206pub struct BuiltServer<A, B, L, More> {
207 inner: L,
208 shared: Arc<Shared<A, B>>,
209 more: More
210}
211
212impl<A, B, L, More> BuiltServer<A, B, L, More>
213where
214 A: Action,
215 L: Listener
216{
217 pub fn get_data<D>(&self) -> Option<&D>
218 where D: Any {
219 self.shared.data.get()
220 }
221
222 pub async fn request<R>(
223 &self,
224 r: R,
225 session: &Arc<Session>
226 ) -> Result<R::Response, R::Error>
227 where
228 R: Request<Action=A>,
229 R: IntoMessage<A, B>,
230 R::Response: FromMessage<A, B>,
231 R::Error: FromMessage<A, B>,
232 B: PacketBytes
233 {
234 let mut msg = r.into_message()
235 .map_err(R::Error::from_message_error)?;
236 msg.header_mut().set_action(R::ACTION);
237
238 let action = *msg.action().unwrap();
240
241 let handler = match self.shared.requests.get(&action) {
242 Some(handler) => handler,
243 None => {
247 tracing::error!("no handler for {:?}", action);
248 return Err(R::Error::from_request_error(
249 RequestError::NoResponse
250 ))
251 }
252 };
253
254 let r = handler.handle(
255 msg,
256 &self.shared.data,
257 session
258 ).await;
259
260 let res = match r {
261 Ok(mut msg) => {
262 msg.header_mut().set_action(action);
263 msg
264 },
265 Err(e) => {
266 tracing::error!(
270 "handler returned an error {:?}", e
271 );
272
273 return Err(R::Error::from_request_error(
274 RequestError::NoResponse
275 ))
276 }
277 };
278
279 if res.is_success() {
281 R::Response::from_message(res)
282 .map_err(R::Error::from_message_error)
283 } else {
284 R::Error::from_message(res)
285 .map(Err)
286 .map_err(R::Error::from_message_error)?
287 }
288 }
289
290 async fn run_raw<F>(&mut self, new_connection: F) -> io::Result<()>
291 where
292 A: Action + Send + Sync + 'static,
293 B: PacketBytes + Send + 'static,
294 F: Fn(&More, L::Stream) -> Connection<Message<A, B>>
295 {
296 loop {
297
298 let (stream, addr) = self.inner.accept().await?;
300
301 let mut con = new_connection(&self.more, stream);
302 let session = Arc::new(Session::new(addr));
303 session.set(con.configurator());
304
305 let share = self.shared.clone();
306 tokio::spawn(async move {
307 while let Some(req) = con.receive().await {
308 let (msg, resp) = match req {
310 server::Message::Request(msg, resp) => (msg, resp),
311 _ => continue
313 };
314
315 let share = share.clone();
316 let session = session.clone();
317
318 let action = match msg.action() {
319 Some(act) => *act,
320 None => {
324 tracing::error!("invalid action received");
325 continue
326 }
327 };
328
329 tokio::spawn(async move {
330 let handler = match share.requests.get(&action) {
331 Some(handler) => handler,
332 None => {
336 tracing::error!("no handler for {:?}", action);
337 return;
338 }
339 };
340 let r = handler.handle(
341 msg,
342 &share.data,
343 &session
344 ).await;
345
346 match r {
347 Ok(mut msg) => {
348 msg.header_mut().set_action(action);
349 let _ = resp.send(msg);
351 },
352 Err(e) => {
353 tracing::error!(
357 "handler returned an error {:?}", e
358 );
359 }
360 }
361 });
362 }
363 });
364 }
365 }
366}
367
368pub struct Session {
369 addr: SocketAddr,
371 data: Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>
372}
373
374impl Session {
375 pub fn new(addr: SocketAddr) -> Self {
376 Self {
377 addr,
378 data: Mutex::new(HashMap::new())
379 }
380 }
381
382 pub fn addr(&self) -> &SocketAddr {
383 &self.addr
384 }
385
386 pub fn set<D>(&self, data: D)
387 where D: Any + Send + Sync {
388 self.data.lock().unwrap()
389 .insert(data.type_id(), Box::new(data));
390 }
391
392 pub fn get<D>(&self) -> Option<D>
393 where D: Any + Clone + Send + Sync {
394 self.data.lock().unwrap()
395 .get(&TypeId::of::<D>())
396 .and_then(|d| d.downcast_ref())
397 .map(Clone::clone)
398 }
399
400 pub fn take<D>(&self) -> Option<D>
401 where D: Any + Send + Sync {
402 self.data.lock().unwrap()
403 .remove(&TypeId::of::<D>())
404 .and_then(|d| d.downcast().ok())
405 .map(|b| *b)
406 }
407}
408
409
410#[cfg(all(test, feature = "json"))]
411mod json_tests {
412 use super::*;
413
414 use codegen::{IntoMessage, FromMessage, api};
415 use crate::request::Request;
416 use crate::message;
417 use crate::error;
418
419 use std::fmt;
420
421 use stream::util::testing::PanicListener;
422
423 use serde::{Serialize, Deserialize};
424
425
426 #[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
427 #[message(json)]
428 struct TestReq {
429 hello: u64
430 }
431
432 #[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
433 #[message(json)]
434 struct TestReq2 {
435 hello: u64
436 }
437
438 #[derive(Debug, Serialize, Deserialize, IntoMessage, FromMessage)]
439 #[message(json)]
440 struct TestResp {
441 hi: u64
442 }
443
444 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
445 pub enum Action {
446 Empty
447 }
448
449 #[derive(Debug, Clone, Serialize, Deserialize, IntoMessage, FromMessage)]
450 #[message(json)]
451 pub enum Error {
452 RequestError(String),
453 MessageError(String)
454 }
455
456 impl fmt::Display for Error {
457 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
458 fmt::Debug::fmt(self, fmt)
459 }
460 }
461
462 impl std::error::Error for Error {}
463
464 impl error::ApiError for Error {
465 fn from_request_error(e: error::RequestError) -> Self {
466 Self::RequestError(e.to_string())
467 }
468
469 fn from_message_error(e: error::MessageError) -> Self {
470 Self::MessageError(e.to_string())
471 }
472 }
473
474 impl message::Action for Action {
475 fn from_u16(_num: u16) -> Option<Self> { todo!() }
476 fn as_u16(&self) -> u16 { todo!() }
477 }
478
479 impl Request for TestReq {
480 type Action = Action;
481 type Response = TestResp;
482 type Error = Error;
483
484 const ACTION: Action = Action::Empty;
485 }
486
487 impl Request for TestReq2 {
488 type Action = Action;
489 type Response = TestResp;
490 type Error = Error;
491
492 const ACTION: Action = Action::Empty;
493 }
494
495 #[api(TestReq)]
496 async fn test(req: TestReq) -> Result<TestResp, Error> {
497 println!("req {:?}", req);
498 Ok(TestResp {
499 hi: req.hello
500 })
501 }
502
503 #[api(TestReq2)]
504 async fn test_2(
505 req: TestReq2
506 ) -> Result<TestResp, Error> {
507 println!("req {:?}", req);
508 Ok(TestResp {
509 hi: req.hello
510 })
511 }
512
513 #[tokio::test]
514 async fn test_direct_request() {
515 let mut server = Server::new(PanicListener::new(), Config {
516 timeout: std::time::Duration::from_millis(10),
517 body_limit: 4096
518 });
519
520 server.register_data(String::from("global String"));
521
522 server.register_request(test);
523 server.register_request(test_2);
524
525 let server = server.build();
526 let session = Arc::new(Session::new(
527 SocketAddr::V4("127.0.0.1:8080".parse().unwrap())
528 ));
529
530 let r = server.request(TestReq { hello: 100 }, &session).await.unwrap();
531 assert_eq!(r.hi, 100);
532
533 let r = server.request(
534 TestReq2 { hello: 100 },
535 &session
536 ).await.unwrap();
537 assert_eq!(r.hi, 100);
538
539 assert_eq!(server.get_data::<String>().unwrap(), "global String");
540 }
541}
542
543
544#[cfg(all(test, feature = "protobuf"))]
545mod protobuf_tests {
546 use codegen::{IntoMessage, FromMessage};
547
548 use fire_protobuf::{EncodeMessage, DecodeMessage};
549
550
551 #[derive(Debug, Default)]
552 #[derive(EncodeMessage, DecodeMessage, IntoMessage, FromMessage)]
553 #[message(protobuf)]
554 struct TestReq {
555 #[field(1)]
556 hello: u64
557 }
558
559 #[derive(Debug, Default)]
560 #[derive(EncodeMessage, DecodeMessage, IntoMessage, FromMessage)]
561 #[message(protobuf)]
562 struct TestReq2 {
563 #[field(1)]
564 hello: u64
565 }
566}