radicle_node/
control.rs

1//! Client control socket implementation.
2use std::io::prelude::*;
3use std::io::BufReader;
4use std::io::LineWriter;
5use std::os::unix::net::UnixListener;
6use std::os::unix::net::UnixStream;
7use std::path::PathBuf;
8use std::{io, net, time};
9
10use radicle::node::Handle;
11use serde_json as json;
12
13use crate::identity::RepoId;
14use crate::node::NodeId;
15use crate::node::{Command, CommandResult};
16use crate::runtime;
17use crate::runtime::thread;
18
19/// Maximum timeout for waiting for node events.
20const MAX_TIMEOUT: time::Duration = time::Duration::MAX;
21
22#[derive(thiserror::Error, Debug)]
23pub enum Error {
24    #[error("failed to bind control socket listener: {0}")]
25    Bind(io::Error),
26    #[error("invalid socket path specified: {0}")]
27    InvalidPath(PathBuf),
28    #[error("node: {0}")]
29    Node(#[from] runtime::HandleError),
30}
31
32/// Listen for commands on the control socket, and process them.
33pub fn listen<E, H>(listener: UnixListener, handle: H) -> Result<(), Error>
34where
35    H: Handle<Error = runtime::HandleError> + 'static,
36    H::Sessions: serde::Serialize,
37    CommandResult<E>: From<H::Event>,
38    E: serde::Serialize,
39{
40    log::debug!(target: "control", "Control thread listening on socket..");
41    let nid = handle.nid()?;
42
43    for incoming in listener.incoming() {
44        match incoming {
45            Ok(mut stream) => {
46                let handle = handle.clone();
47
48                thread::spawn(&nid, "control", move || {
49                    if let Err(e) = command(&stream, handle) {
50                        log::error!(target: "control", "Command returned error: {e}");
51
52                        CommandResult::error(e).to_writer(&mut stream).ok();
53
54                        stream.flush().ok();
55                        stream.shutdown(net::Shutdown::Both).ok();
56                    }
57                });
58            }
59            Err(e) => log::error!(target: "control", "Failed to accept incoming connection: {}", e),
60        }
61    }
62    log::debug!(target: "control", "Exiting control loop..");
63
64    Ok(())
65}
66
67#[derive(thiserror::Error, Debug)]
68enum CommandError {
69    #[error("(de)serialization failed: {0}")]
70    Serialization(#[from] json::Error),
71    #[error("runtime error: {0}")]
72    Runtime(#[from] runtime::HandleError),
73    #[error("i/o error: {0}")]
74    Io(#[from] io::Error),
75}
76
77fn command<E, H>(stream: &UnixStream, mut handle: H) -> Result<(), CommandError>
78where
79    H: Handle<Error = runtime::HandleError> + 'static,
80    H::Sessions: serde::Serialize,
81    CommandResult<E>: From<H::Event>,
82    E: serde::Serialize,
83{
84    let mut reader = BufReader::new(stream);
85    let mut writer = LineWriter::new(stream);
86    let mut line = String::new();
87
88    reader.read_line(&mut line)?;
89    let input = line.trim_end();
90
91    log::debug!(target: "control", "Received `{input}` on control socket");
92    let cmd: Command = json::from_str(input)?;
93
94    match cmd {
95        Command::Connect { addr, opts } => {
96            let (nid, addr) = addr.into();
97            match handle.connect(nid, addr, opts) {
98                Err(e) => return Err(CommandError::Runtime(e)),
99                Ok(result) => {
100                    json::to_writer(&mut writer, &result)?;
101                    writer.write_all(b"\n")?;
102                }
103            }
104        }
105        Command::Disconnect { nid } => match handle.disconnect(nid) {
106            Err(e) => return Err(CommandError::Runtime(e)),
107            Ok(()) => {
108                CommandResult::ok().to_writer(writer).ok();
109            }
110        },
111        Command::Fetch { rid, nid, timeout } => {
112            fetch(rid, nid, timeout, writer, &mut handle)?;
113        }
114        Command::Config => {
115            let config = handle.config()?;
116
117            CommandResult::Okay(config).to_writer(writer)?;
118        }
119        Command::ListenAddrs => {
120            let addrs = handle.listen_addrs()?;
121
122            CommandResult::Okay(addrs).to_writer(writer)?;
123        }
124        Command::Seeds { rid } => {
125            let seeds = handle.seeds(rid)?;
126
127            CommandResult::Okay(seeds).to_writer(writer)?;
128        }
129        Command::Sessions => {
130            let sessions = handle.sessions()?;
131
132            CommandResult::Okay(sessions).to_writer(writer)?;
133        }
134        Command::Session { nid } => {
135            let session = handle.session(nid)?;
136
137            CommandResult::Okay(session).to_writer(writer)?;
138        }
139        Command::Seed { rid, scope } => match handle.seed(rid, scope) {
140            Ok(result) => {
141                CommandResult::updated(result).to_writer(writer)?;
142            }
143            Err(e) => {
144                return Err(CommandError::Runtime(e));
145            }
146        },
147        Command::Unseed { rid } => match handle.unseed(rid) {
148            Ok(result) => {
149                CommandResult::updated(result).to_writer(writer)?;
150            }
151            Err(e) => {
152                return Err(CommandError::Runtime(e));
153            }
154        },
155        Command::Follow { nid, alias } => match handle.follow(nid, alias) {
156            Ok(result) => {
157                CommandResult::updated(result).to_writer(writer)?;
158            }
159            Err(e) => {
160                return Err(CommandError::Runtime(e));
161            }
162        },
163        Command::Unfollow { nid } => match handle.unfollow(nid) {
164            Ok(result) => {
165                CommandResult::updated(result).to_writer(writer)?;
166            }
167            Err(e) => {
168                return Err(CommandError::Runtime(e));
169            }
170        },
171        Command::AnnounceRefs { rid } => {
172            let refs = handle.announce_refs(rid)?;
173
174            CommandResult::Okay(refs).to_writer(writer)?;
175        }
176        Command::AnnounceInventory => {
177            if let Err(e) = handle.announce_inventory() {
178                return Err(CommandError::Runtime(e));
179            }
180            CommandResult::ok().to_writer(writer).ok();
181        }
182        Command::AddInventory { rid } => match handle.add_inventory(rid) {
183            Ok(result) => {
184                CommandResult::updated(result).to_writer(writer)?;
185            }
186            Err(e) => {
187                return Err(CommandError::Runtime(e));
188            }
189        },
190        Command::Subscribe => match handle.subscribe(MAX_TIMEOUT) {
191            Ok(events) => {
192                for e in events {
193                    CommandResult::from(e).to_writer(&mut writer)?;
194                }
195            }
196            Err(e) => return Err(CommandError::Runtime(e)),
197        },
198        Command::Status => {
199            CommandResult::ok().to_writer(writer).ok();
200        }
201        Command::NodeId => match handle.nid() {
202            Ok(nid) => {
203                CommandResult::Okay(nid).to_writer(writer)?;
204            }
205            Err(e) => return Err(CommandError::Runtime(e)),
206        },
207        Command::Debug => {
208            let debug = handle.debug()?;
209
210            CommandResult::Okay(debug).to_writer(writer)?;
211        }
212        Command::Shutdown => {
213            log::debug!(target: "control", "Shutdown requested..");
214            // Channel might already be disconnected if shutdown
215            // came from somewhere else. Ignore errors.
216            handle.shutdown().ok();
217            CommandResult::ok().to_writer(writer).ok();
218        }
219    }
220    Ok(())
221}
222
223fn fetch<W: Write, H: Handle<Error = runtime::HandleError>>(
224    id: RepoId,
225    node: NodeId,
226    timeout: time::Duration,
227    mut writer: W,
228    handle: &mut H,
229) -> Result<(), CommandError> {
230    match handle.fetch(id, node, timeout) {
231        Ok(result) => {
232            json::to_writer(&mut writer, &result)?;
233        }
234        Err(e) => {
235            return Err(CommandError::Runtime(e));
236        }
237    }
238    Ok(())
239}
240
241#[cfg(test)]
242mod tests {
243    use std::io::prelude::*;
244    use std::os::unix::net::UnixStream;
245    use std::thread;
246
247    use super::*;
248    use crate::identity::RepoId;
249    use crate::node::Handle;
250    use crate::node::{Alias, Node, NodeId};
251    use crate::service::policy::Scope;
252    use crate::test;
253
254    #[test]
255    fn test_control_socket() {
256        let tmp = tempfile::tempdir().unwrap();
257        let handle = test::handle::Handle::default();
258        let socket = tmp.path().join("alice.sock");
259        let rids = test::arbitrary::set::<RepoId>(1..3);
260        let listener = UnixListener::bind(&socket).unwrap();
261
262        thread::spawn({
263            let handle = handle.clone();
264
265            move || listen(listener, handle)
266        });
267
268        for rid in &rids {
269            let stream = loop {
270                if let Ok(stream) = UnixStream::connect(&socket) {
271                    break stream;
272                }
273            };
274            writeln!(
275                &stream,
276                "{}",
277                json::to_string(&Command::AnnounceRefs {
278                    rid: rid.to_owned()
279                })
280                .unwrap()
281            )
282            .unwrap();
283
284            let stream = BufReader::new(stream);
285            let line = stream.lines().next().unwrap().unwrap();
286
287            assert_eq!(
288                line,
289                json::json!({
290                    "remote": handle.nid().unwrap(),
291                    "at": "0000000000000000000000000000000000000000"
292                })
293                .to_string()
294            );
295        }
296
297        for rid in &rids {
298            assert!(handle.updates.lock().unwrap().contains(rid));
299        }
300    }
301
302    #[test]
303    fn test_seed_unseed() {
304        let tmp = tempfile::tempdir().unwrap();
305        let socket = tmp.path().join("node.sock");
306        let proj = test::arbitrary::gen::<RepoId>(1);
307        let peer = test::arbitrary::gen::<NodeId>(1);
308        let listener = UnixListener::bind(&socket).unwrap();
309        let mut handle = Node::new(&socket);
310
311        thread::spawn({
312            let handle = crate::test::handle::Handle::default();
313
314            move || crate::control::listen(listener, handle)
315        });
316
317        // Wait for node to be online.
318        while !handle.is_running() {}
319
320        assert!(handle.seed(proj, Scope::default()).unwrap());
321        assert!(!handle.seed(proj, Scope::default()).unwrap());
322        assert!(handle.unseed(proj).unwrap());
323        assert!(!handle.unseed(proj).unwrap());
324
325        assert!(handle.follow(peer, Some(Alias::new("alice"))).unwrap());
326        assert!(!handle.follow(peer, Some(Alias::new("alice"))).unwrap());
327        assert!(handle.unfollow(peer).unwrap());
328        assert!(!handle.unfollow(peer).unwrap());
329    }
330}