jupiter_rs/
commands.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use anyhow::anyhow;
5use bytes::BytesMut;
6
7use crate::platform::Platform;
8use crate::request::Request;
9use crate::response::{OutputError, Response};
10use crate::watch::{Average, Watch};
11
12#[derive(Debug)]
13pub enum CommandError {
14    OutputError(OutputError),
15    ClientError(anyhow::Error),
16    ServerError(anyhow::Error),
17}
18
19#[macro_export]
20macro_rules! server_error {
21    ($err:expr $(,)?) => ({
22        use anyhow::anyhow;
23        jupiter::commands::CommandError::ServerError(anyhow!($err))
24    });
25    ($fmt:expr, $($arg:tt)*) => {
26        use anyhow::anyhow;
27        jupiter::commands::CommandError::ServerError(anyhow!($fmt, $($arg)*))
28    };
29}
30
31#[macro_export]
32macro_rules! client_error {
33    ($err:expr $(,)?) => ({
34        use anyhow::anyhow;
35        jupiter::commands::CommandError::ClientError(anyhow!($err))
36    });
37    ($fmt:expr, $($arg:tt)*) => {
38        use anyhow::anyhow;
39        jupiter::commands::CommandError::ClientError(anyhow!($fmt, $($arg)*))
40    };
41}
42
43impl From<OutputError> for CommandError {
44    fn from(output_error: OutputError) -> Self {
45        CommandError::OutputError(output_error)
46    }
47}
48
49impl From<anyhow::Error> for CommandError {
50    fn from(error: anyhow::Error) -> Self {
51        CommandError::ClientError(error)
52    }
53}
54
55pub type CommandResult = std::result::Result<(), CommandError>;
56
57pub trait ResultExt {
58    fn complete(self, call: Call);
59}
60
61impl ResultExt for CommandResult {
62    fn complete(self, call: Call) {
63        call.complete(self);
64    }
65}
66
67pub struct Call {
68    pub request: Request,
69    pub response: Response,
70    pub token: usize,
71    callback: tokio::sync::oneshot::Sender<Result<BytesMut, OutputError>>,
72}
73
74impl Call {
75    pub fn complete(mut self, result: CommandResult) {
76        let result = match result {
77            Ok(_) => self.response.complete(),
78            Err(CommandError::OutputError(error)) => Err(error),
79            Err(CommandError::ClientError(error)) => {
80                if let Err(error) = self.response.error(&format!("CLIENT: {}", error)) {
81                    Err(error)
82                } else {
83                    self.response.complete()
84                }
85            }
86            Err(CommandError::ServerError(error)) => {
87                if let Err(error) = self.response.error(&format!("SERVER: {}", error)) {
88                    Err(error)
89                } else {
90                    self.response.complete()
91                }
92            }
93        };
94
95        if let Err(_) = self.callback.send(result) {
96            log::error!("Failed to submit a result to a oneshot callback channel!");
97        }
98    }
99}
100
101pub type Queue = tokio::sync::mpsc::Sender<Call>;
102pub type Endpoint = tokio::sync::mpsc::Receiver<Call>;
103
104pub fn queue() -> (Queue, Endpoint) {
105    tokio::sync::mpsc::channel(1024)
106}
107
108pub struct Command {
109    pub name: &'static str,
110    queue: Queue,
111    token: usize,
112    call_metrics: Average,
113}
114
115impl Command {
116    pub fn call_count(&self) -> i32 {
117        self.call_metrics.count() as i32
118    }
119
120    pub fn avg_duration(&self) -> i32 {
121        self.call_metrics.avg() as i32
122    }
123}
124
125pub struct CommandDictionary {
126    commands: Mutex<HashMap<&'static str, Arc<Command>>>,
127}
128
129pub struct Dispatcher {
130    commands: HashMap<&'static str, (Arc<Command>, Queue)>,
131}
132
133impl CommandDictionary {
134    pub fn new() -> Self {
135        CommandDictionary {
136            commands: Mutex::new(HashMap::default()),
137        }
138    }
139
140    pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
141        let commands = Arc::new(CommandDictionary::new());
142        platform.register::<CommandDictionary>(commands.clone());
143
144        commands
145    }
146
147    pub fn register_command(&self, name: &'static str, queue: Queue, token: usize) {
148        let mut commands = self.commands.lock().unwrap();
149        if commands.get(name).is_some() {
150            log::error!("Not going to register command {} as there is already a command present for this name",
151                   name);
152        } else {
153            log::debug!("Registering command {}...", name);
154            commands.insert(
155                name,
156                Arc::new(Command {
157                    name,
158                    queue,
159                    token,
160                    call_metrics: Average::new(),
161                }),
162            );
163        }
164    }
165
166    pub fn commands(&self) -> Vec<Arc<Command>> {
167        let mut result = Vec::new();
168        for command in self.commands.lock().unwrap().values() {
169            result.push(command.clone());
170        }
171
172        return result;
173    }
174
175    pub fn dispatcher(&self) -> Dispatcher {
176        let commands = self.commands.lock().unwrap();
177        let mut cloned_commands = HashMap::with_capacity(commands.len());
178        for command in commands.values() {
179            cloned_commands.insert(command.name, (command.clone(), command.queue.clone()));
180        }
181
182        Dispatcher {
183            commands: cloned_commands,
184        }
185    }
186}
187
188impl Dispatcher {
189    pub async fn invoke(&mut self, request: Request) -> Result<BytesMut, OutputError> {
190        let mut response = Response::new();
191        match self.commands.get_mut(request.command()) {
192            Some((command, queue)) => {
193                Dispatcher::invoke_command(command, queue, request, response).await
194            }
195            _ => {
196                response.error(&format!("CLIENT: Unknown command: {}", request.command()))?;
197                Ok(response.complete()?)
198            }
199        }
200    }
201
202    async fn invoke_command(
203        command: &Arc<Command>,
204        queue: &mut Queue,
205        request: Request,
206        response: Response,
207    ) -> Result<BytesMut, OutputError> {
208        let (callback, promise) = tokio::sync::oneshot::channel();
209        let task = Call {
210            request,
211            response,
212            callback,
213            token: command.token,
214        };
215
216        let watch = Watch::start();
217        if let Err(_) = queue.send(task).await {
218            Err(OutputError::ProtocolError(anyhow!(
219                "Failed to submit command into queue!"
220            )))
221        } else {
222            match promise.await {
223                Ok(result) => {
224                    command.call_metrics.add(watch.micros());
225                    result
226                }
227                _ => Err(OutputError::ProtocolError(anyhow!(
228                    "Command {} did not yield any result!",
229                    command.name
230                ))),
231            }
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use bytes::BytesMut;
239    use num_derive::FromPrimitive;
240    use num_traits::FromPrimitive;
241
242    use crate::commands::{queue, Call, CommandDictionary, CommandError, CommandResult, ResultExt};
243    use crate::request::Request;
244
245    fn ping(task: &mut Call) -> CommandResult {
246        task.response.simple("PONG")?;
247        Ok(())
248    }
249
250    fn test(task: &mut Call) -> CommandResult {
251        task.response.simple("OK")?;
252        Ok(())
253    }
254
255    #[derive(FromPrimitive)]
256    enum TestCommands {
257        Ping,
258        Test,
259    }
260
261    #[test]
262    fn a_command_can_be_executed() {
263        tokio_test::block_on(async {
264            let (queue, mut endpoint) = queue();
265            tokio::spawn(async move {
266                loop {
267                    match endpoint.recv().await {
268                        Some(mut call) => match TestCommands::from_usize(call.token) {
269                            Some(TestCommands::Ping) => ping(&mut call).complete(call),
270                            Some(TestCommands::Test) => test(&mut call).complete(call),
271                            _ => call.complete(Err(CommandError::ServerError(anyhow::anyhow!(
272                                "Unknown token received!"
273                            )))),
274                        },
275                        _ => return,
276                    }
277                }
278            });
279
280            let commands = CommandDictionary::new();
281            commands.register_command("PING", queue.clone(), TestCommands::Ping as usize);
282            commands.register_command("TEST", queue.clone(), TestCommands::Test as usize);
283            let mut dispatcher = commands.dispatcher();
284
285            let request = Request::parse(&mut BytesMut::from("*1\r\n$4\r\nPING\r\n"))
286                .unwrap()
287                .unwrap();
288            let result = dispatcher.invoke(request).await.unwrap();
289            assert_eq!(std::str::from_utf8(&result[..]).unwrap(), "+PONG\r\n");
290
291            let request = Request::parse(&mut BytesMut::from("*1\r\n$4\r\nTEST\r\n"))
292                .unwrap()
293                .unwrap();
294            let result = dispatcher.invoke(request).await.unwrap();
295            assert_eq!(std::str::from_utf8(&result[..]).unwrap(), "+OK\r\n");
296        });
297    }
298}