1use std::io::prelude::*;
3use std::io::BufReader;
4use std::io::LineWriter;
5use std::path::PathBuf;
6use std::{io, net, time};
7
8#[cfg(unix)]
9use std::os::unix::net::{UnixListener as Listener, UnixStream as Stream};
10#[cfg(windows)]
11use winpipe::{WinListener as Listener, WinStream as Stream};
12
13use radicle::node::Handle;
14use serde_json as json;
15
16use crate::identity::RepoId;
17use crate::node::NodeId;
18use crate::node::{Command, CommandResult};
19use crate::runtime;
20use crate::runtime::thread;
21
22const MAX_TIMEOUT: time::Duration = time::Duration::MAX;
24
25#[derive(thiserror::Error, Debug)]
26pub enum Error {
27 #[error("failed to bind control socket listener: {0}")]
28 Bind(io::Error),
29 #[error("invalid socket path specified: {0}")]
30 InvalidPath(PathBuf),
31 #[error("node: {0}")]
32 Node(#[from] runtime::HandleError),
33}
34
35pub fn listen<E, H>(listener: Listener, handle: H) -> Result<(), Error>
37where
38 H: Handle<Error = runtime::HandleError> + 'static,
39 H::Sessions: serde::Serialize,
40 CommandResult<E>: From<H::Event>,
41 E: serde::Serialize,
42{
43 log::debug!(target: "control", "Control thread listening on socket..");
44 let nid = handle.nid()?;
45
46 for incoming in listener.incoming() {
47 match incoming {
48 Ok(stream) => {
49 let handle = handle.clone();
50
51 thread::spawn(&nid, "control", move || {
52 if let Err((e, mut stream)) = command(stream, handle) {
53 log::error!(target: "control", "Command returned error: {e}");
54
55 CommandResult::error(e).to_writer(&mut stream).ok();
56
57 stream.flush().ok();
58 stream.shutdown(net::Shutdown::Both).ok();
59 }
60 });
61 }
62 Err(e) => log::error!(target: "control", "Failed to accept incoming connection: {e}"),
63 }
64 }
65 log::debug!(target: "control", "Exiting control loop..");
66
67 Ok(())
68}
69
70#[derive(thiserror::Error, Debug)]
71enum CommandError {
72 #[error("(de)serialization failed: {0}")]
73 Serialization(#[from] json::Error),
74 #[error("runtime error: {0}")]
75 Runtime(#[from] runtime::HandleError),
76 #[error("i/o error: {0}")]
77 Io(#[from] io::Error),
78}
79
80#[cfg(unix)]
81fn command<E, H>(stream: Stream, handle: H) -> Result<(), (CommandError, Stream)>
82where
83 H: Handle<Error = runtime::HandleError> + 'static,
84 H::Sessions: serde::Serialize,
85 CommandResult<E>: From<H::Event>,
86 E: serde::Serialize,
87{
88 let reader = BufReader::new(&stream);
89 let writer = LineWriter::new(&stream);
90
91 command_internal(reader, writer, handle).map_err(|e| (e, stream))
92}
93
94#[cfg(windows)]
101fn command<E, H>(stream: Stream, handle: H) -> Result<(), (CommandError, Stream)>
102where
103 H: Handle<Error = runtime::HandleError> + 'static,
104 H::Sessions: serde::Serialize,
105 CommandResult<E>: From<H::Event>,
106 E: serde::Serialize,
107{
108 let mut reader = match stream.try_clone() {
109 Ok(reader) => reader,
110 Err(err) => return Err((err.into(), stream)),
111 };
112 let reader = BufReader::new(&mut reader);
113
114 let mut writer = match stream.try_clone() {
115 Ok(writer) => writer,
116 Err(err) => return Err((err.into(), stream)),
117 };
118 let writer = LineWriter::new(&mut writer);
119
120 command_internal(reader, writer, handle).map_err(|e| (e, stream))
121}
122
123#[inline(always)]
124fn command_internal<E, H, R, W>(
125 mut reader: BufReader<R>,
126 mut writer: LineWriter<W>,
127 mut handle: H,
128) -> Result<(), CommandError>
129where
130 H: Handle<Error = runtime::HandleError> + 'static,
131 H::Sessions: serde::Serialize,
132 CommandResult<E>: From<H::Event>,
133 E: serde::Serialize,
134 R: io::Read,
135 W: io::Write,
136{
137 let mut line = String::new();
138
139 reader.read_line(&mut line)?;
140 let input = line.trim_end();
141
142 log::debug!(target: "control", "Received `{input}` on control socket");
143 let cmd: Command = json::from_str(input)?;
144
145 match cmd {
146 Command::Connect { addr, opts } => {
147 let (nid, addr) = addr.into();
148 match handle.connect(nid, addr, opts) {
149 Err(e) => return Err(CommandError::Runtime(e)),
150 Ok(result) => {
151 json::to_writer(&mut writer, &result)?;
152 writer.write_all(b"\n")?;
153 }
154 }
155 }
156 Command::Disconnect { nid } => match handle.disconnect(nid) {
157 Err(e) => return Err(CommandError::Runtime(e)),
158 Ok(()) => {
159 CommandResult::ok().to_writer(writer).ok();
160 }
161 },
162 Command::Fetch { rid, nid, timeout } => {
163 fetch(rid, nid, timeout, writer, &mut handle)?;
164 }
165 Command::Config => {
166 let config = handle.config()?;
167
168 CommandResult::Okay(config).to_writer(writer)?;
169 }
170 Command::ListenAddrs => {
171 let addrs = handle.listen_addrs()?;
172
173 CommandResult::Okay(addrs).to_writer(writer)?;
174 }
175 Command::Seeds { rid } => {
176 let seeds = handle.seeds(rid)?;
177
178 CommandResult::Okay(seeds).to_writer(writer)?;
179 }
180 Command::Sessions => {
181 let sessions = handle.sessions()?;
182
183 CommandResult::Okay(sessions).to_writer(writer)?;
184 }
185 Command::Session { nid } => {
186 let session = handle.session(nid)?;
187
188 CommandResult::Okay(session).to_writer(writer)?;
189 }
190 Command::Seed { rid, scope } => match handle.seed(rid, scope) {
191 Ok(result) => {
192 CommandResult::updated(result).to_writer(writer)?;
193 }
194 Err(e) => {
195 return Err(CommandError::Runtime(e));
196 }
197 },
198 Command::Unseed { rid } => match handle.unseed(rid) {
199 Ok(result) => {
200 CommandResult::updated(result).to_writer(writer)?;
201 }
202 Err(e) => {
203 return Err(CommandError::Runtime(e));
204 }
205 },
206 Command::Follow { nid, alias } => match handle.follow(nid, alias) {
207 Ok(result) => {
208 CommandResult::updated(result).to_writer(writer)?;
209 }
210 Err(e) => {
211 return Err(CommandError::Runtime(e));
212 }
213 },
214 Command::Unfollow { nid } => match handle.unfollow(nid) {
215 Ok(result) => {
216 CommandResult::updated(result).to_writer(writer)?;
217 }
218 Err(e) => {
219 return Err(CommandError::Runtime(e));
220 }
221 },
222 Command::AnnounceRefs { rid } => {
223 let refs = handle.announce_refs(rid)?;
224
225 CommandResult::Okay(refs).to_writer(writer)?;
226 }
227 Command::AnnounceInventory => {
228 if let Err(e) = handle.announce_inventory() {
229 return Err(CommandError::Runtime(e));
230 }
231 CommandResult::ok().to_writer(writer).ok();
232 }
233 Command::AddInventory { rid } => match handle.add_inventory(rid) {
234 Ok(result) => {
235 CommandResult::updated(result).to_writer(writer)?;
236 }
237 Err(e) => {
238 return Err(CommandError::Runtime(e));
239 }
240 },
241 Command::Subscribe => match handle.subscribe(MAX_TIMEOUT) {
242 Ok(events) => {
243 for e in events {
244 CommandResult::from(e).to_writer(&mut writer)?;
245 }
246 }
247 Err(e) => return Err(CommandError::Runtime(e)),
248 },
249 Command::Status => {
250 CommandResult::ok().to_writer(writer).ok();
251 }
252 Command::NodeId => match handle.nid() {
253 Ok(nid) => {
254 CommandResult::Okay(nid).to_writer(writer)?;
255 }
256 Err(e) => return Err(CommandError::Runtime(e)),
257 },
258 Command::Debug => {
259 let debug = handle.debug()?;
260
261 CommandResult::Okay(debug).to_writer(writer)?;
262 }
263 Command::Shutdown => {
264 log::debug!(target: "control", "Shutdown requested..");
265 handle.shutdown().ok();
268 CommandResult::ok().to_writer(writer).ok();
269 }
270 }
271 Ok(())
272}
273
274fn fetch<W: Write, H: Handle<Error = runtime::HandleError>>(
275 id: RepoId,
276 node: NodeId,
277 timeout: time::Duration,
278 mut writer: W,
279 handle: &mut H,
280) -> Result<(), CommandError> {
281 match handle.fetch(id, node, timeout) {
282 Ok(result) => {
283 json::to_writer(&mut writer, &result)?;
284 }
285 Err(e) => {
286 return Err(CommandError::Runtime(e));
287 }
288 }
289 Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294 use std::io::prelude::*;
295 use std::thread;
296
297 use super::*;
298 use crate::identity::RepoId;
299 use crate::node::policy::Scope;
300 use crate::node::Handle;
301 use crate::node::{Alias, Node, NodeId};
302 use crate::test;
303
304 #[test]
305 fn test_control_socket() {
306 let tmp = tempfile::tempdir().unwrap();
307 let handle = test::handle::Handle::default();
308 let socket = tmp.path().join("alice.sock");
309 let rids = test::arbitrary::set::<RepoId>(1..3);
310 let listener = Listener::bind(&socket).unwrap();
311
312 thread::spawn({
313 let handle = handle.clone();
314
315 move || listen(listener, handle)
316 });
317
318 for rid in &rids {
319 let mut stream = loop {
320 if let Ok(stream) = Stream::connect(&socket) {
321 break stream;
322 }
323 };
324 writeln!(
325 &mut stream,
326 "{}",
327 json::to_string(&Command::AnnounceRefs {
328 rid: rid.to_owned()
329 })
330 .unwrap()
331 )
332 .unwrap();
333
334 let stream = BufReader::new(stream);
335 let line = stream.lines().next().unwrap().unwrap();
336
337 assert_eq!(
338 line,
339 json::json!({
340 "remote": handle.nid().unwrap(),
341 "at": "0000000000000000000000000000000000000000"
342 })
343 .to_string()
344 );
345 }
346
347 for rid in &rids {
348 assert!(handle.updates.lock().unwrap().contains(rid));
349 }
350 }
351
352 #[test]
353 fn test_seed_unseed() {
354 let tmp = tempfile::tempdir().unwrap();
355 let socket = tmp.path().join("node.sock");
356 let proj = test::arbitrary::gen::<RepoId>(1);
357 let peer = test::arbitrary::gen::<NodeId>(1);
358 let listener = Listener::bind(&socket).unwrap();
359 let mut handle = Node::new(&socket);
360
361 thread::spawn({
362 let handle = crate::test::handle::Handle::default();
363
364 move || crate::control::listen(listener, handle)
365 });
366
367 while !handle.is_running() {}
369
370 assert!(handle.seed(proj, Scope::default()).unwrap());
371 assert!(!handle.seed(proj, Scope::default()).unwrap());
372 assert!(handle.unseed(proj).unwrap());
373 assert!(!handle.unseed(proj).unwrap());
374
375 assert!(handle.follow(peer, Some(Alias::new("alice"))).unwrap());
376 assert!(!handle.follow(peer, Some(Alias::new("alice"))).unwrap());
377 assert!(handle.unfollow(peer).unwrap());
378 assert!(!handle.unfollow(peer).unwrap());
379 }
380}