radicle_node/
control.rs

1//! Client control socket implementation.
2use std::io::prelude::*;
3use std::io::BufReader;
4use std::io::LineWriter;
5use std::path::PathBuf;
6use std::{io, net, time};
7
8#[cfg(unix)]
9use std::os::unix::net::{UnixListener as Listener, UnixStream as Stream};
10#[cfg(windows)]
11use winpipe::{WinListener as Listener, WinStream as Stream};
12
13use radicle::node::Handle;
14use serde_json as json;
15
16use crate::identity::RepoId;
17use crate::node::NodeId;
18use crate::node::{Command, CommandResult};
19use crate::runtime;
20use crate::runtime::thread;
21
22/// Maximum timeout for waiting for node events.
23const MAX_TIMEOUT: time::Duration = time::Duration::MAX;
24
25#[derive(thiserror::Error, Debug)]
26pub enum Error {
27    #[error("failed to bind control socket listener: {0}")]
28    Bind(io::Error),
29    #[error("invalid socket path specified: {0}")]
30    InvalidPath(PathBuf),
31    #[error("node: {0}")]
32    Node(#[from] runtime::HandleError),
33}
34
35/// Listen for commands on the control socket, and process them.
36pub fn listen<E, H>(listener: Listener, handle: H) -> Result<(), Error>
37where
38    H: Handle<Error = runtime::HandleError> + 'static,
39    H::Sessions: serde::Serialize,
40    CommandResult<E>: From<H::Event>,
41    E: serde::Serialize,
42{
43    log::debug!(target: "control", "Control thread listening on socket..");
44    let nid = handle.nid()?;
45
46    for incoming in listener.incoming() {
47        match incoming {
48            Ok(stream) => {
49                let handle = handle.clone();
50
51                thread::spawn(&nid, "control", move || {
52                    if let Err((e, mut stream)) = command(stream, handle) {
53                        log::error!(target: "control", "Command returned error: {e}");
54
55                        CommandResult::error(e).to_writer(&mut stream).ok();
56
57                        stream.flush().ok();
58                        stream.shutdown(net::Shutdown::Both).ok();
59                    }
60                });
61            }
62            Err(e) => log::error!(target: "control", "Failed to accept incoming connection: {e}"),
63        }
64    }
65    log::debug!(target: "control", "Exiting control loop..");
66
67    Ok(())
68}
69
70#[derive(thiserror::Error, Debug)]
71enum CommandError {
72    #[error("(de)serialization failed: {0}")]
73    Serialization(#[from] json::Error),
74    #[error("runtime error: {0}")]
75    Runtime(#[from] runtime::HandleError),
76    #[error("i/o error: {0}")]
77    Io(#[from] io::Error),
78}
79
80#[cfg(unix)]
81fn command<E, H>(stream: Stream, handle: H) -> Result<(), (CommandError, Stream)>
82where
83    H: Handle<Error = runtime::HandleError> + 'static,
84    H::Sessions: serde::Serialize,
85    CommandResult<E>: From<H::Event>,
86    E: serde::Serialize,
87{
88    let reader = BufReader::new(&stream);
89    let writer = LineWriter::new(&stream);
90
91    command_internal(reader, writer, handle).map_err(|e| (e, stream))
92}
93
94/// Due to different mutability requirements between Unix and Windows,
95/// we are forced to clone the stream on Windows.
96///
97/// # Errors
98///
99/// As of winpipe 0.1.1, [`WinStream::try_clone`] is actually infallible.
100#[cfg(windows)]
101fn command<E, H>(stream: Stream, handle: H) -> Result<(), (CommandError, Stream)>
102where
103    H: Handle<Error = runtime::HandleError> + 'static,
104    H::Sessions: serde::Serialize,
105    CommandResult<E>: From<H::Event>,
106    E: serde::Serialize,
107{
108    let mut reader = match stream.try_clone() {
109        Ok(reader) => reader,
110        Err(err) => return Err((err.into(), stream)),
111    };
112    let reader = BufReader::new(&mut reader);
113
114    let mut writer = match stream.try_clone() {
115        Ok(writer) => writer,
116        Err(err) => return Err((err.into(), stream)),
117    };
118    let writer = LineWriter::new(&mut writer);
119
120    command_internal(reader, writer, handle).map_err(|e| (e, stream))
121}
122
123#[inline(always)]
124fn command_internal<E, H, R, W>(
125    mut reader: BufReader<R>,
126    mut writer: LineWriter<W>,
127    mut handle: H,
128) -> Result<(), CommandError>
129where
130    H: Handle<Error = runtime::HandleError> + 'static,
131    H::Sessions: serde::Serialize,
132    CommandResult<E>: From<H::Event>,
133    E: serde::Serialize,
134    R: io::Read,
135    W: io::Write,
136{
137    let mut line = String::new();
138
139    reader.read_line(&mut line)?;
140    let input = line.trim_end();
141
142    log::debug!(target: "control", "Received `{input}` on control socket");
143    let cmd: Command = json::from_str(input)?;
144
145    match cmd {
146        Command::Connect { addr, opts } => {
147            let (nid, addr) = addr.into();
148            match handle.connect(nid, addr, opts) {
149                Err(e) => return Err(CommandError::Runtime(e)),
150                Ok(result) => {
151                    json::to_writer(&mut writer, &result)?;
152                    writer.write_all(b"\n")?;
153                }
154            }
155        }
156        Command::Disconnect { nid } => match handle.disconnect(nid) {
157            Err(e) => return Err(CommandError::Runtime(e)),
158            Ok(()) => {
159                CommandResult::ok().to_writer(writer).ok();
160            }
161        },
162        Command::Fetch { rid, nid, timeout } => {
163            fetch(rid, nid, timeout, writer, &mut handle)?;
164        }
165        Command::Config => {
166            let config = handle.config()?;
167
168            CommandResult::Okay(config).to_writer(writer)?;
169        }
170        Command::ListenAddrs => {
171            let addrs = handle.listen_addrs()?;
172
173            CommandResult::Okay(addrs).to_writer(writer)?;
174        }
175        Command::Seeds { rid } => {
176            let seeds = handle.seeds(rid)?;
177
178            CommandResult::Okay(seeds).to_writer(writer)?;
179        }
180        Command::Sessions => {
181            let sessions = handle.sessions()?;
182
183            CommandResult::Okay(sessions).to_writer(writer)?;
184        }
185        Command::Session { nid } => {
186            let session = handle.session(nid)?;
187
188            CommandResult::Okay(session).to_writer(writer)?;
189        }
190        Command::Seed { rid, scope } => match handle.seed(rid, scope) {
191            Ok(result) => {
192                CommandResult::updated(result).to_writer(writer)?;
193            }
194            Err(e) => {
195                return Err(CommandError::Runtime(e));
196            }
197        },
198        Command::Unseed { rid } => match handle.unseed(rid) {
199            Ok(result) => {
200                CommandResult::updated(result).to_writer(writer)?;
201            }
202            Err(e) => {
203                return Err(CommandError::Runtime(e));
204            }
205        },
206        Command::Follow { nid, alias } => match handle.follow(nid, alias) {
207            Ok(result) => {
208                CommandResult::updated(result).to_writer(writer)?;
209            }
210            Err(e) => {
211                return Err(CommandError::Runtime(e));
212            }
213        },
214        Command::Unfollow { nid } => match handle.unfollow(nid) {
215            Ok(result) => {
216                CommandResult::updated(result).to_writer(writer)?;
217            }
218            Err(e) => {
219                return Err(CommandError::Runtime(e));
220            }
221        },
222        Command::AnnounceRefs { rid } => {
223            let refs = handle.announce_refs(rid)?;
224
225            CommandResult::Okay(refs).to_writer(writer)?;
226        }
227        Command::AnnounceInventory => {
228            if let Err(e) = handle.announce_inventory() {
229                return Err(CommandError::Runtime(e));
230            }
231            CommandResult::ok().to_writer(writer).ok();
232        }
233        Command::AddInventory { rid } => match handle.add_inventory(rid) {
234            Ok(result) => {
235                CommandResult::updated(result).to_writer(writer)?;
236            }
237            Err(e) => {
238                return Err(CommandError::Runtime(e));
239            }
240        },
241        Command::Subscribe => match handle.subscribe(MAX_TIMEOUT) {
242            Ok(events) => {
243                for e in events {
244                    CommandResult::from(e).to_writer(&mut writer)?;
245                }
246            }
247            Err(e) => return Err(CommandError::Runtime(e)),
248        },
249        Command::Status => {
250            CommandResult::ok().to_writer(writer).ok();
251        }
252        Command::NodeId => match handle.nid() {
253            Ok(nid) => {
254                CommandResult::Okay(nid).to_writer(writer)?;
255            }
256            Err(e) => return Err(CommandError::Runtime(e)),
257        },
258        Command::Debug => {
259            let debug = handle.debug()?;
260
261            CommandResult::Okay(debug).to_writer(writer)?;
262        }
263        Command::Shutdown => {
264            log::debug!(target: "control", "Shutdown requested..");
265            // Channel might already be disconnected if shutdown
266            // came from somewhere else. Ignore errors.
267            handle.shutdown().ok();
268            CommandResult::ok().to_writer(writer).ok();
269        }
270    }
271    Ok(())
272}
273
274fn fetch<W: Write, H: Handle<Error = runtime::HandleError>>(
275    id: RepoId,
276    node: NodeId,
277    timeout: time::Duration,
278    mut writer: W,
279    handle: &mut H,
280) -> Result<(), CommandError> {
281    match handle.fetch(id, node, timeout) {
282        Ok(result) => {
283            json::to_writer(&mut writer, &result)?;
284        }
285        Err(e) => {
286            return Err(CommandError::Runtime(e));
287        }
288    }
289    Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294    use std::io::prelude::*;
295    use std::thread;
296
297    use super::*;
298    use crate::identity::RepoId;
299    use crate::node::policy::Scope;
300    use crate::node::Handle;
301    use crate::node::{Alias, Node, NodeId};
302    use crate::test;
303
304    #[test]
305    fn test_control_socket() {
306        let tmp = tempfile::tempdir().unwrap();
307        let handle = test::handle::Handle::default();
308        let socket = tmp.path().join("alice.sock");
309        let rids = test::arbitrary::set::<RepoId>(1..3);
310        let listener = Listener::bind(&socket).unwrap();
311
312        thread::spawn({
313            let handle = handle.clone();
314
315            move || listen(listener, handle)
316        });
317
318        for rid in &rids {
319            let mut stream = loop {
320                if let Ok(stream) = Stream::connect(&socket) {
321                    break stream;
322                }
323            };
324            writeln!(
325                &mut stream,
326                "{}",
327                json::to_string(&Command::AnnounceRefs {
328                    rid: rid.to_owned()
329                })
330                .unwrap()
331            )
332            .unwrap();
333
334            let stream = BufReader::new(stream);
335            let line = stream.lines().next().unwrap().unwrap();
336
337            assert_eq!(
338                line,
339                json::json!({
340                    "remote": handle.nid().unwrap(),
341                    "at": "0000000000000000000000000000000000000000"
342                })
343                .to_string()
344            );
345        }
346
347        for rid in &rids {
348            assert!(handle.updates.lock().unwrap().contains(rid));
349        }
350    }
351
352    #[test]
353    fn test_seed_unseed() {
354        let tmp = tempfile::tempdir().unwrap();
355        let socket = tmp.path().join("node.sock");
356        let proj = test::arbitrary::gen::<RepoId>(1);
357        let peer = test::arbitrary::gen::<NodeId>(1);
358        let listener = Listener::bind(&socket).unwrap();
359        let mut handle = Node::new(&socket);
360
361        thread::spawn({
362            let handle = crate::test::handle::Handle::default();
363
364            move || crate::control::listen(listener, handle)
365        });
366
367        // Wait for node to be online.
368        while !handle.is_running() {}
369
370        assert!(handle.seed(proj, Scope::default()).unwrap());
371        assert!(!handle.seed(proj, Scope::default()).unwrap());
372        assert!(handle.unseed(proj).unwrap());
373        assert!(!handle.unseed(proj).unwrap());
374
375        assert!(handle.follow(peer, Some(Alias::new("alice"))).unwrap());
376        assert!(!handle.follow(peer, Some(Alias::new("alice"))).unwrap());
377        assert!(handle.unfollow(peer).unwrap());
378        assert!(!handle.unfollow(peer).unwrap());
379    }
380}