Skip to main content

falcorn_sdk/
control.rs

1use crate::error::{Error, Result};
2use crate::wire::{read_frame, write_frame};
3use falcorn_proto::control::{
4    ActionKind, ActionPayload, ActionRequest, ActionRequestPayload, ActionResponse, ActionStatus,
5    CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION, ControlErrorFrame, ControlFrameType,
6    ControlPingFrame, ControlPongFrame, ControlSubscribeAck, ControlSubscribeRequest,
7    ReloadConfigRequest, RestartWorkerRequest, ScaleToRequest, ShutdownRequest, StatusSnapshot,
8    WorkerSnapshot,
9};
10use std::os::unix::net::UnixStream;
11use std::time::Duration;
12
13#[derive(Clone, Debug)]
14pub struct ControlClientBuilder {
15    socket: String,
16    auth_token: Option<String>,
17    client_name: Option<String>,
18    read_timeout: Option<Duration>,
19    write_timeout: Option<Duration>,
20}
21
22impl ControlClientBuilder {
23    pub fn new(socket: impl Into<String>) -> Self {
24        Self {
25            socket: socket.into(),
26            auth_token: None,
27            client_name: None,
28            read_timeout: None,
29            write_timeout: None,
30        }
31    }
32
33    pub fn auth_token(mut self, token: impl Into<String>) -> Self {
34        self.auth_token = Some(token.into());
35        self
36    }
37
38    pub fn client_name(mut self, name: impl Into<String>) -> Self {
39        self.client_name = Some(name.into());
40        self
41    }
42
43    pub fn read_timeout(mut self, timeout: Duration) -> Self {
44        self.read_timeout = Some(timeout);
45        self
46    }
47
48    pub fn write_timeout(mut self, timeout: Duration) -> Self {
49        self.write_timeout = Some(timeout);
50        self
51    }
52
53    pub fn connect(self) -> Result<ControlClient> {
54        let mut stream = UnixStream::connect(&self.socket)?;
55        stream.set_read_timeout(self.read_timeout)?;
56        stream.set_write_timeout(self.write_timeout)?;
57
58        let subscribe = ControlSubscribeRequest {
59            auth_token: self.auth_token,
60            client_name: self.client_name,
61        };
62
63        let payload = bincode::serialize(&subscribe)?;
64        write_frame(
65            &mut stream,
66            CONTROL_FRAME_MAGIC,
67            CONTROL_PROTOCOL_VERSION,
68            ControlFrameType::Subscribe as u8,
69            &payload,
70        )?;
71
72        let (ft, payload) = read_frame(&mut stream, CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION)?;
73        match ControlFrameType::from_u8(ft) {
74            Some(ControlFrameType::Ack) => {
75                let _ack: ControlSubscribeAck = bincode::deserialize(&payload)?;
76            }
77            Some(ControlFrameType::Error) => {
78                let err: ControlErrorFrame = bincode::deserialize(&payload)?;
79                return Err(Error::Remote {
80                    code: err.code,
81                    message: err.message,
82                });
83            }
84            _ => {
85                return Err(Error::Protocol("unexpected first frame".to_string()));
86            }
87        }
88
89        Ok(ControlClient { stream, next_id: 1 })
90    }
91}
92
93pub struct ControlClient {
94    stream: UnixStream,
95    next_id: u64,
96}
97
98impl ControlClient {
99    pub fn builder(socket: impl Into<String>) -> ControlClientBuilder {
100        ControlClientBuilder::new(socket)
101    }
102
103    pub fn send_action(
104        &mut self,
105        action: ActionKind,
106        payload: Option<ActionRequestPayload>,
107    ) -> Result<ActionResponse> {
108        let request_id = self.next_id;
109        self.next_id = self.next_id.saturating_add(1);
110
111        let request = ActionRequest {
112            id: request_id,
113            action,
114            payload,
115        };
116
117        let payload = bincode::serialize(&request)?;
118        write_frame(
119            &mut self.stream,
120            CONTROL_FRAME_MAGIC,
121            CONTROL_PROTOCOL_VERSION,
122            ControlFrameType::ActionRequest as u8,
123            &payload,
124        )?;
125
126        loop {
127            let (ft, payload) = read_frame(
128                &mut self.stream,
129                CONTROL_FRAME_MAGIC,
130                CONTROL_PROTOCOL_VERSION,
131            )?;
132            match ControlFrameType::from_u8(ft) {
133                Some(ControlFrameType::ActionResponse) => {
134                    let response: ActionResponse = bincode::deserialize(&payload)?;
135                    if response.id != request_id {
136                        return Err(Error::Protocol("mismatched action response id".to_string()));
137                    }
138                    return Ok(response);
139                }
140                Some(ControlFrameType::Error) => {
141                    let err: ControlErrorFrame = bincode::deserialize(&payload)?;
142                    return Err(Error::Remote {
143                        code: err.code,
144                        message: err.message,
145                    });
146                }
147                Some(ControlFrameType::Ping) => {
148                    let ping: ControlPingFrame = bincode::deserialize(&payload)?;
149                    let pong = ControlPongFrame {
150                        ts_millis: ping.ts_millis,
151                    };
152                    let payload = bincode::serialize(&pong)?;
153                    write_frame(
154                        &mut self.stream,
155                        CONTROL_FRAME_MAGIC,
156                        CONTROL_PROTOCOL_VERSION,
157                        ControlFrameType::Pong as u8,
158                        &payload,
159                    )?;
160                }
161                Some(ControlFrameType::Pong) => {}
162                _ => {
163                    return Err(Error::Protocol("unexpected frame".to_string()));
164                }
165            }
166        }
167    }
168
169    pub fn get_status(&mut self) -> Result<StatusSnapshot> {
170        let response = self.send_action(ActionKind::GetStatus, None)?;
171        self.expect_status(response)
172    }
173
174    pub fn get_workers(&mut self) -> Result<Vec<WorkerSnapshot>> {
175        let response = self.send_action(ActionKind::GetWorkers, None)?;
176        self.expect_workers(response)
177    }
178
179    pub fn show_config(&mut self) -> Result<String> {
180        let response = self.send_action(ActionKind::ShowConfig, None)?;
181        self.expect_config(response)
182    }
183
184    pub fn scale_to(&mut self, workers: usize) -> Result<String> {
185        let payload = ActionRequestPayload::ScaleTo(ScaleToRequest { workers });
186        let response = self.send_action(ActionKind::ScaleTo, Some(payload))?;
187        self.expect_message(response, "scale")
188    }
189
190    pub fn restart_worker(&mut self, id: Option<u32>, graceful: bool) -> Result<String> {
191        let payload = ActionRequestPayload::RestartWorker(RestartWorkerRequest { id, graceful });
192        let response = self.send_action(ActionKind::RestartWorker, Some(payload))?;
193        self.expect_message(response, "restart")
194    }
195
196    pub fn shutdown(&mut self, graceful: bool) -> Result<String> {
197        let payload = ActionRequestPayload::Shutdown(ShutdownRequest { graceful });
198        let response = self.send_action(ActionKind::Shutdown, Some(payload))?;
199        self.expect_message(response, "shutdown")
200    }
201
202    pub fn reload_config(&mut self, path: Option<String>, rolling: bool) -> Result<String> {
203        let payload = ActionRequestPayload::ReloadConfig(ReloadConfigRequest { path, rolling });
204        let response = self.send_action(ActionKind::ReloadConfig, Some(payload))?;
205        self.expect_message(response, "reload")
206    }
207
208    fn ensure_ok(&self, response: &ActionResponse) -> Result<()> {
209        if let ActionStatus::Error = response.status {
210            if let Some(err) = &response.error {
211                return Err(Error::Remote {
212                    code: err.code.clone(),
213                    message: err.message.clone(),
214                });
215            }
216            return Err(Error::Protocol(
217                "action failed without error details".to_string(),
218            ));
219        }
220        Ok(())
221    }
222
223    fn expect_status(&self, response: ActionResponse) -> Result<StatusSnapshot> {
224        self.ensure_ok(&response)?;
225        match response.payload {
226            Some(ActionPayload::Status(value)) => Ok(value),
227            Some(_) => Err(Error::Protocol("unexpected payload for status".to_string())),
228            None => Err(Error::Protocol(
229                "missing payload for status response".to_string(),
230            )),
231        }
232    }
233
234    fn expect_workers(&self, response: ActionResponse) -> Result<Vec<WorkerSnapshot>> {
235        self.ensure_ok(&response)?;
236        match response.payload {
237            Some(ActionPayload::Workers(value)) => Ok(value),
238            Some(_) => Err(Error::Protocol(
239                "unexpected payload for workers".to_string(),
240            )),
241            None => Err(Error::Protocol(
242                "missing payload for workers response".to_string(),
243            )),
244        }
245    }
246
247    fn expect_config(&self, response: ActionResponse) -> Result<String> {
248        self.ensure_ok(&response)?;
249        match response.payload {
250            Some(ActionPayload::Config(value)) => Ok(value),
251            Some(_) => Err(Error::Protocol("unexpected payload for config".to_string())),
252            None => Err(Error::Protocol(
253                "missing payload for config response".to_string(),
254            )),
255        }
256    }
257
258    fn expect_message(&self, response: ActionResponse, action: &str) -> Result<String> {
259        self.ensure_ok(&response)?;
260        match response.payload {
261            Some(ActionPayload::Message(value)) => Ok(value),
262            Some(_) => Err(Error::Protocol(format!(
263                "unexpected payload for {} response",
264                action
265            ))),
266            None => Err(Error::Protocol(format!(
267                "missing payload for {} response",
268                action
269            ))),
270        }
271    }
272}