rpc_toy/
lib.rs

1//! ## A Toy RPC framework that is bug-prone, super slow, and hard to use
2//! ### But.. why?
3//! Well, I'm taking a class on distributed systems now and RPC was a topic so...
4//! What's best to understand a topic that to implement it?
5//!
6//! So that's where this came from. I repeat, this is toy. please no one
7//! use this
8//!
9//! ### General Information:
10//! This library is built upon two main protocols/concepts:
11//! - JSON serialization
12//! - TCP
13//!
14//! `JSON` is used for everything as a convinient (yet probably slow) way
15//! to serialize data types to strings. Then, the strings are converted to
16//! `UTF8` bytes and passed down.
17//! `TCP` is our choice of transport layer protocol. It makes moving the bytes
18//! from point A to point B easy.. So why not.
19//! Again, this is a toy. Not meant to be used for benchmarking or anything
20//! ### How to use:
21//! This library includes two main structs
22//! a `Client` struct and a `Server` struct.
23//! a server is meant to register RPC functions, and the client
24//! can then call them.
25//!
26//! ### Examples:
27//! #### Example Client
28//! ```no_run
29//! use rpc_toy::Client;
30//! // You can create a new client using "new"
31//! let mut client = Client::new("127.0.0.1:3001").unwrap();
32//! // All arguments have to be passed in as a slice of `serde_json::Value`s
33//! let one = serde_json::to_value(1u32).unwrap();
34//! let two = serde_json::to_value(2u32).unwrap();
35//! let args = vec![one, two];
36//! // Use the `call` function to call remote procedures
37//! let res = client.call("Add", &args).unwrap();
38//!
39//! let three: u32 = serde_json::from_value(res.unwrap()).unwrap();
40//! assert_eq!(three, 3);
41//! ```
42//! #### Example Server
43//! ```no_run
44//! use rpc_toy::Server;
45//! let mut server = Server::new();
46//! server.register("Add", |args| {
47//!     let one = args.get(0).unwrap();
48//!     let two = args.get(1).unwrap();
49//!     let one = serde_json::from_value::<u32>(one.clone()).unwrap();
50//!     let two = serde_json::from_value::<u32>(two.clone()).unwrap();
51//!     let three = one + two;
52//!     return Some(serde_json::to_value(three).unwrap());
53//! });
54//! server.listen("127.0.0.1:3001").unwrap();
55//! ```
56//! ### Message encodings:
57//!
58//! |  The client message encoding                        |
59//! | :-------------------------------------------------: |
60//! | 32 bits for the length of the function name         |
61//! | the name of the function                            |
62//! | The length of the argument, or zero for termination |
63//! | The argument encoded as JSON string utf8            |
64//! | second argument length, or zero for termination     |
65//! | Repeats until termination ...                       |
66//! -----------------------------------------------------
67//!
68//! |  The server message encoding                        |
69//! | :-------------------------------------------------: |
70//! | 32 bits for the length of the response              |
71//! | The response encoded as a JSON string utf8          |
72//! -----------------------------------------------------
73mod error;
74pub use error::Error;
75type Result<T, E = error::Error> = std::result::Result<T, E>;
76use std::io::prelude::*;
77use std::{
78    collections::HashMap,
79    net::{TcpListener, TcpStream},
80};
81
82/// An RPC client
83/// This is the main struct that should be used for
84/// implementing an RPC client.
85pub struct Client {
86    stream: TcpStream,
87}
88
89impl Client {
90    /// Creates a new client that connects to an RPC server
91    /// # Arguments:
92    /// - `addr` The address the TCP client should connect to
93    ///     this should be in the form "host:port"
94    /// # Example:
95    /// ```rust
96    /// use rpc_toy::Client;
97    /// let client = Client::new("127.0.0.1:3001");
98    /// ```
99    pub fn new(addr: &str) -> Result<Self> {
100        Ok(Client {
101            stream: TcpStream::connect(addr)?,
102        })
103    }
104    /// Invokes an RPC, this is the mechanism to "call" functions
105    /// on a remote server
106    ///
107    /// # Arguments:
108    /// - `fn_name`: The name of the function to call
109    ///   NOTE: The server **MUST** have registered this function, otherwise
110    ///   (currently) expect weird stuff to happen :)
111    /// - `args` a slice of `serde_json::Value`s. This represents the arguments
112    ///   that will be passed onto the server's functions
113    /// # Returns:
114    /// - a `Result<Option<serde_json::Value>>>`, which is `Ok` if nothing errored out
115    ///   the `Option` will be `None` if this is a void function, otherwise it will be
116    ///   `Some(value)` where `value` is a `serde_json::Value` representing the return value
117    ///   of the function
118    /// # Example:
119    /// ```no_run
120    /// use rpc_toy::Client;
121    /// let mut client = Client::new("127.0.0.1:3001").unwrap();
122    /// let one = serde_json::to_value(1u32).unwrap();
123    /// let two = serde_json::to_value(2u32).unwrap();
124    /// let args = vec![one, two];
125    /// let res = client.call("Add", &args).unwrap();
126    /// let three: u32 = serde_json::from_value(res.unwrap()).unwrap();
127    /// assert_eq!(three, 3);
128    /// ```
129    pub fn call(
130        &mut self,
131        fn_name: &str,
132        args: &[serde_json::Value],
133    ) -> Result<Option<serde_json::Value>> {
134        let mut bytes = Vec::new();
135        let fn_name = fn_name.as_bytes();
136        bytes.extend_from_slice(&(fn_name.len() as u32).to_be_bytes());
137        bytes.extend_from_slice(fn_name);
138        for arg in args {
139            let arg = serde_json::to_string(&arg)?;
140            let arg = arg.as_bytes();
141            bytes.extend_from_slice(&(arg.len() as u32).to_be_bytes());
142            bytes.extend_from_slice(arg);
143        }
144        bytes.extend_from_slice(&(0u32).to_be_bytes());
145        self.stream.write_all(&bytes)?;
146        let mut response_len = [0; 4];
147        self.stream.read_exact(&mut response_len)?;
148        let response_len = u32::from_be_bytes(response_len);
149        if response_len == 0 {
150            // void function
151            return Ok(None);
152        }
153        let mut response = vec![0; response_len as usize];
154        self.stream.read_exact(&mut response)?;
155        let response = std::str::from_utf8(&response)?;
156        Ok(Some(serde_json::from_str(response)?))
157    }
158}
159
160use std::sync::Mutex;
161type RPCFn = Box<dyn Fn(&[serde_json::Value]) -> Option<serde_json::Value> + Send>;
162
163/// A struct representing an RPC server, this is to be
164/// used to implement the server.
165#[derive(Default)]
166pub struct Server {
167    // At the time this looks ugly and there is currently no
168    // way to alias function traits :(
169    fn_table: Mutex<HashMap<String, RPCFn>>,
170}
171
172impl Server {
173    /// Creates a new RPC server
174    pub fn new() -> Self {
175        Self {
176            fn_table: Mutex::new(HashMap::new()),
177        }
178    }
179
180    /// Registers functions to be used in the RPC server
181    /// only functions registered using this function can be
182    /// called from the client
183    ///
184    /// # Arguments:
185    /// - `fn_name` the name of the function to register, this **MUST**
186    ///    be the same name the client expects to use
187    /// - `function` the function to run once the RPC is invoked
188    ///
189    /// # Example:
190    /// ```
191    /// use rpc_toy::Server;
192    /// let mut server = Server::new();
193    /// server.register("Add", |args| {
194    ///     let one = args.get(0).unwrap();
195    ///     let two = args.get(1).unwrap();
196    ///     let one = serde_json::from_value::<u32>(one.clone()).unwrap();
197    ///     let two = serde_json::from_value::<u32>(two.clone()).unwrap();
198    ///     let three = one + two;
199    ///     return Some(serde_json::to_value(three).unwrap());
200    /// });
201    /// ```
202    pub fn register<F>(&mut self, fn_name: &str, function: F)
203    where
204        F: Fn(&[serde_json::Value]) -> Option<serde_json::Value> + 'static + Send,
205    {
206        self.fn_table
207            .lock()
208            .unwrap()
209            .insert(fn_name.to_string(), Box::new(function));
210    }
211
212    /// Listen to RPC connections
213    /// This function must be run in order to start listening
214    /// for calls over the network
215    ///
216    /// # Arguments:
217    /// - `addr`: An address to bind to, must be in the form:
218    ///     "host:port"
219    /// # Examples:
220    /// ```no_run
221    /// use rpc_toy::Server;
222    /// let mut server = Server::new();
223    /// server.listen("127.0.0.1:3001").unwrap();
224    /// ```
225    pub fn listen(&self, addr: &str) -> Result<()> {
226        let listener = TcpListener::bind(addr)?;
227        for incoming in listener.incoming() {
228            let stream = incoming?;
229            crossbeam::thread::scope(move |s| {
230                s.spawn(move |_| self.handle_client(stream));
231            })
232            .ok();
233        }
234        Ok(())
235    }
236
237    fn handle_client(&self, mut stream: TcpStream) -> Result<()> {
238        // We first read the lenght of the name of the function
239        // that should be encoded as a big endian 4 byte value
240        let mut fn_name_len = [0; 4];
241        while stream.read_exact(&mut fn_name_len).is_ok() {
242            let fn_name_len = u32::from_be_bytes(fn_name_len);
243            let mut fn_name = vec![0; fn_name_len as usize];
244            // We then read the name of the function as a utf8 formatted
245            // string
246            stream.read_exact(&mut fn_name)?;
247            let fn_name = std::str::from_utf8(&fn_name)?;
248            // We check if the server has a function of that name
249            // registered
250            if self.fn_table.lock().unwrap().contains_key(fn_name) {
251                // We read all the arguments
252                let mut args = Vec::new();
253                let mut arg_len = [0; 4];
254                stream.read_exact(&mut arg_len)?;
255                let mut arg_len = u32::from_be_bytes(arg_len);
256                while arg_len != 0 {
257                    let mut arg = vec![0; arg_len as usize];
258                    stream.read_exact(&mut arg)?;
259                    let arg_str = std::str::from_utf8(&arg)?;
260                    args.push(serde_json::from_str(arg_str)?);
261                    let mut arg_len_buff = [0; 4];
262                    stream.read_exact(&mut arg_len_buff)?;
263                    arg_len = u32::from_be_bytes(arg_len_buff);
264                }
265                // We call the function the server registered
266                let res = (self.fn_table.lock().unwrap().get(fn_name).unwrap())(&args);
267                match res {
268                    Some(res) => {
269                        let res_str = serde_json::to_string(&res)?;
270                        let res = res_str.as_bytes();
271                        stream.write_all(&(res.len() as u32).to_be_bytes())?;
272                        stream.write_all(res)?;
273                    }
274                    None => {
275                        stream.write_all(&(0 as u32).to_be_bytes())?;
276                    }
277                }
278            } else {
279                // TODO: Implement error handling
280                // The server should send back an error to the client
281                // letting it know that there is no function with that name
282                break;
283            }
284        }
285        Ok(())
286    }
287}