1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use anyhow::anyhow;
5use bytes::BytesMut;
6
7use crate::platform::Platform;
8use crate::request::Request;
9use crate::response::{OutputError, Response};
10use crate::watch::{Average, Watch};
11
12#[derive(Debug)]
13pub enum CommandError {
14 OutputError(OutputError),
15 ClientError(anyhow::Error),
16 ServerError(anyhow::Error),
17}
18
19#[macro_export]
20macro_rules! server_error {
21 ($err:expr $(,)?) => ({
22 use anyhow::anyhow;
23 jupiter::commands::CommandError::ServerError(anyhow!($err))
24 });
25 ($fmt:expr, $($arg:tt)*) => {
26 use anyhow::anyhow;
27 jupiter::commands::CommandError::ServerError(anyhow!($fmt, $($arg)*))
28 };
29}
30
31#[macro_export]
32macro_rules! client_error {
33 ($err:expr $(,)?) => ({
34 use anyhow::anyhow;
35 jupiter::commands::CommandError::ClientError(anyhow!($err))
36 });
37 ($fmt:expr, $($arg:tt)*) => {
38 use anyhow::anyhow;
39 jupiter::commands::CommandError::ClientError(anyhow!($fmt, $($arg)*))
40 };
41}
42
43impl From<OutputError> for CommandError {
44 fn from(output_error: OutputError) -> Self {
45 CommandError::OutputError(output_error)
46 }
47}
48
49impl From<anyhow::Error> for CommandError {
50 fn from(error: anyhow::Error) -> Self {
51 CommandError::ClientError(error)
52 }
53}
54
55pub type CommandResult = std::result::Result<(), CommandError>;
56
57pub trait ResultExt {
58 fn complete(self, call: Call);
59}
60
61impl ResultExt for CommandResult {
62 fn complete(self, call: Call) {
63 call.complete(self);
64 }
65}
66
67pub struct Call {
68 pub request: Request,
69 pub response: Response,
70 pub token: usize,
71 callback: tokio::sync::oneshot::Sender<Result<BytesMut, OutputError>>,
72}
73
74impl Call {
75 pub fn complete(mut self, result: CommandResult) {
76 let result = match result {
77 Ok(_) => self.response.complete(),
78 Err(CommandError::OutputError(error)) => Err(error),
79 Err(CommandError::ClientError(error)) => {
80 if let Err(error) = self.response.error(&format!("CLIENT: {}", error)) {
81 Err(error)
82 } else {
83 self.response.complete()
84 }
85 }
86 Err(CommandError::ServerError(error)) => {
87 if let Err(error) = self.response.error(&format!("SERVER: {}", error)) {
88 Err(error)
89 } else {
90 self.response.complete()
91 }
92 }
93 };
94
95 if let Err(_) = self.callback.send(result) {
96 log::error!("Failed to submit a result to a oneshot callback channel!");
97 }
98 }
99}
100
101pub type Queue = tokio::sync::mpsc::Sender<Call>;
102pub type Endpoint = tokio::sync::mpsc::Receiver<Call>;
103
104pub fn queue() -> (Queue, Endpoint) {
105 tokio::sync::mpsc::channel(1024)
106}
107
108pub struct Command {
109 pub name: &'static str,
110 queue: Queue,
111 token: usize,
112 call_metrics: Average,
113}
114
115impl Command {
116 pub fn call_count(&self) -> i32 {
117 self.call_metrics.count() as i32
118 }
119
120 pub fn avg_duration(&self) -> i32 {
121 self.call_metrics.avg() as i32
122 }
123}
124
125pub struct CommandDictionary {
126 commands: Mutex<HashMap<&'static str, Arc<Command>>>,
127}
128
129pub struct Dispatcher {
130 commands: HashMap<&'static str, (Arc<Command>, Queue)>,
131}
132
133impl CommandDictionary {
134 pub fn new() -> Self {
135 CommandDictionary {
136 commands: Mutex::new(HashMap::default()),
137 }
138 }
139
140 pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
141 let commands = Arc::new(CommandDictionary::new());
142 platform.register::<CommandDictionary>(commands.clone());
143
144 commands
145 }
146
147 pub fn register_command(&self, name: &'static str, queue: Queue, token: usize) {
148 let mut commands = self.commands.lock().unwrap();
149 if commands.get(name).is_some() {
150 log::error!("Not going to register command {} as there is already a command present for this name",
151 name);
152 } else {
153 log::debug!("Registering command {}...", name);
154 commands.insert(
155 name,
156 Arc::new(Command {
157 name,
158 queue,
159 token,
160 call_metrics: Average::new(),
161 }),
162 );
163 }
164 }
165
166 pub fn commands(&self) -> Vec<Arc<Command>> {
167 let mut result = Vec::new();
168 for command in self.commands.lock().unwrap().values() {
169 result.push(command.clone());
170 }
171
172 return result;
173 }
174
175 pub fn dispatcher(&self) -> Dispatcher {
176 let commands = self.commands.lock().unwrap();
177 let mut cloned_commands = HashMap::with_capacity(commands.len());
178 for command in commands.values() {
179 cloned_commands.insert(command.name, (command.clone(), command.queue.clone()));
180 }
181
182 Dispatcher {
183 commands: cloned_commands,
184 }
185 }
186}
187
188impl Dispatcher {
189 pub async fn invoke(&mut self, request: Request) -> Result<BytesMut, OutputError> {
190 let mut response = Response::new();
191 match self.commands.get_mut(request.command()) {
192 Some((command, queue)) => {
193 Dispatcher::invoke_command(command, queue, request, response).await
194 }
195 _ => {
196 response.error(&format!("CLIENT: Unknown command: {}", request.command()))?;
197 Ok(response.complete()?)
198 }
199 }
200 }
201
202 async fn invoke_command(
203 command: &Arc<Command>,
204 queue: &mut Queue,
205 request: Request,
206 response: Response,
207 ) -> Result<BytesMut, OutputError> {
208 let (callback, promise) = tokio::sync::oneshot::channel();
209 let task = Call {
210 request,
211 response,
212 callback,
213 token: command.token,
214 };
215
216 let watch = Watch::start();
217 if let Err(_) = queue.send(task).await {
218 Err(OutputError::ProtocolError(anyhow!(
219 "Failed to submit command into queue!"
220 )))
221 } else {
222 match promise.await {
223 Ok(result) => {
224 command.call_metrics.add(watch.micros());
225 result
226 }
227 _ => Err(OutputError::ProtocolError(anyhow!(
228 "Command {} did not yield any result!",
229 command.name
230 ))),
231 }
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use bytes::BytesMut;
239 use num_derive::FromPrimitive;
240 use num_traits::FromPrimitive;
241
242 use crate::commands::{queue, Call, CommandDictionary, CommandError, CommandResult, ResultExt};
243 use crate::request::Request;
244
245 fn ping(task: &mut Call) -> CommandResult {
246 task.response.simple("PONG")?;
247 Ok(())
248 }
249
250 fn test(task: &mut Call) -> CommandResult {
251 task.response.simple("OK")?;
252 Ok(())
253 }
254
255 #[derive(FromPrimitive)]
256 enum TestCommands {
257 Ping,
258 Test,
259 }
260
261 #[test]
262 fn a_command_can_be_executed() {
263 tokio_test::block_on(async {
264 let (queue, mut endpoint) = queue();
265 tokio::spawn(async move {
266 loop {
267 match endpoint.recv().await {
268 Some(mut call) => match TestCommands::from_usize(call.token) {
269 Some(TestCommands::Ping) => ping(&mut call).complete(call),
270 Some(TestCommands::Test) => test(&mut call).complete(call),
271 _ => call.complete(Err(CommandError::ServerError(anyhow::anyhow!(
272 "Unknown token received!"
273 )))),
274 },
275 _ => return,
276 }
277 }
278 });
279
280 let commands = CommandDictionary::new();
281 commands.register_command("PING", queue.clone(), TestCommands::Ping as usize);
282 commands.register_command("TEST", queue.clone(), TestCommands::Test as usize);
283 let mut dispatcher = commands.dispatcher();
284
285 let request = Request::parse(&mut BytesMut::from("*1\r\n$4\r\nPING\r\n"))
286 .unwrap()
287 .unwrap();
288 let result = dispatcher.invoke(request).await.unwrap();
289 assert_eq!(std::str::from_utf8(&result[..]).unwrap(), "+PONG\r\n");
290
291 let request = Request::parse(&mut BytesMut::from("*1\r\n$4\r\nTEST\r\n"))
292 .unwrap()
293 .unwrap();
294 let result = dispatcher.invoke(request).await.unwrap();
295 assert_eq!(std::str::from_utf8(&result[..]).unwrap(), "+OK\r\n");
296 });
297 }
298}