use crate::config::GooseConfiguration;
use crate::metrics::GooseMetrics;
use crate::test_plan::{TestPlan, TestPlanHistory, TestPlanStepAction};
use crate::util;
use crate::{AttackPhase, GooseAttack, GooseAttackRunState, GooseError};
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use regex::{Regex, RegexSet};
use serde::{Deserialize, Serialize};
use std::io::{self, Write};
use std::str::{self, FromStr};
use std::time::Duration;
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_tungstenite::tungstenite::Message;
#[derive(Clone, Debug, EnumIter, PartialEq, Eq)]
pub enum ControllerCommand {
Help,
Exit,
Start,
Stop,
Shutdown,
Host,
HatchRate,
StartupTime,
Users,
RunTime,
TestPlan,
Config,
ConfigJson,
Metrics,
MetricsJson,
}
impl ControllerCommand {
fn details(&self) -> ControllerCommandDetails<'_> {
match self {
ControllerCommand::Config => ControllerCommandDetails {
help: ControllerHelp {
name: "config",
description: "display load test configuration\n",
},
regex: r"(?i)^config$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Config(config) = response {
Ok(format!("{config:#?}"))
} else {
Err("error loading configuration".to_string())
}
}),
},
ControllerCommand::ConfigJson => ControllerCommandDetails {
help: ControllerHelp {
name: "config-json",
description: "display load test configuration in json format\n",
},
regex: r"(?i)^(configjson|config-json)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Config(config) = response {
Ok(serde_json::to_string(&config).expect("unexpected serde failure"))
} else {
Err("error loading configuration".to_string())
}
}),
},
ControllerCommand::Exit => ControllerCommandDetails {
help: ControllerHelp {
name: "exit",
description: "exit controller\n\n",
},
regex: r"(?i)^(exit|quit|q)$",
process_response: Box::new(|_| {
let e = "received an impossible EXIT command";
error!("{e}");
Err(e.to_string())
}),
},
ControllerCommand::HatchRate => ControllerCommandDetails {
help: ControllerHelp {
name: "hatchrate FLOAT",
description: "set per-second rate users hatch\n",
},
regex: r"(?i)^(hatchrate|hatch_rate|hatch-rate) ([0-9]*(\.[0-9]*)?){1}$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("hatch_rate configured".to_string())
} else {
Err("failed to configure hatch_rate".to_string())
}
}),
},
ControllerCommand::Help => ControllerCommandDetails {
help: ControllerHelp {
name: "help",
description: "this help\n",
},
regex: r"(?i)^(help|\?)$",
process_response: Box::new(|_| {
let e = "received an impossible HELP command";
error!("{e}");
Err(e.to_string())
}),
},
ControllerCommand::Host => ControllerCommandDetails {
help: ControllerHelp {
name: "host HOST",
description: "set host to load test, (ie https://web.site/)\n",
},
regex: r"(?i)^(host|hostname|host_name|host-name) ((https?)://.+)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("host configured".to_string())
} else {
Err("failed to reconfigure host, be sure host is valid and load test is idle".to_string())
}
}),
},
ControllerCommand::Metrics => ControllerCommandDetails {
help: ControllerHelp {
name: "metrics",
description: "display metrics for current load test\n",
},
regex: r"(?i)^(metrics|stats)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Metrics(metrics) = response {
Ok(metrics.to_string())
} else {
Err("error loading metrics".to_string())
}
}),
},
ControllerCommand::MetricsJson => {
ControllerCommandDetails {
help: ControllerHelp {
name: "metrics-json",
description: "display metrics for current load test in json format",
},
regex: r"(?i)^(metricsjson|metrics-json|statsjson|stats-json)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Metrics(metrics) = response {
Ok(serde_json::to_string(&metrics).expect("unexpected serde failure"))
} else {
Err("error loading metrics".to_string())
}
}),
}
}
ControllerCommand::RunTime => ControllerCommandDetails {
help: ControllerHelp {
name: "runtime TIME",
description: "set how long to run test, (ie 1h30m5s)\n",
},
regex: r"(?i)^(run|runtime|run_time|run-time|) (\d+|((\d+?)h)?((\d+?)m)?((\d+?)s)?)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("run_time configured".to_string())
} else {
Err("failed to configure run_time".to_string())
}
}),
},
ControllerCommand::Shutdown => ControllerCommandDetails {
help: ControllerHelp {
name: "shutdown",
description: "shutdown load test and exit controller\n\n",
},
regex: r"(?i)^shutdown$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("load test shut down".to_string())
} else {
Err("failed to shut down load test".to_string())
}
}),
},
ControllerCommand::Start => {
ControllerCommandDetails {
help: ControllerHelp {
name: "start",
description: "start an idle load test\n",
},
regex: r"(?i)^start$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("load test started".to_string())
} else {
Err("unable to start load test, be sure it is idle and host is configured".to_string())
}
}),
}
}
ControllerCommand::StartupTime => ControllerCommandDetails {
help: ControllerHelp {
name: "startup-time TIME",
description: "set total time to take starting users\n",
},
regex: r"(?i)^(starttime|start_time|start-time|startup|startuptime|startup_time|startup-time) (\d+|((\d+?)h)?((\d+?)m)?((\d+?)s)?)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("startup_time configured".to_string())
} else {
Err(
"failed to configure startup_time, be sure load test is idle"
.to_string(),
)
}
}),
},
ControllerCommand::Stop => ControllerCommandDetails {
help: ControllerHelp {
name: "stop",
description: "stop a running load test and return to idle state\n",
},
regex: r"(?i)^stop$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("load test stopped".to_string())
} else {
Err("load test not running, failed to stop".to_string())
}
}),
},
ControllerCommand::TestPlan => ControllerCommandDetails {
help: ControllerHelp {
name: "test-plan PLAN",
description: "define or replace test-plan, (ie 10,5m;10,1h;0,30s)\n\n",
},
regex: r"(?i)^(testplan|test_plan|test-plan|plan) (((\d+)\s*,\s*(\d+|((\d+?)h)?((\d+?)m)?((\d+?)s)?)*;*)+)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("test-plan configured".to_string())
} else {
Err("failed to configure test-plan, be sure test-plan is valid".to_string())
}
}),
},
ControllerCommand::Users => ControllerCommandDetails {
help: ControllerHelp {
name: "users INT",
description: "set number of simulated users\n",
},
regex: r"(?i)^(users?) (\d+)$",
process_response: Box::new(|response| {
if let ControllerResponseMessage::Bool(true) = response {
Ok("users configured".to_string())
} else {
Err("load test not idle, failed to reconfigure users".to_string())
}
}),
},
}
}
fn validate_value(&self, value: &str) -> Option<String> {
if self == &ControllerCommand::Host {
if util::is_valid_host(value).is_ok() {
Some(value.to_string())
} else {
None
}
} else if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
fn get_value(&self, command_string: &str) -> Option<String> {
let regex = Regex::new(self.details().regex)
.expect("ControllerCommand::details().regex returned invalid regex [2]");
let caps = regex.captures(command_string).unwrap();
let value = caps.get(2).map_or("", |m| m.as_str());
self.validate_value(value)
}
fn display_help() -> String {
let mut help_text = Vec::new();
writeln!(
&mut help_text,
"{} {} controller commands:",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION")
)
.expect("failed to write to buffer");
for command in ControllerCommand::iter() {
write!(
&mut help_text,
"{:<18} {}",
command.details().help.name,
command.details().help.description
)
.expect("failed to write to buffer");
}
String::from_utf8(help_text).expect("invalid utf-8 in help text")
}
}
impl GooseAttack {
pub(crate) async fn handle_controller_requests(
&mut self,
goose_attack_run_state: &mut GooseAttackRunState,
) -> Result<(), GooseError> {
if let Some(c) = goose_attack_run_state.controller_channel_rx.as_ref() {
match c.try_recv() {
Ok(message) => {
info!(
"request from controller client {}: {:?}",
message.client_id, message.request
);
match &message.request.command {
ControllerCommand::Config | ControllerCommand::ConfigJson => {
self.reply_to_controller(
message,
ControllerResponseMessage::Config(Box::new(
self.configuration.clone(),
)),
);
}
ControllerCommand::Metrics | ControllerCommand::MetricsJson => {
self.reply_to_controller(
message,
ControllerResponseMessage::Metrics(Box::new(self.metrics.clone())),
);
}
ControllerCommand::Start => {
if self.attack_phase == AttackPhase::Idle {
self.test_plan = TestPlan::build(&self.configuration);
if self.prepare_load_test().is_ok() {
self.set_attack_phase(
goose_attack_run_state,
AttackPhase::Increase,
);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
self.reset_run_state(goose_attack_run_state).await?;
self.metrics.history.push(TestPlanHistory::step(
TestPlanStepAction::Increasing,
0,
));
} else {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
} else {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
ControllerCommand::Stop => {
if [AttackPhase::Increase, AttackPhase::Maintain]
.contains(&self.attack_phase)
{
goose_attack_run_state.shutdown_after_stop = false;
self.configuration.no_autostart = true;
self.cancel_attack(goose_attack_run_state).await?;
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
} else {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
ControllerCommand::Shutdown => {
if self.attack_phase == AttackPhase::Idle {
self.metrics.display_metrics = false;
self.set_attack_phase(
goose_attack_run_state,
AttackPhase::Decrease,
);
} else {
self.cancel_attack(goose_attack_run_state).await?;
}
goose_attack_run_state.shutdown_after_stop = true;
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
tokio::time::sleep(Duration::from_millis(250)).await;
}
ControllerCommand::Host => {
if self.attack_phase == AttackPhase::Idle {
if let Some(host) = &message.request.value {
info!(
"changing host from {:?} to {}",
self.configuration.host, host
);
self.configuration.host = host.to_string();
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
} else {
debug!(
"controller didn't provide host: {:#?}",
&message.request
);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
} else {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
ControllerCommand::Users => {
if let Some(users) = &message.request.value {
let new_users = usize::from_str(users)
.expect("failed to convert string to usize");
self.configuration.test_plan = None;
match self.attack_phase {
AttackPhase::Idle => {
let current_users = if !self.test_plan.steps.is_empty() {
self.test_plan.steps[self.test_plan.current].0
} else {
self.configuration.users.unwrap_or_default()
};
info!(
"changing users from {current_users:?} to {new_users}"
);
self.configuration.users = Some(new_users);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
}
AttackPhase::Increase
| AttackPhase::Decrease
| AttackPhase::Maintain => {
info!(
"changing users from {} to {new_users}",
goose_attack_run_state.active_users
);
let elapsed = self.step_elapsed() as usize;
let hatch_rate = if let Some(hatch_rate) =
self.configuration.hatch_rate.as_ref()
{
util::get_hatch_rate(Some(hatch_rate.to_string()))
} else {
util::get_hatch_rate(None)
};
let ms_hatch_rate = 1.0 / hatch_rate * 1_000.0;
let user_difference = (goose_attack_run_state.active_users
as isize
- new_users as isize)
.abs();
let total_time =
(ms_hatch_rate * user_difference as f32) as usize;
self.test_plan.steps = vec![
(goose_attack_run_state.active_users, elapsed),
(new_users, total_time),
];
self.test_plan.current = 0;
if new_users > goose_attack_run_state.active_users {
self.weighted_users = self
.weight_scenario_users(user_difference as usize)?;
}
self.configuration.users = Some(new_users);
self.advance_test_plan(goose_attack_run_state);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
}
_ => {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
} else {
warn!(
"[controller]: didn't provide users: {:#?}",
&message.request
);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
ControllerCommand::HatchRate => {
if let Some(hatch_rate) = &message.request.value {
if !self.configuration.startup_time.is_empty() {
info!(
"resetting startup_time from {} to 0",
self.configuration.startup_time
);
self.configuration.startup_time = "0".to_string();
}
info!(
"changing hatch_rate from {:?} to {}",
self.configuration.hatch_rate, hatch_rate
);
self.configuration.hatch_rate = Some(hatch_rate.clone());
self.configuration.test_plan = None;
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
} else {
warn!(
"Controller didn't provide hatch_rate: {:#?}",
&message.request
);
}
}
ControllerCommand::StartupTime => {
if self.attack_phase == AttackPhase::Idle {
if let Some(startup_time) = &message.request.value {
if let Some(hatch_rate) = &self.configuration.hatch_rate {
info!("resetting hatch_rate from {hatch_rate} to None");
self.configuration.hatch_rate = None;
}
info!(
"changing startup_rate from {} to {}",
self.configuration.startup_time, startup_time
);
self.configuration.startup_time.clone_from(startup_time);
self.configuration.test_plan = None;
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
} else {
warn!(
"Controller didn't provide startup_time: {:#?}",
&message.request
);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
} else {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
ControllerCommand::RunTime => {
if let Some(run_time) = &message.request.value {
info!(
"changing run_time from {:?} to {}",
self.configuration.run_time, run_time
);
self.configuration.run_time.clone_from(run_time);
self.configuration.test_plan = None;
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
} else {
warn!(
"Controller didn't provide run_time: {:#?}",
&message.request
);
}
}
ControllerCommand::TestPlan => {
if let Some(value) = &message.request.value {
match value.parse::<TestPlan>() {
Ok(t) => {
self.configuration.test_plan = Some(t.clone());
self.configuration.users = None;
self.configuration.hatch_rate = None;
self.configuration.startup_time = "0".to_string();
self.configuration.run_time = "0".to_string();
match self.attack_phase {
AttackPhase::Idle => {
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
}
AttackPhase::Increase
| AttackPhase::Decrease
| AttackPhase::Maintain => {
self.test_plan = t;
self.weighted_users = self.weight_scenario_users(
self.test_plan.total_users(),
)?;
let elapsed = self.step_elapsed() as usize;
self.test_plan.steps.insert(
0,
(goose_attack_run_state.active_users, elapsed),
);
self.advance_test_plan(goose_attack_run_state);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(true),
);
}
_ => {
unreachable!("Controller used in impossible phase.")
}
}
}
Err(e) => {
warn!("Controller provided invalid test_plan: {e}");
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
} else {
warn!(
"Controller didn't provide test_plan: {:#?}",
&message.request
);
self.reply_to_controller(
message,
ControllerResponseMessage::Bool(false),
);
}
}
ControllerCommand::Help | ControllerCommand::Exit => {
warn!("Unexpected command: {:?}", &message.request);
}
}
}
Err(e) => {
debug!("error receiving message: {e}");
}
}
};
Ok(())
}
pub(crate) fn reply_to_controller(
&mut self,
request: ControllerRequest,
response: ControllerResponseMessage,
) {
if let Some(oneshot_tx) = request.response_channel {
if oneshot_tx
.send(ControllerResponse {
_client_id: request.client_id,
response,
})
.is_err()
{
warn!("failed to send response to controller via one-shot channel")
}
}
}
}
pub(crate) async fn controller_main(
configuration: GooseConfiguration,
channel_tx: flume::Sender<ControllerRequest>,
protocol: ControllerProtocol,
) -> io::Result<()> {
let address = match &protocol {
ControllerProtocol::Telnet => format!(
"{}:{}",
configuration.telnet_host, configuration.telnet_port
),
ControllerProtocol::WebSocket => format!(
"{}:{}",
configuration.websocket_host, configuration.websocket_port
),
};
debug!("[controller]: preparing to bind {protocol:?} to: {address}");
let listener = TcpListener::bind(&address).await?;
info!("[controller]: {protocol:?} listening on: {address}");
let mut thread_id: u32 = 0;
while let Ok((stream, _)) = listener.accept().await {
thread_id += 1;
let peer_address = stream
.peer_addr()
.map_or("UNKNOWN ADDRESS".to_string(), |p| p.to_string());
let controller_state = ControllerState {
thread_id,
peer_address,
channel_tx: channel_tx.clone(),
protocol: protocol.clone(),
};
let _ignored_joinhandle = tokio::spawn(controller_state.accept_connections(stream));
}
Ok(())
}
impl FromStr for ControllerCommand {
type Err = GooseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut regex_set: Vec<String> = Vec::new();
let mut keys = Vec::new();
for t in ControllerCommand::iter() {
keys.push(t.clone());
regex_set.push(t.details().regex.to_string());
}
let commands = RegexSet::new(regex_set)
.expect("ControllerCommand::details().regex returned invalid regex");
let matches: Vec<_> = commands.matches(s).into_iter().collect();
if matches.is_empty() {
Err(GooseError::InvalidControllerCommand {
detail: format!("unrecognized controller command: '{s}'."),
})
} else if matches.len() > 1 {
let mut matched_commands = Vec::new();
for index in matches {
matched_commands.push(keys[index].clone())
}
Err(GooseError::InvalidControllerCommand {
detail: format!(
"matched multiple controller commands: '{s}' ({matched_commands:?})."
),
})
} else {
Ok(keys[*matches.first().unwrap()].clone())
}
}
}
async fn write_to_socket_raw(socket: &mut tokio::net::TcpStream, message: &str) {
if socket
.write_all(message.as_bytes())
.await
.is_err()
{
warn!("failed to write data to socket");
}
}
#[derive(Clone, Debug)]
pub(crate) enum ControllerProtocol {
Telnet,
WebSocket,
}
#[derive(Clone, Debug)]
pub(crate) struct ControllerHelp<'a> {
name: &'a str,
description: &'a str,
}
pub(crate) struct ControllerCommandDetails<'a> {
help: ControllerHelp<'a>,
regex: &'a str,
process_response: Box<dyn Fn(ControllerResponseMessage) -> Result<String, String>>,
}
#[derive(Debug)]
pub(crate) struct ControllerRequestMessage {
pub command: ControllerCommand,
pub value: Option<String>,
}
#[derive(Debug)]
pub(crate) enum ControllerResponseMessage {
Bool(bool),
Config(Box<GooseConfiguration>),
Metrics(Box<GooseMetrics>),
}
#[derive(Debug)]
pub(crate) struct ControllerRequest {
pub response_channel: Option<tokio::sync::oneshot::Sender<ControllerResponse>>,
pub client_id: u32,
pub request: ControllerRequestMessage,
}
#[derive(Debug)]
pub(crate) struct ControllerResponse {
pub _client_id: u32,
pub response: ControllerResponseMessage,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ControllerWebSocketRequest {
pub request: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ControllerWebSocketResponse {
pub response: String,
pub success: bool,
}
type ControllerExit = bool;
type ControllerTelnetMessage = [u8; 1024];
type ControllerWebSocketMessage = std::result::Result<
tokio_tungstenite::tungstenite::Message,
tokio_tungstenite::tungstenite::Error,
>;
type ControllerWebSocketSender = futures::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
tokio_tungstenite::tungstenite::Message,
>;
pub(crate) struct ControllerState {
thread_id: u32,
peer_address: String,
channel_tx: flume::Sender<ControllerRequest>,
protocol: ControllerProtocol,
}
impl ControllerState {
async fn accept_connections(self, mut socket: tokio::net::TcpStream) {
info!(
"{:?} client [{}] connected from {}",
self.protocol, self.thread_id, self.peer_address
);
match self.protocol {
ControllerProtocol::Telnet => {
let mut buf: ControllerTelnetMessage = [0; 1024];
write_to_socket_raw(&mut socket, "goose> ").await;
loop {
let n = match socket.read(&mut buf).await {
Ok(data) => data,
Err(_) => {
info!(
"Telnet client [{}] disconnected from {}",
self.thread_id, self.peer_address
);
break;
}
};
if n == 0 {
info!(
"Telnet client [{}] disconnected from {}",
self.thread_id, self.peer_address
);
break;
}
if let Ok(command_string) = self.get_command_string(buf).await {
if let Ok(request_message) = self.get_match(command_string.trim()).await {
if self.execute_command(&mut socket, request_message).await {
info!(
"Telnet client [{}] disconnected from {}",
self.thread_id, self.peer_address
);
break;
}
} else {
self.write_to_socket(
&mut socket,
Err("unrecognized command".to_string()),
)
.await;
}
} else {
info!(
"Telnet client [{}] disconnected from {}",
self.thread_id, self.peer_address
);
break;
}
}
}
ControllerProtocol::WebSocket => {
let stream = match tokio_tungstenite::accept_async(socket).await {
Ok(s) => s,
Err(e) => {
info!("invalid WebSocket handshake: {e}");
return;
}
};
let (mut ws_sender, mut ws_receiver) = stream.split();
loop {
let data = match ws_receiver.next().await {
Some(d) => d,
None => {
info!(
"WebSocket client [{}] disconnected from {}",
self.thread_id, self.peer_address
);
break;
}
};
if let Ok(command_string) = self.get_command_string(data).await {
if let Ok(request_message) = self.get_match(command_string.trim()).await {
if self.execute_command(&mut ws_sender, request_message).await {
info!(
"WebSocket client [{}] disconnected from {}",
self.thread_id, self.peer_address
);
break;
}
} else {
self.write_to_socket(
&mut ws_sender,
Err(
"unrecognized command, see Goose book https://book.goose.rs/controller/websocket.html"
.to_string(),
),
)
.await;
}
} else {
self.write_to_socket(
&mut ws_sender,
Err(
"invalid json, see Goose book https://book.goose.rs/controller/websocket.html"
.to_string(),
),
)
.await;
}
}
}
}
}
async fn get_match(
&self,
command_string: &str,
) -> Result<ControllerRequestMessage, GooseError> {
let command: ControllerCommand = ControllerCommand::from_str(command_string)?;
let value: Option<String> = command.get_value(command_string);
Ok(ControllerRequestMessage { command, value })
}
fn process_local_command(&self, request_message: &ControllerRequestMessage) -> Option<String> {
match request_message.command {
ControllerCommand::Help => Some(ControllerCommand::display_help()),
ControllerCommand::Exit => Some("goodbye!".to_string()),
_ => None,
}
}
async fn process_command(
&self,
request: ControllerRequestMessage,
) -> Result<ControllerResponseMessage, String> {
let (response_tx, response_rx): (
tokio::sync::oneshot::Sender<ControllerResponse>,
tokio::sync::oneshot::Receiver<ControllerResponse>,
) = tokio::sync::oneshot::channel();
if self
.channel_tx
.try_send(ControllerRequest {
response_channel: Some(response_tx),
client_id: self.thread_id,
request,
})
.is_err()
{
return Err("parent process has closed the controller channel".to_string());
}
match response_rx.await {
Ok(value) => Ok(value.response),
Err(e) => Err(format!("one-shot channel dropped without reply: {e}")),
}
}
}
#[async_trait]
trait Controller<T> {
async fn get_command_string(&self, raw_value: T) -> Result<String, String>;
}
#[async_trait]
impl Controller<ControllerTelnetMessage> for ControllerState {
async fn get_command_string(
&self,
raw_value: ControllerTelnetMessage,
) -> Result<String, String> {
let command_string = match str::from_utf8(&raw_value) {
Ok(m) => m.lines().next().unwrap_or_default(),
Err(e) => {
let error = format!("ignoring unexpected input from telnet controller: {e}");
info!("{error}");
return Err(error);
}
};
Ok(command_string.to_string())
}
}
#[async_trait]
impl Controller<ControllerWebSocketMessage> for ControllerState {
async fn get_command_string(
&self,
raw_value: ControllerWebSocketMessage,
) -> Result<String, String> {
if let Ok(request) = raw_value {
if request.is_text() {
if let Ok(request) = request.into_text() {
debug!("websocket request: {:?}", request.trim());
let command_string: ControllerWebSocketRequest =
match serde_json::from_str(&request) {
Ok(c) => c,
Err(_) => {
return Err("invalid json, see Goose book https://book.goose.rs/controller/websocket.html"
.to_string())
}
};
return Ok(command_string.request);
} else {
return Err("unsupported string format".to_string());
}
} else {
return Err("unsupported format, requests must be sent as text".to_string());
}
}
Err("WebSocket handshake error".to_string())
}
}
#[async_trait]
trait ControllerExecuteCommand<T> {
async fn execute_command(
&self,
socket: &mut T,
request_message: ControllerRequestMessage,
) -> ControllerExit;
async fn write_to_socket(&self, socket: &mut T, response_message: Result<String, String>);
}
#[async_trait]
impl ControllerExecuteCommand<tokio::net::TcpStream> for ControllerState {
async fn execute_command(
&self,
socket: &mut tokio::net::TcpStream,
request_message: ControllerRequestMessage,
) -> ControllerExit {
if let Some(message) = self.process_local_command(&request_message) {
self.write_to_socket(socket, Ok(message)).await;
return request_message.command == ControllerCommand::Exit;
}
let command = request_message.command.clone();
let response = match self.process_command(request_message).await {
Ok(r) => r,
Err(e) => {
self.write_to_socket(socket, Err(e)).await;
return true;
}
};
let exit_controller = command == ControllerCommand::Shutdown;
let processed_response = (command.details().process_response)(response);
self.write_to_socket(socket, processed_response).await;
exit_controller
}
async fn write_to_socket(
&self,
socket: &mut tokio::net::TcpStream,
message: Result<String, String>,
) {
let response_message = match message {
Ok(m) => m,
Err(e) => e,
};
if socket
.write_all([&response_message, "\ngoose> "].concat().as_bytes())
.await
.is_err()
{
warn!("failed to write data to socker");
};
}
}
#[async_trait]
impl ControllerExecuteCommand<ControllerWebSocketSender> for ControllerState {
async fn execute_command(
&self,
socket: &mut ControllerWebSocketSender,
request_message: ControllerRequestMessage,
) -> ControllerExit {
if let Some(message) = self.process_local_command(&request_message) {
self.write_to_socket(socket, Ok(message)).await;
let exit_controller = request_message.command == ControllerCommand::Exit;
if exit_controller
&& socket
.send(Message::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
reason: "exit".into(),
})))
.await
.is_err()
{
warn!("failed to write data to stream");
}
return exit_controller;
}
let command = match request_message.command {
ControllerCommand::Config => ControllerCommand::ConfigJson,
ControllerCommand::Metrics => ControllerCommand::MetricsJson,
_ => request_message.command.clone(),
};
let response = match self.process_command(request_message).await {
Ok(r) => r,
Err(e) => {
self.write_to_socket(socket, Err(e)).await;
return true;
}
};
let exit_controller = command == ControllerCommand::Shutdown;
let processed_response = (command.details().process_response)(response);
self.write_to_socket(socket, processed_response).await;
if exit_controller
&& socket
.send(Message::Close(Some(tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
reason: "shutdown".into(),
})))
.await
.is_err()
{
warn!("failed to write data to stream");
}
exit_controller
}
async fn write_to_socket(
&self,
socket: &mut ControllerWebSocketSender,
response_result: Result<String, String>,
) {
let success;
let response = match response_result {
Ok(m) => {
success = true;
m
}
Err(e) => {
success = false;
e
}
};
if let Err(e) = socket
.send(Message::Text(
match serde_json::to_string(&ControllerWebSocketResponse {
response,
success,
}) {
Ok(json) => json.into(),
Err(e) => {
warn!("failed to json encode response: {e}");
return;
}
},
))
.await
{
info!("failed to write data to websocket: {e}");
}
}
}