goose/
controller.rs

1//! Optional telnet and WebSocket Controller threads.
2//!
3//! By default, Goose launches both a telnet Controller and a WebSocket Controller, allowing
4//! real-time control of the running load test.
5
6use crate::config::GooseConfiguration;
7use crate::metrics::GooseMetrics;
8use crate::test_plan::{TestPlan, TestPlanHistory, TestPlanStepAction};
9use crate::util;
10use crate::{AttackPhase, GooseAttack, GooseAttackRunState, GooseError};
11
12use async_trait::async_trait;
13use futures::{SinkExt, StreamExt};
14use regex::{Regex, RegexSet};
15use serde::{Deserialize, Serialize};
16use std::io::{self, Write};
17use std::str::{self, FromStr};
18use std::time::Duration;
19use strum::IntoEnumIterator;
20use strum_macros::EnumIter;
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::TcpListener;
23use tokio_tungstenite::tungstenite::Message;
24
25/// All commands recognized by the Goose Controllers.
26///
27/// Commands are not case sensitive. When sending commands to the WebSocket Controller,
28/// they must be formatted as json as defined by
29/// [ControllerWebSocketRequest](./struct.ControllerWebSocketRequest.html).
30///
31/// GOOSE DEVELOPER NOTE: The following steps are required to add a new command:
32///  1. Define the new command here in the ControllerCommand enum.
33///      - Commands will be displayed in the help screen in the order defined here, so
34///        they should be logically grouped.
35///  2. Add the new command to `ControllerCommand::details` and populate all
36///     `ControllerCommandDetails`, using other commands as an implementation reference.
37///       - The `regex` is used to identify the command, and optionally to extract a
38///         value (for example see `Hatchrate` and `Users`)
39///       - If additional validation is required beyond the regular expression, add
40///         the necessary logic to `ControllerCommand::validate_value`.
41///  3. Add any necessary parent process logic for the command to
42///     `GooseAttack::handle_controller_requests` (also in this file).
43///  4. Add a test for the new command in tests/controller.rs.
44#[derive(Clone, Debug, EnumIter, PartialEq, Eq)]
45pub enum ControllerCommand {
46    /// Displays a list of all commands supported by the Controller.
47    ///
48    /// # Example
49    /// Returns the a list of all supported Controller commands.
50    /// ```notest
51    /// help
52    /// ```
53    ///
54    /// This command can be run at any time.
55    Help,
56    /// Disconnect from the Controller.
57    ///
58    /// # Example
59    /// Disconnects from the Controller.
60    /// ```notest
61    /// exit
62    /// ```
63    ///
64    /// This command can be run at any time.
65    Exit,
66    /// Start an idle test.
67    ///
68    /// # Example
69    /// Starts an idle load test.
70    /// ```notest
71    /// start
72    /// ```
73    ///
74    /// Goose must be idle to process this command.
75    Start,
76    /// Stop a running test, putting it into an idle state.
77    ///
78    /// # Example
79    /// Stops a running (or stating) load test.
80    /// ```notest
81    /// stop
82    /// ```
83    ///
84    /// Goose must be running (or starting) to process this command.
85    Stop,
86    /// Tell the load test to shut down (which will disconnect the controller).
87    ///
88    /// # Example
89    /// Terminates the Goose process, cleanly shutting down the load test if running.
90    /// ```notest
91    /// shutdown
92    /// ```
93    ///
94    /// Goose can process this command at any time.
95    Shutdown,
96    /// Configure the host to load test.
97    ///
98    /// # Example
99    /// Tells Goose to generate load against <http://example.com/>.
100    /// ```notest
101    /// host http://example.com/
102    /// ```
103    ///
104    /// Goose must be idle to process this command.
105    Host,
106    /// Configure how quickly new [`GooseUser`](../goose/struct.GooseUser.html)s are launched.
107    ///
108    /// # Example
109    /// Tells Goose to launch a new user every 1.25 seconds.
110    /// ```notest
111    /// hatchrate 1.25
112    /// ```
113    ///
114    /// Goose can be idle or running when processing this command.
115    HatchRate,
116    /// Configure how long to take to launch all [`GooseUser`](../goose/struct.GooseUser.html)s.
117    ///
118    /// # Example
119    /// Tells Goose to launch a new user every 1.25 seconds.
120    /// ```notest
121    /// startuptime 1.25
122    /// ```
123    ///
124    /// Goose must be idle to process this command.
125    StartupTime,
126    /// Configure how many [`GooseUser`](../goose/struct.GooseUser.html)s to launch.
127    ///
128    /// # Example
129    /// Tells Goose to simulate 100 concurrent users.
130    /// ```notest
131    /// users 100
132    /// ```
133    ///
134    /// Can be configured on an idle or running load test.
135    Users,
136    /// Configure how long the load test should run before stopping and returning to an idle state.
137    ///
138    /// # Example
139    /// Tells Goose to run the load test for 1 minute, before automatically stopping.
140    /// ```notest
141    /// runtime 60
142    /// ```
143    ///
144    /// This can be configured when Goose is idle as well as when a Goose load test is running.
145    RunTime,
146    /// Define a load test plan. This will replace the previously configured test plan, if any.
147    ///
148    /// # Example
149    /// Tells Goose to launch 10 users in 5 seconds, maintain them for 30 seconds, and then spend 5 seconds
150    /// stopping them.
151    /// ```notest
152    /// testplan "10,5s;10,30s;0,5s"
153    /// ```
154    ///
155    /// Can be configured on an idle or running load test.
156    TestPlan,
157    /// Display the current [`GooseConfiguration`](../struct.GooseConfiguration.html)s.
158    ///
159    /// # Example
160    /// Returns the current Goose configuration.
161    /// ```notest
162    /// config
163    /// ```
164    Config,
165    /// Display the current [`GooseConfiguration`](../struct.GooseConfiguration.html)s in json format.
166    ///
167    /// # Example
168    /// Returns the current Goose configuration in json format.
169    /// ```notest
170    /// configjson
171    /// ```
172    ///
173    /// This command can be run at any time.
174    ConfigJson,
175    /// Display the current [`GooseMetric`](../metrics/struct.GooseMetrics.html)s.
176    ///
177    /// # Example
178    /// Returns the current Goose metrics.
179    /// ```notest
180    /// metrics
181    /// ```
182    ///
183    /// This command can be run at any time.
184    Metrics,
185    /// Display the current [`GooseMetric`](../metrics/struct.GooseMetrics.html)s in json format.
186    ///
187    /// # Example
188    /// Returns the current Goose metrics in json format.
189    /// ```notest
190    /// metricsjson
191    /// ```
192    ///
193    /// This command can be run at any time.
194    MetricsJson,
195}
196
197/// Defines details around identifying and processing ControllerCommands.
198///
199/// As order doesn't matter here, it's preferred to define commands in alphabetical order.
200impl ControllerCommand {
201    /// Each ControllerCommand must define complete ControllerCommentDetails.
202    ///  - `help` returns a ControllerHelp struct that is used to provide inline help.
203    ///  - `regex` returns an &str defining the regular expression used to match the command
204    ///    (and optionally to grab a value being set).
205    ///  - `process_response` contains a boxed closure that receives the parent process
206    ///    response to the command and responds to the controller appropriately.
207    fn details(&self) -> ControllerCommandDetails<'_> {
208        match self {
209            ControllerCommand::Config => ControllerCommandDetails {
210                help: ControllerHelp {
211                    name: "config",
212                    description: "display load test configuration\n",
213                },
214                regex: r"(?i)^config$",
215                process_response: Box::new(|response| {
216                    if let ControllerResponseMessage::Config(config) = response {
217                        Ok(format!("{config:#?}"))
218                    } else {
219                        Err("error loading configuration".to_string())
220                    }
221                }),
222            },
223            ControllerCommand::ConfigJson => ControllerCommandDetails {
224                help: ControllerHelp {
225                    name: "config-json",
226                    description: "display load test configuration in json format\n",
227                },
228                regex: r"(?i)^(configjson|config-json)$",
229                process_response: Box::new(|response| {
230                    if let ControllerResponseMessage::Config(config) = response {
231                        Ok(serde_json::to_string(&config).expect("unexpected serde failure"))
232                    } else {
233                        Err("error loading configuration".to_string())
234                    }
235                }),
236            },
237            ControllerCommand::Exit => ControllerCommandDetails {
238                help: ControllerHelp {
239                    name: "exit",
240                    description: "exit controller\n\n",
241                },
242                regex: r"(?i)^(exit|quit|q)$",
243                process_response: Box::new(|_| {
244                    let e = "received an impossible EXIT command";
245                    error!("{e}");
246                    Err(e.to_string())
247                }),
248            },
249            ControllerCommand::HatchRate => ControllerCommandDetails {
250                help: ControllerHelp {
251                    name: "hatchrate FLOAT",
252                    description: "set per-second rate users hatch\n",
253                },
254                regex: r"(?i)^(hatchrate|hatch_rate|hatch-rate) ([0-9]*(\.[0-9]*)?){1}$",
255                process_response: Box::new(|response| {
256                    if let ControllerResponseMessage::Bool(true) = response {
257                        Ok("hatch_rate configured".to_string())
258                    } else {
259                        Err("failed to configure hatch_rate".to_string())
260                    }
261                }),
262            },
263            ControllerCommand::Help => ControllerCommandDetails {
264                help: ControllerHelp {
265                    name: "help",
266                    description: "this help\n",
267                },
268                regex: r"(?i)^(help|\?)$",
269                process_response: Box::new(|_| {
270                    let e = "received an impossible HELP command";
271                    error!("{e}");
272                    Err(e.to_string())
273                }),
274            },
275            ControllerCommand::Host => ControllerCommandDetails {
276                help: ControllerHelp {
277                    name: "host HOST",
278                    description: "set host to load test, (ie https://web.site/)\n",
279                },
280                regex: r"(?i)^(host|hostname|host_name|host-name) ((https?)://.+)$",
281                process_response: Box::new(|response| {
282                    if let ControllerResponseMessage::Bool(true) = response {
283                        Ok("host configured".to_string())
284                    } else {
285                        Err("failed to reconfigure host, be sure host is valid and load test is idle".to_string())
286                    }
287                }),
288            },
289            ControllerCommand::Metrics => ControllerCommandDetails {
290                help: ControllerHelp {
291                    name: "metrics",
292                    description: "display metrics for current load test\n",
293                },
294                regex: r"(?i)^(metrics|stats)$",
295                process_response: Box::new(|response| {
296                    if let ControllerResponseMessage::Metrics(metrics) = response {
297                        Ok(metrics.to_string())
298                    } else {
299                        Err("error loading metrics".to_string())
300                    }
301                }),
302            },
303            ControllerCommand::MetricsJson => {
304                ControllerCommandDetails {
305                    help: ControllerHelp {
306                        name: "metrics-json",
307                        // No new-line as this is the last line of the help screen.
308                        description: "display metrics for current load test in json format",
309                    },
310                    regex: r"(?i)^(metricsjson|metrics-json|statsjson|stats-json)$",
311                    process_response: Box::new(|response| {
312                        if let ControllerResponseMessage::Metrics(metrics) = response {
313                            Ok(serde_json::to_string(&metrics).expect("unexpected serde failure"))
314                        } else {
315                            Err("error loading metrics".to_string())
316                        }
317                    }),
318                }
319            }
320            ControllerCommand::RunTime => ControllerCommandDetails {
321                help: ControllerHelp {
322                    name: "runtime TIME",
323                    description: "set how long to run test, (ie 1h30m5s)\n",
324                },
325                regex: r"(?i)^(run|runtime|run_time|run-time|) (\d+|((\d+?)h)?((\d+?)m)?((\d+?)s)?)$",
326                process_response: Box::new(|response| {
327                    if let ControllerResponseMessage::Bool(true) = response {
328                        Ok("run_time configured".to_string())
329                    } else {
330                        Err("failed to configure run_time".to_string())
331                    }
332                }),
333            },
334            ControllerCommand::Shutdown => ControllerCommandDetails {
335                help: ControllerHelp {
336                    name: "shutdown",
337                    description: "shutdown load test and exit controller\n\n",
338                },
339                regex: r"(?i)^shutdown$",
340                process_response: Box::new(|response| {
341                    if let ControllerResponseMessage::Bool(true) = response {
342                        Ok("load test shut down".to_string())
343                    } else {
344                        Err("failed to shut down load test".to_string())
345                    }
346                }),
347            },
348            ControllerCommand::Start => {
349                ControllerCommandDetails {
350                    help: ControllerHelp {
351                        name: "start",
352                        description: "start an idle load test\n",
353                    },
354                    regex: r"(?i)^start$",
355                    process_response: Box::new(|response| {
356                        if let ControllerResponseMessage::Bool(true) = response {
357                            Ok("load test started".to_string())
358                        } else {
359                            Err("unable to start load test, be sure it is idle and host is configured".to_string())
360                        }
361                    }),
362                }
363            }
364            ControllerCommand::StartupTime => ControllerCommandDetails {
365                help: ControllerHelp {
366                    name: "startup-time TIME",
367                    description: "set total time to take starting users\n",
368                },
369                regex: r"(?i)^(starttime|start_time|start-time|startup|startuptime|startup_time|startup-time) (\d+|((\d+?)h)?((\d+?)m)?((\d+?)s)?)$",
370                process_response: Box::new(|response| {
371                    if let ControllerResponseMessage::Bool(true) = response {
372                        Ok("startup_time configured".to_string())
373                    } else {
374                        Err(
375                            "failed to configure startup_time, be sure load test is idle"
376                                .to_string(),
377                        )
378                    }
379                }),
380            },
381            ControllerCommand::Stop => ControllerCommandDetails {
382                help: ControllerHelp {
383                    name: "stop",
384                    description: "stop a running load test and return to idle state\n",
385                },
386                regex: r"(?i)^stop$",
387                process_response: Box::new(|response| {
388                    if let ControllerResponseMessage::Bool(true) = response {
389                        Ok("load test stopped".to_string())
390                    } else {
391                        Err("load test not running, failed to stop".to_string())
392                    }
393                }),
394            },
395            ControllerCommand::TestPlan => ControllerCommandDetails {
396                help: ControllerHelp {
397                    name: "test-plan PLAN",
398                    description: "define or replace test-plan, (ie 10,5m;10,1h;0,30s)\n\n",
399                },
400                regex: r"(?i)^(testplan|test_plan|test-plan|plan) (((\d+)\s*,\s*(\d+|((\d+?)h)?((\d+?)m)?((\d+?)s)?)*;*)+)$",
401                process_response: Box::new(|response| {
402                    if let ControllerResponseMessage::Bool(true) = response {
403                        Ok("test-plan configured".to_string())
404                    } else {
405                        Err("failed to configure test-plan, be sure test-plan is valid".to_string())
406                    }
407                }),
408            },
409            ControllerCommand::Users => ControllerCommandDetails {
410                help: ControllerHelp {
411                    name: "users INT",
412                    description: "set number of simulated users\n",
413                },
414                regex: r"(?i)^(users?) (\d+)$",
415                process_response: Box::new(|response| {
416                    if let ControllerResponseMessage::Bool(true) = response {
417                        Ok("users configured".to_string())
418                    } else {
419                        Err("load test not idle, failed to reconfigure users".to_string())
420                    }
421                }),
422            },
423        }
424    }
425
426    // Optionally perform validation beyond what is possible with a regular expression.
427    //
428    // Return Some(value) if the value is valid, otherwise return None.
429    fn validate_value(&self, value: &str) -> Option<String> {
430        // The Regex that captures the host only validates that the host starts with
431        // http:// or https://. Now use a library to properly validate that this is
432        // a valid host before sending to the parent process.
433        if self == &ControllerCommand::Host {
434            if util::is_valid_host(value).is_ok() {
435                Some(value.to_string())
436            } else {
437                None
438            }
439        } else if value.is_empty() {
440            None
441        } else {
442            Some(value.to_string())
443        }
444    }
445
446    // If the regular expression that matches this command also matches a value, get and validate
447    // the value.
448    //
449    // Returns Some(value) if the value is valid, otherwise returns None.
450    fn get_value(&self, command_string: &str) -> Option<String> {
451        let regex = Regex::new(self.details().regex)
452            .expect("ControllerCommand::details().regex returned invalid regex [2]");
453        let caps = regex.captures(command_string).unwrap();
454        let value = caps.get(2).map_or("", |m| m.as_str());
455        self.validate_value(value)
456    }
457
458    // Builds a help screen displayed when a controller receives the `help` command.
459    fn display_help() -> String {
460        let mut help_text = Vec::new();
461        writeln!(
462            &mut help_text,
463            "{} {} controller commands:",
464            env!("CARGO_PKG_NAME"),
465            env!("CARGO_PKG_VERSION")
466        )
467        .expect("failed to write to buffer");
468        // Builds help screen in the order commands are defined in the ControllerCommand enum.
469        for command in ControllerCommand::iter() {
470            write!(
471                &mut help_text,
472                "{:<18} {}",
473                command.details().help.name,
474                command.details().help.description
475            )
476            .expect("failed to write to buffer");
477        }
478        String::from_utf8(help_text).expect("invalid utf-8 in help text")
479    }
480}
481
482/// The parent process side of the Controller functionality.
483impl GooseAttack {
484    /// Handle Controller requests.
485    pub(crate) async fn handle_controller_requests(
486        &mut self,
487        goose_attack_run_state: &mut GooseAttackRunState,
488    ) -> Result<(), GooseError> {
489        // If the controller is enabled, check if we've received any
490        // messages.
491        if let Some(c) = goose_attack_run_state.controller_channel_rx.as_ref() {
492            match c.try_recv() {
493                Ok(message) => {
494                    info!(
495                        "request from controller client {}: {:?}",
496                        message.client_id, message.request
497                    );
498                    // As order is not important here, commands should be defined in alphabetical order.
499                    match &message.request.command {
500                        // Send back a copy of the running configuration.
501                        ControllerCommand::Config | ControllerCommand::ConfigJson => {
502                            self.reply_to_controller(
503                                message,
504                                ControllerResponseMessage::Config(Box::new(
505                                    self.configuration.clone(),
506                                )),
507                            );
508                        }
509                        // Send back a copy of the running metrics.
510                        ControllerCommand::Metrics | ControllerCommand::MetricsJson => {
511                            self.reply_to_controller(
512                                message,
513                                ControllerResponseMessage::Metrics(Box::new(self.metrics.clone())),
514                            );
515                        }
516                        // Start the load test, and acknowledge command.
517                        ControllerCommand::Start => {
518                            // We can only start an idle load test.
519                            if self.attack_phase == AttackPhase::Idle {
520                                self.test_plan = TestPlan::build(&self.configuration);
521                                if self.prepare_load_test().is_ok() {
522                                    // Rebuild test plan in case any parameters have been changed.
523                                    self.set_attack_phase(
524                                        goose_attack_run_state,
525                                        AttackPhase::Increase,
526                                    );
527                                    self.reply_to_controller(
528                                        message,
529                                        ControllerResponseMessage::Bool(true),
530                                    );
531                                    // Reset the run state when starting a new load test.
532                                    self.reset_run_state(goose_attack_run_state).await?;
533                                    self.metrics.history.push(TestPlanHistory::step(
534                                        TestPlanStepAction::Increasing,
535                                        0,
536                                    ));
537                                } else {
538                                    // Do not move to Starting phase if unable to prepare load test.
539                                    self.reply_to_controller(
540                                        message,
541                                        ControllerResponseMessage::Bool(false),
542                                    );
543                                }
544                            } else {
545                                self.reply_to_controller(
546                                    message,
547                                    ControllerResponseMessage::Bool(false),
548                                );
549                            }
550                        }
551                        // Stop the load test, and acknowledge command.
552                        ControllerCommand::Stop => {
553                            // We can only stop a starting or running load test.
554                            if [AttackPhase::Increase, AttackPhase::Maintain]
555                                .contains(&self.attack_phase)
556                            {
557                                // Don't shutdown when load test is stopped by controller, remain idle instead.
558                                goose_attack_run_state.shutdown_after_stop = false;
559                                // Don't automatically restart the load test.
560                                self.configuration.no_autostart = true;
561                                self.cancel_attack(goose_attack_run_state).await?;
562                                self.reply_to_controller(
563                                    message,
564                                    ControllerResponseMessage::Bool(true),
565                                );
566                            } else {
567                                self.reply_to_controller(
568                                    message,
569                                    ControllerResponseMessage::Bool(false),
570                                );
571                            }
572                        }
573                        // Stop the load test, and acknowledge request.
574                        ControllerCommand::Shutdown => {
575                            // If load test is Idle, there are no metrics to display.
576                            if self.attack_phase == AttackPhase::Idle {
577                                self.metrics.display_metrics = false;
578                                self.set_attack_phase(
579                                    goose_attack_run_state,
580                                    AttackPhase::Decrease,
581                                );
582                            } else {
583                                self.cancel_attack(goose_attack_run_state).await?;
584                            }
585
586                            // Shutdown after stopping.
587                            goose_attack_run_state.shutdown_after_stop = true;
588                            // Confirm shut down to Controller.
589                            self.reply_to_controller(
590                                message,
591                                ControllerResponseMessage::Bool(true),
592                            );
593
594                            // Give the controller thread time to send the response.
595                            tokio::time::sleep(Duration::from_millis(250)).await;
596                        }
597                        ControllerCommand::Host => {
598                            if self.attack_phase == AttackPhase::Idle {
599                                // The controller uses a regular expression to validate that
600                                // this is a valid hostname, so simply use it with further
601                                // validation.
602                                if let Some(host) = &message.request.value {
603                                    info!(
604                                        "changing host from {:?} to {}",
605                                        self.configuration.host, host
606                                    );
607                                    self.configuration.host = host.to_string();
608                                    self.reply_to_controller(
609                                        message,
610                                        ControllerResponseMessage::Bool(true),
611                                    );
612                                } else {
613                                    debug!(
614                                        "controller didn't provide host: {:#?}",
615                                        &message.request
616                                    );
617                                    self.reply_to_controller(
618                                        message,
619                                        ControllerResponseMessage::Bool(false),
620                                    );
621                                }
622                            } else {
623                                self.reply_to_controller(
624                                    message,
625                                    ControllerResponseMessage::Bool(false),
626                                );
627                            }
628                        }
629                        ControllerCommand::Users => {
630                            // The controller uses a regular expression to validate that
631                            // this is a valid integer, so simply use it with further
632                            // validation.
633                            if let Some(users) = &message.request.value {
634                                // Use expect() as Controller uses regex to validate this is an integer.
635                                let new_users = usize::from_str(users)
636                                    .expect("failed to convert string to usize");
637                                // If setting users, any existing configuration for a test plan isn't valid.
638                                self.configuration.test_plan = None;
639
640                                match self.attack_phase {
641                                    // If the load test is idle, simply update the configuration.
642                                    AttackPhase::Idle => {
643                                        let current_users = if !self.test_plan.steps.is_empty() {
644                                            self.test_plan.steps[self.test_plan.current].0
645                                        } else {
646                                            self.configuration.users.unwrap_or_default()
647                                        };
648                                        info!(
649                                            "changing users from {current_users:?} to {new_users}"
650                                        );
651                                        self.configuration.users = Some(new_users);
652                                        self.reply_to_controller(
653                                            message,
654                                            ControllerResponseMessage::Bool(true),
655                                        );
656                                    }
657                                    // If the load test is running, rebuild the active test plan.
658                                    AttackPhase::Increase
659                                    | AttackPhase::Decrease
660                                    | AttackPhase::Maintain => {
661                                        info!(
662                                            "changing users from {} to {new_users}",
663                                            goose_attack_run_state.active_users
664                                        );
665                                        // Determine how long has elapsed since this step started.
666                                        let elapsed = self.step_elapsed() as usize;
667
668                                        // Determine how quickly to adjust user account.
669                                        let hatch_rate = if let Some(hatch_rate) =
670                                            self.configuration.hatch_rate.as_ref()
671                                        {
672                                            util::get_hatch_rate(Some(hatch_rate.to_string()))
673                                        } else {
674                                            util::get_hatch_rate(None)
675                                        };
676                                        // Convert hatch_rate to milliseconds.
677                                        let ms_hatch_rate = 1.0 / hatch_rate * 1_000.0;
678                                        // Determine how many users to increase or decrease by.
679                                        let user_difference = (goose_attack_run_state.active_users
680                                            as isize
681                                            - new_users as isize)
682                                            .abs();
683                                        // Multiply the user difference by the hatch rate to get the total_time required.
684                                        let total_time =
685                                            (ms_hatch_rate * user_difference as f32) as usize;
686
687                                        // Reset the test_plan to adjust to the newly specified users.
688                                        self.test_plan.steps = vec![
689                                            // Record how many active users there are currently.
690                                            (goose_attack_run_state.active_users, elapsed),
691                                            // Configure the new user count.
692                                            (new_users, total_time),
693                                        ];
694
695                                        // Reset the current step to what was happening when reconfiguration happened.
696                                        self.test_plan.current = 0;
697
698                                        // Allocate more users if increasing users.
699                                        if new_users > goose_attack_run_state.active_users {
700                                            self.weighted_users = self
701                                                .weight_scenario_users(user_difference as usize)?;
702                                        }
703
704                                        // Also update the running configurtion (this impacts if the test is stopped and then
705                                        // restarted through the controller).
706                                        self.configuration.users = Some(new_users);
707
708                                        // Finally, advance to the next step to adjust user count.
709                                        self.advance_test_plan(goose_attack_run_state);
710
711                                        self.reply_to_controller(
712                                            message,
713                                            ControllerResponseMessage::Bool(true),
714                                        );
715                                    }
716                                    _ => {
717                                        self.reply_to_controller(
718                                            message,
719                                            ControllerResponseMessage::Bool(false),
720                                        );
721                                    }
722                                }
723                            } else {
724                                warn!(
725                                    "[controller]: didn't provide users: {:#?}",
726                                    &message.request
727                                );
728                                self.reply_to_controller(
729                                    message,
730                                    ControllerResponseMessage::Bool(false),
731                                );
732                            }
733                        }
734                        ControllerCommand::HatchRate => {
735                            // The controller uses a regular expression to validate that
736                            // this is a valid float, so simply use it with further
737                            // validation.
738                            if let Some(hatch_rate) = &message.request.value {
739                                // If startup_time was already set, unset it first.
740                                if !self.configuration.startup_time.is_empty() {
741                                    info!(
742                                        "resetting startup_time from {} to 0",
743                                        self.configuration.startup_time
744                                    );
745                                    self.configuration.startup_time = "0".to_string();
746                                }
747                                info!(
748                                    "changing hatch_rate from {:?} to {}",
749                                    self.configuration.hatch_rate, hatch_rate
750                                );
751                                self.configuration.hatch_rate = Some(hatch_rate.clone());
752                                // If setting hatch_rate, any existing configuration for a test plan isn't valid.
753                                self.configuration.test_plan = None;
754                                self.reply_to_controller(
755                                    message,
756                                    ControllerResponseMessage::Bool(true),
757                                );
758                            } else {
759                                warn!(
760                                    "Controller didn't provide hatch_rate: {:#?}",
761                                    &message.request
762                                );
763                            }
764                        }
765                        ControllerCommand::StartupTime => {
766                            if self.attack_phase == AttackPhase::Idle {
767                                // The controller uses a regular expression to validate that
768                                // this is a valid startup time, so simply use it with further
769                                // validation.
770                                if let Some(startup_time) = &message.request.value {
771                                    // If hatch_rate was already set, unset it first.
772                                    if let Some(hatch_rate) = &self.configuration.hatch_rate {
773                                        info!("resetting hatch_rate from {hatch_rate} to None");
774                                        self.configuration.hatch_rate = None;
775                                    }
776                                    info!(
777                                        "changing startup_rate from {} to {}",
778                                        self.configuration.startup_time, startup_time
779                                    );
780                                    self.configuration.startup_time.clone_from(startup_time);
781                                    // If setting startup_time, any existing configuration for a test plan isn't valid.
782                                    self.configuration.test_plan = None;
783                                    self.reply_to_controller(
784                                        message,
785                                        ControllerResponseMessage::Bool(true),
786                                    );
787                                } else {
788                                    warn!(
789                                        "Controller didn't provide startup_time: {:#?}",
790                                        &message.request
791                                    );
792                                    self.reply_to_controller(
793                                        message,
794                                        ControllerResponseMessage::Bool(false),
795                                    );
796                                }
797                            } else {
798                                self.reply_to_controller(
799                                    message,
800                                    ControllerResponseMessage::Bool(false),
801                                );
802                            }
803                        }
804                        ControllerCommand::RunTime => {
805                            // The controller uses a regular expression to validate that
806                            // this is a valid run time, so simply use it with further
807                            // validation.
808                            if let Some(run_time) = &message.request.value {
809                                info!(
810                                    "changing run_time from {:?} to {}",
811                                    self.configuration.run_time, run_time
812                                );
813                                self.configuration.run_time.clone_from(run_time);
814                                // If setting run_time, any existing configuration for a test plan isn't valid.
815                                self.configuration.test_plan = None;
816                                self.reply_to_controller(
817                                    message,
818                                    ControllerResponseMessage::Bool(true),
819                                );
820                            } else {
821                                warn!(
822                                    "Controller didn't provide run_time: {:#?}",
823                                    &message.request
824                                );
825                            }
826                        }
827                        ControllerCommand::TestPlan => {
828                            if let Some(value) = &message.request.value {
829                                match value.parse::<TestPlan>() {
830                                    Ok(t) => {
831                                        // Switch the configuration to use the test plan.
832                                        self.configuration.test_plan = Some(t.clone());
833                                        self.configuration.users = None;
834                                        self.configuration.hatch_rate = None;
835                                        self.configuration.startup_time = "0".to_string();
836                                        self.configuration.run_time = "0".to_string();
837                                        match self.attack_phase {
838                                            // If the load test is idle, just update the configuration.
839                                            AttackPhase::Idle => {
840                                                self.reply_to_controller(
841                                                    message,
842                                                    ControllerResponseMessage::Bool(true),
843                                                );
844                                            }
845                                            // If the load test is running, rebuild the active test plan.
846                                            AttackPhase::Increase
847                                            | AttackPhase::Decrease
848                                            | AttackPhase::Maintain => {
849                                                // Rebuild the active test plan.
850                                                self.test_plan = t;
851
852                                                // Reallocate users.
853                                                self.weighted_users = self.weight_scenario_users(
854                                                    self.test_plan.total_users(),
855                                                )?;
856
857                                                // Determine how long the current step has been running.
858                                                let elapsed = self.step_elapsed() as usize;
859
860                                                // Insert the current state of the test plan before the new test plan.
861                                                self.test_plan.steps.insert(
862                                                    0,
863                                                    (goose_attack_run_state.active_users, elapsed),
864                                                );
865
866                                                // Finally, advance to the next step to adjust user count.
867                                                self.advance_test_plan(goose_attack_run_state);
868
869                                                // The load test is successfully reconfigured.
870                                                self.reply_to_controller(
871                                                    message,
872                                                    ControllerResponseMessage::Bool(true),
873                                                );
874                                            }
875                                            _ => {
876                                                unreachable!("Controller used in impossible phase.")
877                                            }
878                                        }
879                                    }
880                                    Err(e) => {
881                                        warn!("Controller provided invalid test_plan: {e}");
882                                        self.reply_to_controller(
883                                            message,
884                                            ControllerResponseMessage::Bool(false),
885                                        );
886                                    }
887                                }
888                            } else {
889                                warn!(
890                                    "Controller didn't provide test_plan: {:#?}",
891                                    &message.request
892                                );
893                                self.reply_to_controller(
894                                    message,
895                                    ControllerResponseMessage::Bool(false),
896                                );
897                            }
898                        }
899                        // These messages shouldn't be received here.
900                        ControllerCommand::Help | ControllerCommand::Exit => {
901                            warn!("Unexpected command: {:?}", &message.request);
902                        }
903                    }
904                }
905                Err(e) => {
906                    // Errors can be ignored, they happen any time there are no messages.
907                    debug!("error receiving message: {e}");
908                }
909            }
910        };
911        Ok(())
912    }
913
914    /// Use the provided oneshot channel to reply to a controller client request.
915    pub(crate) fn reply_to_controller(
916        &mut self,
917        request: ControllerRequest,
918        response: ControllerResponseMessage,
919    ) {
920        if let Some(oneshot_tx) = request.response_channel {
921            if oneshot_tx
922                .send(ControllerResponse {
923                    _client_id: request.client_id,
924                    response,
925                })
926                .is_err()
927            {
928                warn!("failed to send response to controller via one-shot channel")
929            }
930        }
931    }
932}
933
934/// The control loop listens for connections on the configured TCP port. Each connection
935/// spawns a new thread so multiple clients can connect. Handles incoming connections for
936/// both telnet and WebSocket clients.
937///  -  @TODO: optionally limit how many controller connections are allowed
938///  -  @TODO: optionally require client authentication
939///  -  @TODO: optionally ssl-encrypt client communication
940pub(crate) async fn controller_main(
941    // Expose load test configuration to controller thread.
942    configuration: GooseConfiguration,
943    // For sending requests to the parent process.
944    channel_tx: flume::Sender<ControllerRequest>,
945    // Which type of controller to launch.
946    protocol: ControllerProtocol,
947) -> io::Result<()> {
948    // Build protocol-appropriate address.
949    let address = match &protocol {
950        ControllerProtocol::Telnet => format!(
951            "{}:{}",
952            configuration.telnet_host, configuration.telnet_port
953        ),
954        ControllerProtocol::WebSocket => format!(
955            "{}:{}",
956            configuration.websocket_host, configuration.websocket_port
957        ),
958    };
959
960    // All controllers use a TcpListener port.
961    debug!("[controller]: preparing to bind {protocol:?} to: {address}");
962    let listener = TcpListener::bind(&address).await?;
963    info!("[controller]: {protocol:?} listening on: {address}");
964
965    // Counter increments each time a controller client connects with this protocol.
966    let mut thread_id: u32 = 0;
967
968    // Wait for a connection.
969    while let Ok((stream, _)) = listener.accept().await {
970        thread_id += 1;
971
972        // Identify the client ip and port, used primarily for debug logging.
973        let peer_address = stream
974            .peer_addr()
975            .map_or("UNKNOWN ADDRESS".to_string(), |p| p.to_string());
976
977        // Create a per-client Controller state.
978        let controller_state = ControllerState {
979            thread_id,
980            peer_address,
981            channel_tx: channel_tx.clone(),
982            protocol: protocol.clone(),
983        };
984
985        // Spawn a new thread to communicate with a client. The returned JoinHandle is
986        // ignored as the thread simply runs until the client exits or Goose shuts down.
987        // Don't .await the tokio::spawn or Goose can't handle multiple simultaneous
988        // connections.
989        let _ignored_joinhandle = tokio::spawn(controller_state.accept_connections(stream));
990    }
991
992    Ok(())
993}
994
995/// Implement [`FromStr`] to convert controller commands and optional values to the proper enum
996/// representation.
997impl FromStr for ControllerCommand {
998    type Err = GooseError;
999
1000    // Use regular expressions to convert controller input to ControllerCommands.
1001    fn from_str(s: &str) -> Result<Self, Self::Err> {
1002        // Load all ControllerCommand regex into a set.
1003        let mut regex_set: Vec<String> = Vec::new();
1004        let mut keys = Vec::new();
1005        for t in ControllerCommand::iter() {
1006            keys.push(t.clone());
1007            regex_set.push(t.details().regex.to_string());
1008        }
1009        let commands = RegexSet::new(regex_set)
1010            .expect("ControllerCommand::details().regex returned invalid regex");
1011        let matches: Vec<_> = commands.matches(s).into_iter().collect();
1012        // This happens any time the controller receives an invalid command.
1013        if matches.is_empty() {
1014            Err(GooseError::InvalidControllerCommand {
1015                detail: format!("unrecognized controller command: '{s}'."),
1016            })
1017        // This shouldn't ever happen, but if it does report all available information.
1018        } else if matches.len() > 1 {
1019            let mut matched_commands = Vec::new();
1020            for index in matches {
1021                matched_commands.push(keys[index].clone())
1022            }
1023            Err(GooseError::InvalidControllerCommand {
1024                detail: format!(
1025                    "matched multiple controller commands: '{s}' ({matched_commands:?})."
1026                ),
1027            })
1028        // Only one command matched.
1029        } else {
1030            Ok(keys[*matches.first().unwrap()].clone())
1031        }
1032    }
1033}
1034
1035/// Send a message to the client TcpStream, no prompt or line feed.
1036async fn write_to_socket_raw(socket: &mut tokio::net::TcpStream, message: &str) {
1037    if socket
1038        // Add a linefeed to the end of the message.
1039        .write_all(message.as_bytes())
1040        .await
1041        .is_err()
1042    {
1043        warn!("failed to write data to socket");
1044    }
1045}
1046
1047/// Goose supports two different Controller protocols: telnet and WebSocket.
1048#[derive(Clone, Debug)]
1049pub(crate) enum ControllerProtocol {
1050    /// Allows control of Goose via telnet.
1051    Telnet,
1052    /// Allows control of Goose via a WebSocket.
1053    WebSocket,
1054}
1055
1056/// All commands define their own ControllerHelp to create a help screen.
1057#[derive(Clone, Debug)]
1058pub(crate) struct ControllerHelp<'a> {
1059    // The name of the controller command.
1060    name: &'a str,
1061    // A description of the contorller command.
1062    description: &'a str,
1063}
1064
1065/// Defines the regular expression used to identify a command and optionall the associated
1066/// value, as well as the logic for letting the controller know whether or not the
1067/// recognized command worked correctly.
1068pub(crate) struct ControllerCommandDetails<'a> {
1069    // The name and description of the controller command.
1070    help: ControllerHelp<'a>,
1071    // A [regular expression](https://docs.rs/regex/1.5.5/regex/struct.Regex.html) for
1072    // matching the command and option value.
1073    regex: &'a str,
1074    // The response sent by the controller after handling the incoming controller
1075    // request.
1076    process_response: Box<dyn Fn(ControllerResponseMessage) -> Result<String, String>>,
1077}
1078
1079/// This structure is used to send commands and values to the parent process.
1080#[derive(Debug)]
1081pub(crate) struct ControllerRequestMessage {
1082    /// The command that is being sent to the parent.
1083    pub command: ControllerCommand,
1084    /// An optional value that is being sent to the parent.
1085    pub value: Option<String>,
1086}
1087
1088/// An enumeration of all messages the parent can reply back to the controller thread.
1089#[derive(Debug)]
1090pub(crate) enum ControllerResponseMessage {
1091    /// A response containing a boolean value.
1092    Bool(bool),
1093    /// A response containing the load test configuration.
1094    Config(Box<GooseConfiguration>),
1095    /// A response containing current load test metrics.
1096    Metrics(Box<GooseMetrics>),
1097}
1098
1099/// The request that's passed from the controller to the parent thread.
1100#[derive(Debug)]
1101pub(crate) struct ControllerRequest {
1102    /// Optional one-shot channel if a reply is required.
1103    pub response_channel: Option<tokio::sync::oneshot::Sender<ControllerResponse>>,
1104    /// An integer identifying which controller client is making the request.
1105    pub client_id: u32,
1106    /// The actual request message.
1107    pub request: ControllerRequestMessage,
1108}
1109
1110/// The response that's passed from the parent to the controller.
1111#[derive(Debug)]
1112pub(crate) struct ControllerResponse {
1113    /// An integer identifying which controller the parent is responding to.
1114    pub _client_id: u32,
1115    /// The actual response message.
1116    pub response: ControllerResponseMessage,
1117}
1118
1119/// This structure defines the required json format of any request sent to the WebSocket
1120/// Controller.
1121///
1122/// Requests must be made in the following format:
1123/// ```json
1124/// {
1125///     "request": String,
1126/// }
1127///
1128/// ```
1129///
1130/// The request "String" value must be a valid
1131/// [`ControllerCommand`](./enum.ControllerCommand.html).
1132///
1133/// # Example
1134/// The following request will shut down the load test:
1135/// ```json
1136/// {
1137///     "request": "shutdown",
1138/// }
1139/// ```
1140///
1141/// Responses will be formatted as defined in
1142/// [ControllerWebSocketResponse](./struct.ControllerWebSocketResponse.html).
1143#[derive(Debug, Deserialize, Serialize)]
1144pub struct ControllerWebSocketRequest {
1145    /// A valid command string.
1146    pub request: String,
1147}
1148
1149/// This structure defines the json format of any response returned from the WebSocket
1150/// Controller.
1151///
1152/// Responses are in the following format:
1153/// ```json
1154/// {
1155///     "response": String,
1156///     "success": bool,
1157/// }
1158/// ```
1159///
1160/// # Example
1161/// The following response will be returned when a request is made to shut down the
1162/// load test:
1163/// ```json
1164/// {
1165///     "response": "load test shut down",
1166///     "success": true
1167/// }
1168/// ```
1169///
1170/// Requests must be formatted as defined in
1171/// [ControllerWebSocketRequest](./struct.ControllerWebSocketRequest.html).
1172#[derive(Debug, Deserialize, Serialize)]
1173pub struct ControllerWebSocketResponse {
1174    /// The response from the controller.
1175    pub response: String,
1176    /// Whether the request was successful or not.
1177    pub success: bool,
1178}
1179
1180/// Return type to indicate whether or not to exit the Controller thread.
1181type ControllerExit = bool;
1182
1183/// The telnet Controller message buffer.
1184type ControllerTelnetMessage = [u8; 1024];
1185
1186/// The WebSocket Controller message buffer.
1187type ControllerWebSocketMessage = std::result::Result<
1188    tokio_tungstenite::tungstenite::Message,
1189    tokio_tungstenite::tungstenite::Error,
1190>;
1191
1192/// Simplify the ControllerExecuteCommand trait definition for WebSockets.
1193type ControllerWebSocketSender = futures::stream::SplitSink<
1194    tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
1195    tokio_tungstenite::tungstenite::Message,
1196>;
1197
1198/// This state object is created in the main Controller thread and then passed to the specific
1199/// per-client thread.
1200pub(crate) struct ControllerState {
1201    /// Track which controller-thread this is.
1202    thread_id: u32,
1203    /// Track the ip and port of the connected TCP client.
1204    peer_address: String,
1205    /// A shared channel for communicating with the parent process.
1206    channel_tx: flume::Sender<ControllerRequest>,
1207    /// Which protocol this Controller understands.
1208    protocol: ControllerProtocol,
1209}
1210// Defines functions shared by all Controllers.
1211impl ControllerState {
1212    async fn accept_connections(self, mut socket: tokio::net::TcpStream) {
1213        info!(
1214            "{:?} client [{}] connected from {}",
1215            self.protocol, self.thread_id, self.peer_address
1216        );
1217        match self.protocol {
1218            ControllerProtocol::Telnet => {
1219                let mut buf: ControllerTelnetMessage = [0; 1024];
1220
1221                // Display initial goose> prompt.
1222                write_to_socket_raw(&mut socket, "goose> ").await;
1223
1224                loop {
1225                    // Process data received from the client in a loop.
1226                    let n = match socket.read(&mut buf).await {
1227                        Ok(data) => data,
1228                        Err(_) => {
1229                            info!(
1230                                "Telnet client [{}] disconnected from {}",
1231                                self.thread_id, self.peer_address
1232                            );
1233                            break;
1234                        }
1235                    };
1236
1237                    // Invalid request, exit.
1238                    if n == 0 {
1239                        info!(
1240                            "Telnet client [{}] disconnected from {}",
1241                            self.thread_id, self.peer_address
1242                        );
1243                        break;
1244                    }
1245
1246                    // Extract the command string in a protocol-specific way.
1247                    if let Ok(command_string) = self.get_command_string(buf).await {
1248                        // Extract the command and value in a generic way.
1249                        if let Ok(request_message) = self.get_match(command_string.trim()).await {
1250                            // Act on the commmand received.
1251                            if self.execute_command(&mut socket, request_message).await {
1252                                // If execute_command returns true, it's time to exit.
1253                                info!(
1254                                    "Telnet client [{}] disconnected from {}",
1255                                    self.thread_id, self.peer_address
1256                                );
1257                                break;
1258                            }
1259                        } else {
1260                            self.write_to_socket(
1261                                &mut socket,
1262                                Err("unrecognized command".to_string()),
1263                            )
1264                            .await;
1265                        }
1266                    } else {
1267                        // Corrupted request from telnet client, exit.
1268                        info!(
1269                            "Telnet client [{}] disconnected from {}",
1270                            self.thread_id, self.peer_address
1271                        );
1272                        break;
1273                    }
1274                }
1275            }
1276            ControllerProtocol::WebSocket => {
1277                let stream = match tokio_tungstenite::accept_async(socket).await {
1278                    Ok(s) => s,
1279                    Err(e) => {
1280                        info!("invalid WebSocket handshake: {e}");
1281                        return;
1282                    }
1283                };
1284                let (mut ws_sender, mut ws_receiver) = stream.split();
1285
1286                loop {
1287                    // Wait until the client sends a command.
1288                    let data = match ws_receiver.next().await {
1289                        Some(d) => d,
1290                        None => {
1291                            // Returning with no data means the client disconnected.
1292                            info!(
1293                                "WebSocket client [{}] disconnected from {}",
1294                                self.thread_id, self.peer_address
1295                            );
1296                            break;
1297                        }
1298                    };
1299
1300                    // Extract the command string in a protocol-specific way.
1301                    if let Ok(command_string) = self.get_command_string(data).await {
1302                        // Extract the command and value in a generic way.
1303                        if let Ok(request_message) = self.get_match(command_string.trim()).await {
1304                            if self.execute_command(&mut ws_sender, request_message).await {
1305                                // If execute_command() returns true, it's time to exit.
1306                                info!(
1307                                    "WebSocket client [{}] disconnected from {}",
1308                                    self.thread_id, self.peer_address
1309                                );
1310                                break;
1311                            }
1312                        } else {
1313                            self.write_to_socket(
1314                                &mut ws_sender,
1315                                Err(
1316                                    "unrecognized command, see Goose book https://book.goose.rs/controller/websocket.html"
1317                                        .to_string(),
1318                                ),
1319                            )
1320                            .await;
1321                        }
1322                    } else {
1323                        self.write_to_socket(
1324                            &mut ws_sender,
1325                            Err(
1326                                "invalid json, see Goose book https://book.goose.rs/controller/websocket.html"
1327                                    .to_string(),
1328                            ),
1329                        )
1330                        .await;
1331                    }
1332                }
1333            }
1334        }
1335    }
1336
1337    // Both Controllers use a common function to identify commands.
1338    async fn get_match(
1339        &self,
1340        command_string: &str,
1341    ) -> Result<ControllerRequestMessage, GooseError> {
1342        // Use FromStr to convert &str to ControllerCommand.
1343        let command: ControllerCommand = ControllerCommand::from_str(command_string)?;
1344        // Extract value if there is one, otherwise will be None.
1345        let value: Option<String> = command.get_value(command_string);
1346
1347        Ok(ControllerRequestMessage { command, value })
1348    }
1349
1350    /// Process a request entirely within the Controller thread, without sending a message
1351    /// to the parent thread.
1352    fn process_local_command(&self, request_message: &ControllerRequestMessage) -> Option<String> {
1353        match request_message.command {
1354            ControllerCommand::Help => Some(ControllerCommand::display_help()),
1355            ControllerCommand::Exit => Some("goodbye!".to_string()),
1356            // All other commands require sending the request to the parent thread.
1357            _ => None,
1358        }
1359    }
1360
1361    /// Send a message to parent thread, with or without an optional value, and wait for
1362    /// a reply.
1363    async fn process_command(
1364        &self,
1365        request: ControllerRequestMessage,
1366    ) -> Result<ControllerResponseMessage, String> {
1367        // Create a one-shot channel to allow the parent to reply to our request. As flume
1368        // doesn't implement a one-shot channel, we use tokio for this temporary channel.
1369        let (response_tx, response_rx): (
1370            tokio::sync::oneshot::Sender<ControllerResponse>,
1371            tokio::sync::oneshot::Receiver<ControllerResponse>,
1372        ) = tokio::sync::oneshot::channel();
1373
1374        if self
1375            .channel_tx
1376            .try_send(ControllerRequest {
1377                response_channel: Some(response_tx),
1378                client_id: self.thread_id,
1379                request,
1380            })
1381            .is_err()
1382        {
1383            return Err("parent process has closed the controller channel".to_string());
1384        }
1385
1386        // Await response from parent.
1387        match response_rx.await {
1388            Ok(value) => Ok(value.response),
1389            Err(e) => Err(format!("one-shot channel dropped without reply: {e}")),
1390        }
1391    }
1392}
1393
1394/// Controller-protocol-specific functions, necessary to manage the different way each
1395/// Controller protocol communicates with a client.
1396#[async_trait]
1397trait Controller<T> {
1398    // Extract the command string from a Controller client request.
1399    async fn get_command_string(&self, raw_value: T) -> Result<String, String>;
1400}
1401#[async_trait]
1402impl Controller<ControllerTelnetMessage> for ControllerState {
1403    // Extract the command string from a telnet Controller client request.
1404    async fn get_command_string(
1405        &self,
1406        raw_value: ControllerTelnetMessage,
1407    ) -> Result<String, String> {
1408        let command_string = match str::from_utf8(&raw_value) {
1409            Ok(m) => m.lines().next().unwrap_or_default(),
1410            Err(e) => {
1411                let error = format!("ignoring unexpected input from telnet controller: {e}");
1412                info!("{error}");
1413                return Err(error);
1414            }
1415        };
1416
1417        Ok(command_string.to_string())
1418    }
1419}
1420#[async_trait]
1421impl Controller<ControllerWebSocketMessage> for ControllerState {
1422    // Extract the command string from a WebSocket Controller client request.
1423    async fn get_command_string(
1424        &self,
1425        raw_value: ControllerWebSocketMessage,
1426    ) -> Result<String, String> {
1427        if let Ok(request) = raw_value {
1428            if request.is_text() {
1429                if let Ok(request) = request.into_text() {
1430                    debug!("websocket request: {:?}", request.trim());
1431                    let command_string: ControllerWebSocketRequest =
1432                        match serde_json::from_str(&request) {
1433                            Ok(c) => c,
1434                            Err(_) => {
1435                                return Err("invalid json, see Goose book https://book.goose.rs/controller/websocket.html"
1436                                    .to_string())
1437                            }
1438                        };
1439                    return Ok(command_string.request);
1440                } else {
1441                    // Failed to consume the WebSocket message and convert it to a String.
1442                    return Err("unsupported string format".to_string());
1443                }
1444            } else {
1445                // Received a non-text WebSocket message.
1446                return Err("unsupported format, requests must be sent as text".to_string());
1447            }
1448        }
1449        // Improper WebSocket handshake.
1450        Err("WebSocket handshake error".to_string())
1451    }
1452}
1453#[async_trait]
1454trait ControllerExecuteCommand<T> {
1455    // Run the command received from a Controller request. Returns a boolean, if true exit.
1456    async fn execute_command(
1457        &self,
1458        socket: &mut T,
1459        request_message: ControllerRequestMessage,
1460    ) -> ControllerExit;
1461
1462    // Send response to Controller client. The response is wrapped in a Result to indicate
1463    // if the request was successful or not.
1464    async fn write_to_socket(&self, socket: &mut T, response_message: Result<String, String>);
1465}
1466
1467#[async_trait]
1468impl ControllerExecuteCommand<tokio::net::TcpStream> for ControllerState {
1469    // Run the command received from a telnet Controller request.
1470    async fn execute_command(
1471        &self,
1472        socket: &mut tokio::net::TcpStream,
1473        request_message: ControllerRequestMessage,
1474    ) -> ControllerExit {
1475        // First handle commands that don't require interaction with the parent process.
1476        if let Some(message) = self.process_local_command(&request_message) {
1477            self.write_to_socket(socket, Ok(message)).await;
1478            // If Exit was received return true to exit, otherwise return false.
1479            return request_message.command == ControllerCommand::Exit;
1480        }
1481
1482        // Retain a copy of the command used when processing the parent response.
1483        let command = request_message.command.clone();
1484
1485        // Now handle commands that require interaction with the parent process.
1486        let response = match self.process_command(request_message).await {
1487            Ok(r) => r,
1488            Err(e) => {
1489                // Receiving an error here means the parent closed the communication
1490                // channel. Write the error to the Controller client and then return
1491                // true to exit.
1492                self.write_to_socket(socket, Err(e)).await;
1493                return true;
1494            }
1495        };
1496
1497        // If Shutdown command was received return true to exit, otherwise return false.
1498        let exit_controller = command == ControllerCommand::Shutdown;
1499
1500        // Write the response to the Controller client socket.
1501        let processed_response = (command.details().process_response)(response);
1502        self.write_to_socket(socket, processed_response).await;
1503
1504        // Return true if it's time to exit the Controller.
1505        exit_controller
1506    }
1507
1508    // Send response to telnet Controller client.
1509    async fn write_to_socket(
1510        &self,
1511        socket: &mut tokio::net::TcpStream,
1512        message: Result<String, String>,
1513    ) {
1514        // Send result to telnet Controller client, whether Ok() or Err().
1515        let response_message = match message {
1516            Ok(m) => m,
1517            Err(e) => e,
1518        };
1519        if socket
1520            // Add a linefeed to the end of the message, followed by a prompt.
1521            .write_all([&response_message, "\ngoose> "].concat().as_bytes())
1522            .await
1523            .is_err()
1524        {
1525            warn!("failed to write data to socker");
1526        };
1527    }
1528}
1529
1530#[async_trait]
1531impl ControllerExecuteCommand<ControllerWebSocketSender> for ControllerState {
1532    // Run the command received from a WebSocket Controller request.
1533    async fn execute_command(
1534        &self,
1535        socket: &mut ControllerWebSocketSender,
1536        request_message: ControllerRequestMessage,
1537    ) -> ControllerExit {
1538        // First handle commands that don't require interaction with the parent process.
1539        if let Some(message) = self.process_local_command(&request_message) {
1540            self.write_to_socket(socket, Ok(message)).await;
1541
1542            // If Exit was received return true to exit, otherwise return false.
1543            let exit_controller = request_message.command == ControllerCommand::Exit;
1544            // If exiting, notify the WebSocket client that this connection is closing.
1545            if exit_controller
1546                && socket
1547                    .send(Message::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
1548                        code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
1549                        reason: "exit".into(),
1550                    })))
1551                    .await
1552                    .is_err()
1553            {
1554                warn!("failed to write data to stream");
1555            }
1556
1557            return exit_controller;
1558        }
1559
1560        // WebSocket Controller always returns JSON, convert command where necessary.
1561        let command = match request_message.command {
1562            ControllerCommand::Config => ControllerCommand::ConfigJson,
1563            ControllerCommand::Metrics => ControllerCommand::MetricsJson,
1564            _ => request_message.command.clone(),
1565        };
1566
1567        // Now handle commands that require interaction with the parent process.
1568        let response = match self.process_command(request_message).await {
1569            Ok(r) => r,
1570            Err(e) => {
1571                // Receiving an error here means the parent closed the communication
1572                // channel. Write the error to the Controller client and then return
1573                // true to exit.
1574                self.write_to_socket(socket, Err(e)).await;
1575                return true;
1576            }
1577        };
1578
1579        // If Shutdown command was received return true to exit, otherwise return false.
1580        let exit_controller = command == ControllerCommand::Shutdown;
1581
1582        // Write the response to the Controller client socket.
1583        let processed_response = (command.details().process_response)(response);
1584        self.write_to_socket(socket, processed_response).await;
1585
1586        // If exiting, notify the WebSocket client that this connection is closing.
1587        if exit_controller
1588            && socket
1589                .send(Message::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
1590                    code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
1591                    reason: "shutdown".into(),
1592                })))
1593                .await
1594                .is_err()
1595        {
1596            warn!("failed to write data to stream");
1597        }
1598
1599        // Return true if it's time to exit the Controller.
1600        exit_controller
1601    }
1602
1603    // Send a json-formatted response to the WebSocket.
1604    async fn write_to_socket(
1605        &self,
1606        socket: &mut ControllerWebSocketSender,
1607        response_result: Result<String, String>,
1608    ) {
1609        let success;
1610        let response = match response_result {
1611            Ok(m) => {
1612                success = true;
1613                m
1614            }
1615            Err(e) => {
1616                success = false;
1617                e
1618            }
1619        };
1620        if let Err(e) = socket
1621            .send(Message::Text(
1622                match serde_json::to_string(&ControllerWebSocketResponse {
1623                    response,
1624                    // Success is true if there is no error, false if there is an error.
1625                    success,
1626                }) {
1627                    Ok(json) => json.into(),
1628                    Err(e) => {
1629                        warn!("failed to json encode response: {e}");
1630                        return;
1631                    }
1632                },
1633            ))
1634            .await
1635        {
1636            info!("failed to write data to websocket: {e}");
1637        }
1638    }
1639}