1#![recursion_limit = "256"]
2#![warn(missing_docs)]
54
55use std::io;
56use std::net::SocketAddr;
57use std::str;
58use std::sync::Arc;
59
60use futures::prelude::*;
61use log::{debug, error, info, trace, warn};
62use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
63use tokio::net::{TcpListener, TcpStream};
64use tokio::sync::{
65 broadcast::{self, Receiver, Sender},
66 Semaphore,
67};
68
69mod error;
70
71pub use error::Error;
72pub use error::Result;
73
74type HandleFn = Arc<dyn Fn(&str) -> String + 'static + Send + Sync>;
75
76#[derive(Debug, Clone)]
78pub struct Config {
79 host: String,
81
82 port: u16,
84
85 max_clients: usize,
87
88 client_buf_size: usize,
90
91 handle_fn_blocks: bool,
94}
95
96impl Config {
97 pub fn host<S>(mut self, val: S) -> Self
99 where
100 S: Into<String>,
101 {
102 self.host = val.into();
103 self
104 }
105
106 pub fn port(mut self, val: u16) -> Self {
108 self.port = val;
109 self
110 }
111
112 pub fn max_clients(mut self, val: usize) -> Self {
114 self.max_clients = val;
115 self
116 }
117
118 pub fn client_buf_size(mut self, val: usize) -> Self {
120 self.client_buf_size = val;
121 self
122 }
123
124 pub fn handle_fn_blocks(mut self) -> Self {
128 self.handle_fn_blocks = true;
129 self
130 }
131}
132
133impl Default for Config {
134 fn default() -> Config {
135 Config {
136 host: "127.0.0.1".into(),
137 port: 7343,
138 max_clients: 32,
139 client_buf_size: 1024,
140 handle_fn_blocks: false,
141 }
142 }
143}
144
145struct Client {
146 buf: String,
147 reader: BufReader<ReadHalf<TcpStream>>,
148 writer: WriteHalf<TcpStream>,
149 handle_fn: HandleFn,
150 handle_fn_blocks: bool,
151}
152
153impl Client {
154 fn new(config: &Config, stream: TcpStream, handle_fn: &HandleFn) -> Client {
155 let (reader, writer) = tokio::io::split(stream);
156
157 let buf = String::with_capacity(config.client_buf_size);
158 let reader = BufReader::with_capacity(config.client_buf_size, reader);
159 let handle_fn = Arc::clone(handle_fn);
160
161 Client {
162 buf,
163 reader,
164 writer,
165 handle_fn,
166 handle_fn_blocks: config.handle_fn_blocks,
167 }
168 }
169
170 fn spawn(self, clients: &Arc<Semaphore>, shutdown_send: &Sender<ControlMsg>) {
171 let clients = Arc::clone(&clients);
172 let shutdown_recv = shutdown_send.subscribe();
173
174 tokio::spawn(self.try_accept(clients, shutdown_recv));
175 }
176
177 async fn try_accept(self, clients: Arc<Semaphore>, shutdown_recv: Receiver<ControlMsg>) {
178 let permit = match clients.try_acquire() {
179 Ok(client) => client,
180 Err(_) => {
181 warn!("rejecting client; max connections reached");
182 return;
183 }
184 };
185
186 trace!("accept client connection");
187 self.accept(shutdown_recv).await;
188
189 drop(permit);
190 }
191
192 async fn accept(mut self, mut shutdown_recv: Receiver<ControlMsg>) {
193 loop {
194 futures::select! {
195 result = self.handle_line().fuse() => {
196 if let Err(e) = result {
197 debug!("Error handling value - shutting down connection: {}", e);
198 break;
199 }
200
201 self.buf.clear();
202 }
203 control_msg = shutdown_recv.recv().fuse() => {
204 match control_msg {
205 Ok(ControlMsg::Shutdown) => {
206 info!("Shutting down server");
207 break;
208 }
209 Err(e) => {
210 error!("Error receiving control message {:?}", e);
211 break;
212 }
213 }
214 }
215 }
216 }
217
218 self.shutdown().await;
219 }
220
221 async fn shutdown(self) {
222 if let Err(e) = self
223 .reader
224 .into_inner()
225 .unsplit(self.writer)
226 .shutdown()
227 .await
228 {
229 debug!("Error closing socket connection {:?}", e);
230 }
231 }
232
233 async fn handle_line(&mut self) -> Result<()> {
234 let bytes_read = self.reader.read_line(&mut self.buf).await?;
235 if bytes_read == 0 {
236 return Err(Error::NoBytesRead);
237 }
238
239 let slice = if self.buf.is_empty() {
240 &self.buf[..]
241 } else {
242 &self.buf[0..self.buf.len() - 1]
244 };
245
246 trace!("Read line: \"{}\"", slice);
247
248 let handle_fn = &self.handle_fn;
249 let mut response = if self.handle_fn_blocks {
250 let handle_fn = Arc::clone(handle_fn);
251 let string = slice.to_owned();
252
253 tokio::task::spawn_blocking(move || handle_fn(&string)).await?
254 } else {
255 handle_fn(&slice)
256 };
257
258 response.push('\n');
259
260 self.writer.write_all(response.as_bytes()).await?;
261 trace!("Wrote response: \"{}\"", response.trim());
262
263 Ok(())
264 }
265}
266
267pub struct Handle {
269 sender: Sender<ControlMsg>,
270}
271
272impl Handle {
273 pub fn shutdown(self) {
275 let _ = self.sender.send(ControlMsg::Shutdown);
279 }
280}
281
282pub struct Server {
284 handler: HandleFn,
285 config: Config,
286 address: SocketAddr,
287 shutdown_recv: Receiver<ControlMsg>,
288 shutdown_send: Sender<ControlMsg>,
289 clients: Arc<Semaphore>,
290}
291
292impl Server {
293 pub fn new<F>(config: Config, func: F) -> Result<Server>
313 where
314 F: Fn(&str) -> String + 'static + Send + Sync,
315 {
316 let address = format!("{host}:{port}", host = config.host, port = config.port).parse()?;
317 let (shutdown_send, shutdown_recv) = broadcast::channel(1);
318 let clients = Arc::new(Semaphore::new(config.max_clients));
319
320 Ok(Server {
321 handler: Arc::new(func),
322 config,
323 address,
324 shutdown_send,
325 shutdown_recv,
326 clients,
327 })
328 }
329
330 pub fn handle(&self) -> Handle {
332 Handle {
333 sender: self.shutdown_send.clone(),
334 }
335 }
336
337 pub async fn run(&mut self) -> io::Result<()> {
339 info!("Listening at {}", self.address);
340 let listener = TcpListener::bind(self.address).await?;
341
342 loop {
343 futures::select! {
344 accept = listener.accept().fuse() => {
345 let (socket, _) = match accept {
346 Ok(socket) => socket,
347 Err(e) => {
348 error!("Error accepting connection: {}", e);
349 continue;
350 }
351 };
352
353 self.accept(socket);
354 }
355 control_msg = self.shutdown_recv.recv().fuse() => {
356 match control_msg {
357 Ok(ControlMsg::Shutdown) => {
358 info!("Shutting down server");
359 break;
360 }
361 Err(e) => {
362 error!("Error receiving control message {:?}", e);
363 break;
364 }
365 }
366 }
367 }
368 }
369
370 Ok(())
371 }
372
373 fn accept(&self, socket: TcpStream) {
374 let client = Client::new(&self.config, socket, &self.handler);
375
376 client.spawn(&self.clients, &self.shutdown_send);
377 }
378}
379
380#[doc(hidden)]
381#[derive(Debug, Clone)]
382pub enum ControlMsg {
383 Shutdown,
385}
386
387#[cfg(test)]
388mod tests {
389 use std::net::SocketAddr;
390 use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
391 use tokio::net::TcpStream;
392
393 use super::{Config, Handle, Server};
394
395 trait AsByteSlice {
396 fn as_byte_slice(&self) -> &[u8];
397 }
398
399 impl AsByteSlice for String {
400 fn as_byte_slice(&self) -> &[u8] {
401 self.as_bytes()
402 }
403 }
404
405 impl<'a> AsByteSlice for &'a str {
406 fn as_byte_slice(&self) -> &[u8] {
407 self.as_bytes()
408 }
409 }
410
411 struct Client {
413 stream_read: BufReader<ReadHalf<TcpStream>>,
414 stream_write: WriteHalf<TcpStream>,
415 }
416
417 impl Client {
418 pub async fn new(config: &Config) -> Self {
424 let stream = Client::connect(config).await;
425
426 let (stream_read, stream_write) = io::split(stream);
427 let stream_read = BufReader::new(stream_read);
428
429 Self {
430 stream_read,
431 stream_write,
432 }
433 }
434
435 async fn connect(config: &Config) -> TcpStream {
440 loop {
441 match TcpStream::connect(
442 format!("{}:{}", config.host, config.port)
443 .parse::<SocketAddr>()
444 .unwrap(),
445 )
446 .await
447 {
448 Ok(stream) => return stream,
449 Err(err) => match err.kind() {
450 ::std::io::ErrorKind::ConnectionRefused => continue,
451 _ => panic!("failed to connect; {}", err),
452 },
453 }
454 }
455 }
456
457 pub async fn send<B>(&mut self, bytes: B)
459 where
460 B: AsByteSlice,
461 {
462 self.stream_write
463 .write_all(bytes.as_byte_slice())
464 .await
465 .expect("successfully send bytes");
466 self.stream_write
467 .write_all(b"\n")
468 .await
469 .expect("successfully send bytes");
470 }
471
472 pub async fn recv(&mut self) -> String {
476 let mut buf = String::new();
477 self.stream_read
478 .read_line(&mut buf)
479 .await
480 .expect("read_line");
481
482 buf.trim_end().into()
483 }
484
485 pub async fn expect(&mut self, s: &str) {
486 let got = self.recv().await;
487 assert_eq!(got.as_str(), s);
488 }
489 }
490
491 fn run_server(config: &Config) -> TestHandle {
492 let _ = env_logger::try_init();
493 let config = config.to_owned();
494
495 let mut server = Server::new(config, |query| match query {
496 "version" => String::from("0.1.0"),
497 _ => String::from("unknown command"),
498 })
499 .unwrap();
500
501 let handle = server.handle();
502
503 tokio::spawn(async move {
504 server.run().await.unwrap();
505 });
506
507 TestHandle {
508 handle: Some(handle),
509 }
510 }
511
512 pub struct TestHandle {
516 handle: Option<Handle>,
517 }
518
519 impl Drop for TestHandle {
520 fn drop(&mut self) {
521 let _ = self.handle.take().unwrap().shutdown();
522 }
523 }
524
525 #[tokio::test]
526 async fn it_works() {
527 let config = Config::default();
528 let _server = run_server(&config);
529
530 {
531 let mut client = Client::new(&config).await;
532 client.send("version").await;
533 client.expect("0.1.0").await;
534 client.send("nope").await;
535 client.expect("unknown command").await;
536 }
537 }
538
539 #[tokio::test]
540 async fn send_empty_line() {
541 let config = Config::default().port(5501);
542 let _server = run_server(&config);
543
544 {
545 let mut client = Client::new(&config).await;
546 client.send("").await;
547 client.expect("unknown command").await;
548 client.send("version").await;
550 client.expect("0.1.0").await;
551 }
552 }
553
554 #[tokio::test]
555 async fn multiple_commands_received_at_once() {
556 let config = Config::default().port(5502);
557 let _server = run_server(&config);
558
559 {
560 let mut client = Client::new(&config).await;
561 client.send("version\nversion").await;
562
563 let got = client.recv().await;
565 assert!(got.contains("0.1.0"));
566 }
567 }
568
569 #[tokio::test]
570 async fn exceed_max_clients() {
571 let config = Config::default().max_clients(1).port(5503);
572 let _server = run_server(&config);
573
574 {
575 let mut client = Client::new(&config).await;
576 {
577 let _client = Client::new(&config).await;
579 }
580 client.send("version").await;
581 client.expect("0.1.0").await;
582 }
583 }
584}