use crate::error::{Error, Result};
use crate::wire::{read_frame, write_frame};
use falcorn_proto::control::{
ActionKind, ActionPayload, ActionRequest, ActionRequestPayload, ActionResponse, ActionStatus,
CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION, ControlErrorFrame, ControlFrameType,
ControlPingFrame, ControlPongFrame, ControlSubscribeAck, ControlSubscribeRequest,
ReloadConfigRequest, RestartWorkerRequest, ScaleToRequest, ShutdownRequest, StatusSnapshot,
WorkerSnapshot,
};
use std::os::unix::net::UnixStream;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct ControlClientBuilder {
socket: String,
auth_token: Option<String>,
client_name: Option<String>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
}
impl ControlClientBuilder {
pub fn new(socket: impl Into<String>) -> Self {
Self {
socket: socket.into(),
auth_token: None,
client_name: None,
read_timeout: None,
write_timeout: None,
}
}
pub fn auth_token(mut self, token: impl Into<String>) -> Self {
self.auth_token = Some(token.into());
self
}
pub fn client_name(mut self, name: impl Into<String>) -> Self {
self.client_name = Some(name.into());
self
}
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub fn write_timeout(mut self, timeout: Duration) -> Self {
self.write_timeout = Some(timeout);
self
}
pub fn connect(self) -> Result<ControlClient> {
let mut stream = UnixStream::connect(&self.socket)?;
stream.set_read_timeout(self.read_timeout)?;
stream.set_write_timeout(self.write_timeout)?;
let subscribe = ControlSubscribeRequest {
auth_token: self.auth_token,
client_name: self.client_name,
};
let payload = bincode::serialize(&subscribe)?;
write_frame(
&mut stream,
CONTROL_FRAME_MAGIC,
CONTROL_PROTOCOL_VERSION,
ControlFrameType::Subscribe as u8,
&payload,
)?;
let (ft, payload) = read_frame(&mut stream, CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION)?;
match ControlFrameType::from_u8(ft) {
Some(ControlFrameType::Ack) => {
let _ack: ControlSubscribeAck = bincode::deserialize(&payload)?;
}
Some(ControlFrameType::Error) => {
let err: ControlErrorFrame = bincode::deserialize(&payload)?;
return Err(Error::Remote {
code: err.code,
message: err.message,
});
}
_ => {
return Err(Error::Protocol("unexpected first frame".to_string()));
}
}
Ok(ControlClient { stream, next_id: 1 })
}
}
pub struct ControlClient {
stream: UnixStream,
next_id: u64,
}
impl ControlClient {
pub fn builder(socket: impl Into<String>) -> ControlClientBuilder {
ControlClientBuilder::new(socket)
}
pub fn send_action(
&mut self,
action: ActionKind,
payload: Option<ActionRequestPayload>,
) -> Result<ActionResponse> {
let request_id = self.next_id;
self.next_id = self.next_id.saturating_add(1);
let request = ActionRequest {
id: request_id,
action,
payload,
};
let payload = bincode::serialize(&request)?;
write_frame(
&mut self.stream,
CONTROL_FRAME_MAGIC,
CONTROL_PROTOCOL_VERSION,
ControlFrameType::ActionRequest as u8,
&payload,
)?;
loop {
let (ft, payload) = read_frame(
&mut self.stream,
CONTROL_FRAME_MAGIC,
CONTROL_PROTOCOL_VERSION,
)?;
match ControlFrameType::from_u8(ft) {
Some(ControlFrameType::ActionResponse) => {
let response: ActionResponse = bincode::deserialize(&payload)?;
if response.id != request_id {
return Err(Error::Protocol("mismatched action response id".to_string()));
}
return Ok(response);
}
Some(ControlFrameType::Error) => {
let err: ControlErrorFrame = bincode::deserialize(&payload)?;
return Err(Error::Remote {
code: err.code,
message: err.message,
});
}
Some(ControlFrameType::Ping) => {
let ping: ControlPingFrame = bincode::deserialize(&payload)?;
let pong = ControlPongFrame {
ts_millis: ping.ts_millis,
};
let payload = bincode::serialize(&pong)?;
write_frame(
&mut self.stream,
CONTROL_FRAME_MAGIC,
CONTROL_PROTOCOL_VERSION,
ControlFrameType::Pong as u8,
&payload,
)?;
}
Some(ControlFrameType::Pong) => {}
_ => {
return Err(Error::Protocol("unexpected frame".to_string()));
}
}
}
}
pub fn get_status(&mut self) -> Result<StatusSnapshot> {
let response = self.send_action(ActionKind::GetStatus, None)?;
self.expect_status(response)
}
pub fn get_workers(&mut self) -> Result<Vec<WorkerSnapshot>> {
let response = self.send_action(ActionKind::GetWorkers, None)?;
self.expect_workers(response)
}
pub fn show_config(&mut self) -> Result<String> {
let response = self.send_action(ActionKind::ShowConfig, None)?;
self.expect_config(response)
}
pub fn scale_to(&mut self, workers: usize) -> Result<String> {
let payload = ActionRequestPayload::ScaleTo(ScaleToRequest { workers });
let response = self.send_action(ActionKind::ScaleTo, Some(payload))?;
self.expect_message(response, "scale")
}
pub fn restart_worker(&mut self, id: Option<u32>, graceful: bool) -> Result<String> {
let payload = ActionRequestPayload::RestartWorker(RestartWorkerRequest { id, graceful });
let response = self.send_action(ActionKind::RestartWorker, Some(payload))?;
self.expect_message(response, "restart")
}
pub fn shutdown(&mut self, graceful: bool) -> Result<String> {
let payload = ActionRequestPayload::Shutdown(ShutdownRequest { graceful });
let response = self.send_action(ActionKind::Shutdown, Some(payload))?;
self.expect_message(response, "shutdown")
}
pub fn reload_config(&mut self, path: Option<String>, rolling: bool) -> Result<String> {
let payload = ActionRequestPayload::ReloadConfig(ReloadConfigRequest { path, rolling });
let response = self.send_action(ActionKind::ReloadConfig, Some(payload))?;
self.expect_message(response, "reload")
}
fn ensure_ok(&self, response: &ActionResponse) -> Result<()> {
if let ActionStatus::Error = response.status {
if let Some(err) = &response.error {
return Err(Error::Remote {
code: err.code.clone(),
message: err.message.clone(),
});
}
return Err(Error::Protocol(
"action failed without error details".to_string(),
));
}
Ok(())
}
fn expect_status(&self, response: ActionResponse) -> Result<StatusSnapshot> {
self.ensure_ok(&response)?;
match response.payload {
Some(ActionPayload::Status(value)) => Ok(value),
Some(_) => Err(Error::Protocol("unexpected payload for status".to_string())),
None => Err(Error::Protocol(
"missing payload for status response".to_string(),
)),
}
}
fn expect_workers(&self, response: ActionResponse) -> Result<Vec<WorkerSnapshot>> {
self.ensure_ok(&response)?;
match response.payload {
Some(ActionPayload::Workers(value)) => Ok(value),
Some(_) => Err(Error::Protocol(
"unexpected payload for workers".to_string(),
)),
None => Err(Error::Protocol(
"missing payload for workers response".to_string(),
)),
}
}
fn expect_config(&self, response: ActionResponse) -> Result<String> {
self.ensure_ok(&response)?;
match response.payload {
Some(ActionPayload::Config(value)) => Ok(value),
Some(_) => Err(Error::Protocol("unexpected payload for config".to_string())),
None => Err(Error::Protocol(
"missing payload for config response".to_string(),
)),
}
}
fn expect_message(&self, response: ActionResponse, action: &str) -> Result<String> {
self.ensure_ok(&response)?;
match response.payload {
Some(ActionPayload::Message(value)) => Ok(value),
Some(_) => Err(Error::Protocol(format!(
"unexpected payload for {} response",
action
))),
None => Err(Error::Protocol(format!(
"missing payload for {} response",
action
))),
}
}
}