linebased/
lib.rs

1#![recursion_limit = "256"]
2//! Drop-in TCP command server
3//!
4//! Provide a callback that is passed commands from clients and handle them synchronously.
5//! `tokio` is used internally so multiple clients may be active.
6//!
7//! # Examples
8//!
9//! ```no_run
10//! use linebased::Server;
11//!
12//! #[tokio::main]
13//! async fn main() {
14//!     // Create a server with the default config and a
15//!     // handler that only knows the "version" command
16//!     let mut server = Server::new(Default::default(), |query| {
17//!         match query {
18//!             "version" => String::from("0.1.0"),
19//!             _ => String::from("unknown command"),
20//!         }
21//!     }).unwrap();
22//!
23//!     server.run().await.unwrap();
24//! }
25//! ```
26//!
27//! Running a server in the background is also possible, just spawn the future
28//! returned by `Server::run`. Request a handle ! from the server so that you
29//! may shut it down gracefully.
30//!
31//! ```no_run
32//! use linebased::Server;
33//! use std::thread;
34//!
35//! #[tokio::main]
36//! async fn main() {
37//!     let mut server = Server::new(Default::default(), |query| {
38//!         match query {
39//!             "version" => String::from("0.1.0"),
40//!             _ => String::from("unknown command"),
41//!         }
42//!     }).unwrap();
43//!
44//!     let handle = server.handle();
45//!     let fut = tokio::spawn(async move { server.run().await });
46//!
47//!     // Time passes
48//!
49//!     handle.shutdown();
50//!     fut.await.expect("failed to spawn future").expect("Error from linebased::Server::run");
51//! }
52//! ```
53#![warn(missing_docs)]
54
55use std::io;
56use std::net::SocketAddr;
57use std::str;
58use std::sync::Arc;
59
60use futures::prelude::*;
61use log::{debug, error, info, trace, warn};
62use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
63use tokio::net::{TcpListener, TcpStream};
64use tokio::sync::{
65    broadcast::{self, Receiver, Sender},
66    Semaphore,
67};
68
69mod error;
70
71pub use error::Error;
72pub use error::Result;
73
74type HandleFn = Arc<dyn Fn(&str) -> String + 'static + Send + Sync>;
75
76/// Server configuration
77#[derive(Debug, Clone)]
78pub struct Config {
79    /// Address to listen on
80    host: String,
81
82    /// Port to listen on
83    port: u16,
84
85    /// Maximum number of client connections
86    max_clients: usize,
87
88    /// initial per-client buffer size, will grow beyond this limit if required.
89    client_buf_size: usize,
90
91    /// If the handler function blocks, it should be spawned on its own worker
92    /// thread so that the tokio threadpool isn't blocked by it.
93    handle_fn_blocks: bool,
94}
95
96impl Config {
97    /// Set host address to listen on
98    pub fn host<S>(mut self, val: S) -> Self
99    where
100        S: Into<String>,
101    {
102        self.host = val.into();
103        self
104    }
105
106    /// Set port to listen on
107    pub fn port(mut self, val: u16) -> Self {
108        self.port = val;
109        self
110    }
111
112    /// set maximum number of clients
113    pub fn max_clients(mut self, val: usize) -> Self {
114        self.max_clients = val;
115        self
116    }
117
118    /// Set the initial per-client buffer size, will grow beyond this limit if required.
119    pub fn client_buf_size(mut self, val: usize) -> Self {
120        self.client_buf_size = val;
121        self
122    }
123
124    /// Indicate that the handle fn might block and should be spawned on a tokio
125    /// worker thread. This should be used if the fn does heavy computation, or
126    /// has any blocking i/o in it.
127    pub fn handle_fn_blocks(mut self) -> Self {
128        self.handle_fn_blocks = true;
129        self
130    }
131}
132
133impl Default for Config {
134    fn default() -> Config {
135        Config {
136            host: "127.0.0.1".into(),
137            port: 7343,
138            max_clients: 32,
139            client_buf_size: 1024,
140            handle_fn_blocks: false,
141        }
142    }
143}
144
145struct Client {
146    buf: String,
147    reader: BufReader<ReadHalf<TcpStream>>,
148    writer: WriteHalf<TcpStream>,
149    handle_fn: HandleFn,
150    handle_fn_blocks: bool,
151}
152
153impl Client {
154    fn new(config: &Config, stream: TcpStream, handle_fn: &HandleFn) -> Client {
155        let (reader, writer) = tokio::io::split(stream);
156
157        let buf = String::with_capacity(config.client_buf_size);
158        let reader = BufReader::with_capacity(config.client_buf_size, reader);
159        let handle_fn = Arc::clone(handle_fn);
160
161        Client {
162            buf,
163            reader,
164            writer,
165            handle_fn,
166            handle_fn_blocks: config.handle_fn_blocks,
167        }
168    }
169
170    fn spawn(self, clients: &Arc<Semaphore>, shutdown_send: &Sender<ControlMsg>) {
171        let clients = Arc::clone(&clients);
172        let shutdown_recv = shutdown_send.subscribe();
173
174        tokio::spawn(self.try_accept(clients, shutdown_recv));
175    }
176
177    async fn try_accept(self, clients: Arc<Semaphore>, shutdown_recv: Receiver<ControlMsg>) {
178        let permit = match clients.try_acquire() {
179            Ok(client) => client,
180            Err(_) => {
181                warn!("rejecting client; max connections reached");
182                return;
183            }
184        };
185
186        trace!("accept client connection");
187        self.accept(shutdown_recv).await;
188
189        drop(permit);
190    }
191
192    async fn accept(mut self, mut shutdown_recv: Receiver<ControlMsg>) {
193        loop {
194            futures::select! {
195                result = self.handle_line().fuse() => {
196                    if let Err(e) = result {
197                        debug!("Error handling value - shutting down connection: {}", e);
198                        break;
199                    }
200
201                    self.buf.clear();
202                }
203                control_msg = shutdown_recv.recv().fuse() => {
204                    match control_msg {
205                        Ok(ControlMsg::Shutdown) => {
206                            info!("Shutting down server");
207                            break;
208                        }
209                        Err(e) => {
210                            error!("Error receiving control message {:?}", e);
211                            break;
212                        }
213                    }
214                }
215            }
216        }
217
218        self.shutdown().await;
219    }
220
221    async fn shutdown(self) {
222        if let Err(e) = self
223            .reader
224            .into_inner()
225            .unsplit(self.writer)
226            .shutdown()
227            .await
228        {
229            debug!("Error closing socket connection {:?}", e);
230        }
231    }
232
233    async fn handle_line(&mut self) -> Result<()> {
234        let bytes_read = self.reader.read_line(&mut self.buf).await?;
235        if bytes_read == 0 {
236            return Err(Error::NoBytesRead);
237        }
238
239        let slice = if self.buf.is_empty() {
240            &self.buf[..]
241        } else {
242            // Remove the newline at the end of the string
243            &self.buf[0..self.buf.len() - 1]
244        };
245
246        trace!("Read line: \"{}\"", slice);
247
248        let handle_fn = &self.handle_fn;
249        let mut response = if self.handle_fn_blocks {
250            let handle_fn = Arc::clone(handle_fn);
251            let string = slice.to_owned();
252
253            tokio::task::spawn_blocking(move || handle_fn(&string)).await?
254        } else {
255            handle_fn(&slice)
256        };
257
258        response.push('\n');
259
260        self.writer.write_all(response.as_bytes()).await?;
261        trace!("Wrote response: \"{}\"", response.trim());
262
263        Ok(())
264    }
265}
266
267/// Handle for the server
268pub struct Handle {
269    sender: Sender<ControlMsg>,
270}
271
272impl Handle {
273    /// Request the server to shutdown gracefully
274    pub fn shutdown(self) {
275        // send only returns an error if there are no receivers active, meaning
276        // the server was already shut down, so it is safe to ignore this
277        // result.
278        let _ = self.sender.send(ControlMsg::Shutdown);
279    }
280}
281
282/// The linebased TCP server
283pub struct Server {
284    handler: HandleFn,
285    config: Config,
286    address: SocketAddr,
287    shutdown_recv: Receiver<ControlMsg>,
288    shutdown_send: Sender<ControlMsg>,
289    clients: Arc<Semaphore>,
290}
291
292impl Server {
293    /// Create a new server
294    // # Examples
295    ///
296    /// ```no_run
297    /// use linebased::Server;
298    ///
299    /// // Create a server with the default config and a
300    /// // handler that only knows the "version" command
301    /// let mut server = Server::new(Default::default(), |query| {
302    ///     match query {
303    ///         "version" => {
304    ///             String::from("0.1.0")
305    ///         },
306    ///         _ => {
307    ///             String::from("unknown command")
308    ///         }
309    ///     }
310    /// }).unwrap();
311    /// ```
312    pub fn new<F>(config: Config, func: F) -> Result<Server>
313    where
314        F: Fn(&str) -> String + 'static + Send + Sync,
315    {
316        let address = format!("{host}:{port}", host = config.host, port = config.port).parse()?;
317        let (shutdown_send, shutdown_recv) = broadcast::channel(1);
318        let clients = Arc::new(Semaphore::new(config.max_clients));
319
320        Ok(Server {
321            handler: Arc::new(func),
322            config,
323            address,
324            shutdown_send,
325            shutdown_recv,
326            clients,
327        })
328    }
329
330    /// Get a handle for the server so graceful shutdown can be requested
331    pub fn handle(&self) -> Handle {
332        Handle {
333            sender: self.shutdown_send.clone(),
334        }
335    }
336
337    /// Run the event loop
338    pub async fn run(&mut self) -> io::Result<()> {
339        info!("Listening at {}", self.address);
340        let listener = TcpListener::bind(self.address).await?;
341
342        loop {
343            futures::select! {
344                accept = listener.accept().fuse() => {
345                    let (socket, _) = match accept {
346                        Ok(socket) => socket,
347                        Err(e) => {
348                            error!("Error accepting connection: {}", e);
349                            continue;
350                        }
351                    };
352
353                    self.accept(socket);
354                }
355                control_msg = self.shutdown_recv.recv().fuse() => {
356                    match control_msg {
357                        Ok(ControlMsg::Shutdown) => {
358                            info!("Shutting down server");
359                            break;
360                        }
361                        Err(e) => {
362                            error!("Error receiving control message {:?}", e);
363                            break;
364                        }
365                    }
366                }
367            }
368        }
369
370        Ok(())
371    }
372
373    fn accept(&self, socket: TcpStream) {
374        let client = Client::new(&self.config, socket, &self.handler);
375
376        client.spawn(&self.clients, &self.shutdown_send);
377    }
378}
379
380#[doc(hidden)]
381#[derive(Debug, Clone)]
382pub enum ControlMsg {
383    /// Stop the server and end all connections immediately
384    Shutdown,
385}
386
387#[cfg(test)]
388mod tests {
389    use std::net::SocketAddr;
390    use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
391    use tokio::net::TcpStream;
392
393    use super::{Config, Handle, Server};
394
395    trait AsByteSlice {
396        fn as_byte_slice(&self) -> &[u8];
397    }
398
399    impl AsByteSlice for String {
400        fn as_byte_slice(&self) -> &[u8] {
401            self.as_bytes()
402        }
403    }
404
405    impl<'a> AsByteSlice for &'a str {
406        fn as_byte_slice(&self) -> &[u8] {
407            self.as_bytes()
408        }
409    }
410
411    /// Client for testing
412    struct Client {
413        stream_read: BufReader<ReadHalf<TcpStream>>,
414        stream_write: WriteHalf<TcpStream>,
415    }
416
417    impl Client {
418        /// Create a client
419        ///
420        /// Any errors will panic since this is for testing only. The server is
421        /// assumed to be on the default port. Performance is not a
422        /// consideration here; only ergonomics, correctness, and failing early.
423        pub async fn new(config: &Config) -> Self {
424            let stream = Client::connect(config).await;
425
426            let (stream_read, stream_write) = io::split(stream);
427            let stream_read = BufReader::new(stream_read);
428
429            Self {
430                stream_read,
431                stream_write,
432            }
433        }
434
435        /// Connect to the server
436        ///
437        /// Retries as long as error is connection refused. I guess this can
438        /// mean tests hang if something is wrong. Oh well.
439        async fn connect(config: &Config) -> TcpStream {
440            loop {
441                match TcpStream::connect(
442                    format!("{}:{}", config.host, config.port)
443                        .parse::<SocketAddr>()
444                        .unwrap(),
445                )
446                .await
447                {
448                    Ok(stream) => return stream,
449                    Err(err) => match err.kind() {
450                        ::std::io::ErrorKind::ConnectionRefused => continue,
451                        _ => panic!("failed to connect; {}", err),
452                    },
453                }
454            }
455        }
456
457        /// Sends all bytes to the remote
458        pub async fn send<B>(&mut self, bytes: B)
459        where
460            B: AsByteSlice,
461        {
462            self.stream_write
463                .write_all(bytes.as_byte_slice())
464                .await
465                .expect("successfully send bytes");
466            self.stream_write
467                .write_all(b"\n")
468                .await
469                .expect("successfully send bytes");
470        }
471
472        /// Receive the next line.
473        ///
474        /// Extra data is buffered internally.
475        pub async fn recv(&mut self) -> String {
476            let mut buf = String::new();
477            self.stream_read
478                .read_line(&mut buf)
479                .await
480                .expect("read_line");
481
482            buf.trim_end().into()
483        }
484
485        pub async fn expect(&mut self, s: &str) {
486            let got = self.recv().await;
487            assert_eq!(got.as_str(), s);
488        }
489    }
490
491    fn run_server(config: &Config) -> TestHandle {
492        let _ = env_logger::try_init();
493        let config = config.to_owned();
494
495        let mut server = Server::new(config, |query| match query {
496            "version" => String::from("0.1.0"),
497            _ => String::from("unknown command"),
498        })
499        .unwrap();
500
501        let handle = server.handle();
502
503        tokio::spawn(async move {
504            server.run().await.unwrap();
505        });
506
507        TestHandle {
508            handle: Some(handle),
509        }
510    }
511
512    /// Handle wrapping test server
513    ///
514    /// Requests graceful shutdown and joins with thread on drop
515    pub struct TestHandle {
516        handle: Option<Handle>,
517    }
518
519    impl Drop for TestHandle {
520        fn drop(&mut self) {
521            let _ = self.handle.take().unwrap().shutdown();
522        }
523    }
524
525    #[tokio::test]
526    async fn it_works() {
527        let config = Config::default();
528        let _server = run_server(&config);
529
530        {
531            let mut client = Client::new(&config).await;
532            client.send("version").await;
533            client.expect("0.1.0").await;
534            client.send("nope").await;
535            client.expect("unknown command").await;
536        }
537    }
538
539    #[tokio::test]
540    async fn send_empty_line() {
541        let config = Config::default().port(5501);
542        let _server = run_server(&config);
543
544        {
545            let mut client = Client::new(&config).await;
546            client.send("").await;
547            client.expect("unknown command").await;
548            // commands should continue to work
549            client.send("version").await;
550            client.expect("0.1.0").await;
551        }
552    }
553
554    #[tokio::test]
555    async fn multiple_commands_received_at_once() {
556        let config = Config::default().port(5502);
557        let _server = run_server(&config);
558
559        {
560            let mut client = Client::new(&config).await;
561            client.send("version\nversion").await;
562
563            // This is a bug. Second response may or may not have a prompt.
564            let got = client.recv().await;
565            assert!(got.contains("0.1.0"));
566        }
567    }
568
569    #[tokio::test]
570    async fn exceed_max_clients() {
571        let config = Config::default().max_clients(1).port(5503);
572        let _server = run_server(&config);
573
574        {
575            let mut client = Client::new(&config).await;
576            {
577                // should get disconnected immediately
578                let _client = Client::new(&config).await;
579            }
580            client.send("version").await;
581            client.expect("0.1.0").await;
582        }
583    }
584}