1use crate::error::{Error, Result};
2use crate::wire::{read_frame, write_frame};
3use falcorn_proto::control::{
4 ActionKind, ActionPayload, ActionRequest, ActionRequestPayload, ActionResponse, ActionStatus,
5 CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION, ControlErrorFrame, ControlFrameType,
6 ControlPingFrame, ControlPongFrame, ControlSubscribeAck, ControlSubscribeRequest,
7 ReloadConfigRequest, RestartWorkerRequest, ScaleToRequest, ShutdownRequest, StatusSnapshot,
8 WorkerSnapshot,
9};
10use std::os::unix::net::UnixStream;
11use std::time::Duration;
12
13#[derive(Clone, Debug)]
14pub struct ControlClientBuilder {
15 socket: String,
16 auth_token: Option<String>,
17 client_name: Option<String>,
18 read_timeout: Option<Duration>,
19 write_timeout: Option<Duration>,
20}
21
22impl ControlClientBuilder {
23 pub fn new(socket: impl Into<String>) -> Self {
24 Self {
25 socket: socket.into(),
26 auth_token: None,
27 client_name: None,
28 read_timeout: None,
29 write_timeout: None,
30 }
31 }
32
33 pub fn auth_token(mut self, token: impl Into<String>) -> Self {
34 self.auth_token = Some(token.into());
35 self
36 }
37
38 pub fn client_name(mut self, name: impl Into<String>) -> Self {
39 self.client_name = Some(name.into());
40 self
41 }
42
43 pub fn read_timeout(mut self, timeout: Duration) -> Self {
44 self.read_timeout = Some(timeout);
45 self
46 }
47
48 pub fn write_timeout(mut self, timeout: Duration) -> Self {
49 self.write_timeout = Some(timeout);
50 self
51 }
52
53 pub fn connect(self) -> Result<ControlClient> {
54 let mut stream = UnixStream::connect(&self.socket)?;
55 stream.set_read_timeout(self.read_timeout)?;
56 stream.set_write_timeout(self.write_timeout)?;
57
58 let subscribe = ControlSubscribeRequest {
59 auth_token: self.auth_token,
60 client_name: self.client_name,
61 };
62
63 let payload = bincode::serialize(&subscribe)?;
64 write_frame(
65 &mut stream,
66 CONTROL_FRAME_MAGIC,
67 CONTROL_PROTOCOL_VERSION,
68 ControlFrameType::Subscribe as u8,
69 &payload,
70 )?;
71
72 let (ft, payload) = read_frame(&mut stream, CONTROL_FRAME_MAGIC, CONTROL_PROTOCOL_VERSION)?;
73 match ControlFrameType::from_u8(ft) {
74 Some(ControlFrameType::Ack) => {
75 let _ack: ControlSubscribeAck = bincode::deserialize(&payload)?;
76 }
77 Some(ControlFrameType::Error) => {
78 let err: ControlErrorFrame = bincode::deserialize(&payload)?;
79 return Err(Error::Remote {
80 code: err.code,
81 message: err.message,
82 });
83 }
84 _ => {
85 return Err(Error::Protocol("unexpected first frame".to_string()));
86 }
87 }
88
89 Ok(ControlClient { stream, next_id: 1 })
90 }
91}
92
93pub struct ControlClient {
94 stream: UnixStream,
95 next_id: u64,
96}
97
98impl ControlClient {
99 pub fn builder(socket: impl Into<String>) -> ControlClientBuilder {
100 ControlClientBuilder::new(socket)
101 }
102
103 pub fn send_action(
104 &mut self,
105 action: ActionKind,
106 payload: Option<ActionRequestPayload>,
107 ) -> Result<ActionResponse> {
108 let request_id = self.next_id;
109 self.next_id = self.next_id.saturating_add(1);
110
111 let request = ActionRequest {
112 id: request_id,
113 action,
114 payload,
115 };
116
117 let payload = bincode::serialize(&request)?;
118 write_frame(
119 &mut self.stream,
120 CONTROL_FRAME_MAGIC,
121 CONTROL_PROTOCOL_VERSION,
122 ControlFrameType::ActionRequest as u8,
123 &payload,
124 )?;
125
126 loop {
127 let (ft, payload) = read_frame(
128 &mut self.stream,
129 CONTROL_FRAME_MAGIC,
130 CONTROL_PROTOCOL_VERSION,
131 )?;
132 match ControlFrameType::from_u8(ft) {
133 Some(ControlFrameType::ActionResponse) => {
134 let response: ActionResponse = bincode::deserialize(&payload)?;
135 if response.id != request_id {
136 return Err(Error::Protocol("mismatched action response id".to_string()));
137 }
138 return Ok(response);
139 }
140 Some(ControlFrameType::Error) => {
141 let err: ControlErrorFrame = bincode::deserialize(&payload)?;
142 return Err(Error::Remote {
143 code: err.code,
144 message: err.message,
145 });
146 }
147 Some(ControlFrameType::Ping) => {
148 let ping: ControlPingFrame = bincode::deserialize(&payload)?;
149 let pong = ControlPongFrame {
150 ts_millis: ping.ts_millis,
151 };
152 let payload = bincode::serialize(&pong)?;
153 write_frame(
154 &mut self.stream,
155 CONTROL_FRAME_MAGIC,
156 CONTROL_PROTOCOL_VERSION,
157 ControlFrameType::Pong as u8,
158 &payload,
159 )?;
160 }
161 Some(ControlFrameType::Pong) => {}
162 _ => {
163 return Err(Error::Protocol("unexpected frame".to_string()));
164 }
165 }
166 }
167 }
168
169 pub fn get_status(&mut self) -> Result<StatusSnapshot> {
170 let response = self.send_action(ActionKind::GetStatus, None)?;
171 self.expect_status(response)
172 }
173
174 pub fn get_workers(&mut self) -> Result<Vec<WorkerSnapshot>> {
175 let response = self.send_action(ActionKind::GetWorkers, None)?;
176 self.expect_workers(response)
177 }
178
179 pub fn show_config(&mut self) -> Result<String> {
180 let response = self.send_action(ActionKind::ShowConfig, None)?;
181 self.expect_config(response)
182 }
183
184 pub fn scale_to(&mut self, workers: usize) -> Result<String> {
185 let payload = ActionRequestPayload::ScaleTo(ScaleToRequest { workers });
186 let response = self.send_action(ActionKind::ScaleTo, Some(payload))?;
187 self.expect_message(response, "scale")
188 }
189
190 pub fn restart_worker(&mut self, id: Option<u32>, graceful: bool) -> Result<String> {
191 let payload = ActionRequestPayload::RestartWorker(RestartWorkerRequest { id, graceful });
192 let response = self.send_action(ActionKind::RestartWorker, Some(payload))?;
193 self.expect_message(response, "restart")
194 }
195
196 pub fn shutdown(&mut self, graceful: bool) -> Result<String> {
197 let payload = ActionRequestPayload::Shutdown(ShutdownRequest { graceful });
198 let response = self.send_action(ActionKind::Shutdown, Some(payload))?;
199 self.expect_message(response, "shutdown")
200 }
201
202 pub fn reload_config(&mut self, path: Option<String>, rolling: bool) -> Result<String> {
203 let payload = ActionRequestPayload::ReloadConfig(ReloadConfigRequest { path, rolling });
204 let response = self.send_action(ActionKind::ReloadConfig, Some(payload))?;
205 self.expect_message(response, "reload")
206 }
207
208 fn ensure_ok(&self, response: &ActionResponse) -> Result<()> {
209 if let ActionStatus::Error = response.status {
210 if let Some(err) = &response.error {
211 return Err(Error::Remote {
212 code: err.code.clone(),
213 message: err.message.clone(),
214 });
215 }
216 return Err(Error::Protocol(
217 "action failed without error details".to_string(),
218 ));
219 }
220 Ok(())
221 }
222
223 fn expect_status(&self, response: ActionResponse) -> Result<StatusSnapshot> {
224 self.ensure_ok(&response)?;
225 match response.payload {
226 Some(ActionPayload::Status(value)) => Ok(value),
227 Some(_) => Err(Error::Protocol("unexpected payload for status".to_string())),
228 None => Err(Error::Protocol(
229 "missing payload for status response".to_string(),
230 )),
231 }
232 }
233
234 fn expect_workers(&self, response: ActionResponse) -> Result<Vec<WorkerSnapshot>> {
235 self.ensure_ok(&response)?;
236 match response.payload {
237 Some(ActionPayload::Workers(value)) => Ok(value),
238 Some(_) => Err(Error::Protocol(
239 "unexpected payload for workers".to_string(),
240 )),
241 None => Err(Error::Protocol(
242 "missing payload for workers response".to_string(),
243 )),
244 }
245 }
246
247 fn expect_config(&self, response: ActionResponse) -> Result<String> {
248 self.ensure_ok(&response)?;
249 match response.payload {
250 Some(ActionPayload::Config(value)) => Ok(value),
251 Some(_) => Err(Error::Protocol("unexpected payload for config".to_string())),
252 None => Err(Error::Protocol(
253 "missing payload for config response".to_string(),
254 )),
255 }
256 }
257
258 fn expect_message(&self, response: ActionResponse, action: &str) -> Result<String> {
259 self.ensure_ok(&response)?;
260 match response.payload {
261 Some(ActionPayload::Message(value)) => Ok(value),
262 Some(_) => Err(Error::Protocol(format!(
263 "unexpected payload for {} response",
264 action
265 ))),
266 None => Err(Error::Protocol(format!(
267 "missing payload for {} response",
268 action
269 ))),
270 }
271 }
272}