falcorn-sdk 0.1.0

Falcorn SDK for interacting with the server IPC and plugins.
Documentation
use crate::error::{Error, Result};
use crate::wire::{read_frame, write_frame};
use falcorn_proto::control::{
    ActionKind, ActionPayload, ActionRequest, ActionRequestPayload, ActionResponse, ActionStatus,
    CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION, ControlErrorFrame, ControlFrameType,
    ControlPingFrame, ControlPongFrame, ControlSubscribeAck, ControlSubscribeRequest,
    ReloadConfigRequest, RestartWorkerRequest, ScaleToRequest, ShutdownRequest, StatusSnapshot,
    WorkerSnapshot,
};
use std::os::unix::net::UnixStream;
use std::time::Duration;

#[derive(Clone, Debug)]
pub struct ControlClientBuilder {
    socket: String,
    auth_token: Option<String>,
    client_name: Option<String>,
    read_timeout: Option<Duration>,
    write_timeout: Option<Duration>,
}

impl ControlClientBuilder {
    pub fn new(socket: impl Into<String>) -> Self {
        Self {
            socket: socket.into(),
            auth_token: None,
            client_name: None,
            read_timeout: None,
            write_timeout: None,
        }
    }

    pub fn auth_token(mut self, token: impl Into<String>) -> Self {
        self.auth_token = Some(token.into());
        self
    }

    pub fn client_name(mut self, name: impl Into<String>) -> Self {
        self.client_name = Some(name.into());
        self
    }

    pub fn read_timeout(mut self, timeout: Duration) -> Self {
        self.read_timeout = Some(timeout);
        self
    }

    pub fn write_timeout(mut self, timeout: Duration) -> Self {
        self.write_timeout = Some(timeout);
        self
    }

    pub fn connect(self) -> Result<ControlClient> {
        let mut stream = UnixStream::connect(&self.socket)?;
        stream.set_read_timeout(self.read_timeout)?;
        stream.set_write_timeout(self.write_timeout)?;

        let subscribe = ControlSubscribeRequest {
            auth_token: self.auth_token,
            client_name: self.client_name,
        };

        let payload = bincode::serialize(&subscribe)?;
        write_frame(
            &mut stream,
            CONTROL_FRAME_MAGIC,
            CONTROL_PROTOCOL_VERSION,
            ControlFrameType::Subscribe as u8,
            &payload,
        )?;

        let (ft, payload) = read_frame(&mut stream, CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION)?;
        match ControlFrameType::from_u8(ft) {
            Some(ControlFrameType::Ack) => {
                let _ack: ControlSubscribeAck = bincode::deserialize(&payload)?;
            }
            Some(ControlFrameType::Error) => {
                let err: ControlErrorFrame = bincode::deserialize(&payload)?;
                return Err(Error::Remote {
                    code: err.code,
                    message: err.message,
                });
            }
            _ => {
                return Err(Error::Protocol("unexpected first frame".to_string()));
            }
        }

        Ok(ControlClient { stream, next_id: 1 })
    }
}

pub struct ControlClient {
    stream: UnixStream,
    next_id: u64,
}

impl ControlClient {
    pub fn builder(socket: impl Into<String>) -> ControlClientBuilder {
        ControlClientBuilder::new(socket)
    }

    pub fn send_action(
        &mut self,
        action: ActionKind,
        payload: Option<ActionRequestPayload>,
    ) -> Result<ActionResponse> {
        let request_id = self.next_id;
        self.next_id = self.next_id.saturating_add(1);

        let request = ActionRequest {
            id: request_id,
            action,
            payload,
        };

        let payload = bincode::serialize(&request)?;
        write_frame(
            &mut self.stream,
            CONTROL_FRAME_MAGIC,
            CONTROL_PROTOCOL_VERSION,
            ControlFrameType::ActionRequest as u8,
            &payload,
        )?;

        loop {
            let (ft, payload) = read_frame(
                &mut self.stream,
                CONTROL_FRAME_MAGIC,
                CONTROL_PROTOCOL_VERSION,
            )?;
            match ControlFrameType::from_u8(ft) {
                Some(ControlFrameType::ActionResponse) => {
                    let response: ActionResponse = bincode::deserialize(&payload)?;
                    if response.id != request_id {
                        return Err(Error::Protocol("mismatched action response id".to_string()));
                    }
                    return Ok(response);
                }
                Some(ControlFrameType::Error) => {
                    let err: ControlErrorFrame = bincode::deserialize(&payload)?;
                    return Err(Error::Remote {
                        code: err.code,
                        message: err.message,
                    });
                }
                Some(ControlFrameType::Ping) => {
                    let ping: ControlPingFrame = bincode::deserialize(&payload)?;
                    let pong = ControlPongFrame {
                        ts_millis: ping.ts_millis,
                    };
                    let payload = bincode::serialize(&pong)?;
                    write_frame(
                        &mut self.stream,
                        CONTROL_FRAME_MAGIC,
                        CONTROL_PROTOCOL_VERSION,
                        ControlFrameType::Pong as u8,
                        &payload,
                    )?;
                }
                Some(ControlFrameType::Pong) => {}
                _ => {
                    return Err(Error::Protocol("unexpected frame".to_string()));
                }
            }
        }
    }

    pub fn get_status(&mut self) -> Result<StatusSnapshot> {
        let response = self.send_action(ActionKind::GetStatus, None)?;
        self.expect_status(response)
    }

    pub fn get_workers(&mut self) -> Result<Vec<WorkerSnapshot>> {
        let response = self.send_action(ActionKind::GetWorkers, None)?;
        self.expect_workers(response)
    }

    pub fn show_config(&mut self) -> Result<String> {
        let response = self.send_action(ActionKind::ShowConfig, None)?;
        self.expect_config(response)
    }

    pub fn scale_to(&mut self, workers: usize) -> Result<String> {
        let payload = ActionRequestPayload::ScaleTo(ScaleToRequest { workers });
        let response = self.send_action(ActionKind::ScaleTo, Some(payload))?;
        self.expect_message(response, "scale")
    }

    pub fn restart_worker(&mut self, id: Option<u32>, graceful: bool) -> Result<String> {
        let payload = ActionRequestPayload::RestartWorker(RestartWorkerRequest { id, graceful });
        let response = self.send_action(ActionKind::RestartWorker, Some(payload))?;
        self.expect_message(response, "restart")
    }

    pub fn shutdown(&mut self, graceful: bool) -> Result<String> {
        let payload = ActionRequestPayload::Shutdown(ShutdownRequest { graceful });
        let response = self.send_action(ActionKind::Shutdown, Some(payload))?;
        self.expect_message(response, "shutdown")
    }

    pub fn reload_config(&mut self, path: Option<String>, rolling: bool) -> Result<String> {
        let payload = ActionRequestPayload::ReloadConfig(ReloadConfigRequest { path, rolling });
        let response = self.send_action(ActionKind::ReloadConfig, Some(payload))?;
        self.expect_message(response, "reload")
    }

    fn ensure_ok(&self, response: &ActionResponse) -> Result<()> {
        if let ActionStatus::Error = response.status {
            if let Some(err) = &response.error {
                return Err(Error::Remote {
                    code: err.code.clone(),
                    message: err.message.clone(),
                });
            }
            return Err(Error::Protocol(
                "action failed without error details".to_string(),
            ));
        }
        Ok(())
    }

    fn expect_status(&self, response: ActionResponse) -> Result<StatusSnapshot> {
        self.ensure_ok(&response)?;
        match response.payload {
            Some(ActionPayload::Status(value)) => Ok(value),
            Some(_) => Err(Error::Protocol("unexpected payload for status".to_string())),
            None => Err(Error::Protocol(
                "missing payload for status response".to_string(),
            )),
        }
    }

    fn expect_workers(&self, response: ActionResponse) -> Result<Vec<WorkerSnapshot>> {
        self.ensure_ok(&response)?;
        match response.payload {
            Some(ActionPayload::Workers(value)) => Ok(value),
            Some(_) => Err(Error::Protocol(
                "unexpected payload for workers".to_string(),
            )),
            None => Err(Error::Protocol(
                "missing payload for workers response".to_string(),
            )),
        }
    }

    fn expect_config(&self, response: ActionResponse) -> Result<String> {
        self.ensure_ok(&response)?;
        match response.payload {
            Some(ActionPayload::Config(value)) => Ok(value),
            Some(_) => Err(Error::Protocol("unexpected payload for config".to_string())),
            None => Err(Error::Protocol(
                "missing payload for config response".to_string(),
            )),
        }
    }

    fn expect_message(&self, response: ActionResponse, action: &str) -> Result<String> {
        self.ensure_ok(&response)?;
        match response.payload {
            Some(ActionPayload::Message(value)) => Ok(value),
            Some(_) => Err(Error::Protocol(format!(
                "unexpected payload for {} response",
                action
            ))),
            None => Err(Error::Protocol(format!(
                "missing payload for {} response",
                action
            ))),
        }
    }
}