companion/
lib.rs

1#![allow(unused_imports)]
2#![allow(unreachable_code)]
3#![allow(unused_variables)]
4//! This crate implements a minimal abstraction over Udp/UNIX domain sockets for
5//! the purpose of IPC.  It lets you send both file handles and rust objects
6//! between processes.
7//!
8use std::{
9    collections::HashMap,
10    env, fs,
11    net::UdpSocket,
12    os::unix::{
13        io::{FromRawFd, IntoRawFd},
14        net::UnixListener,
15    },
16    path::{Path, PathBuf},
17    process::Stdio,
18    time::Duration,
19};
20
21use sysinfo::{Pid, PidExt, SystemExt};
22
23#[cfg(feature = "log")]
24use log::*;
25#[cfg(feature = "log")]
26use syslog::{BasicLogger, Facility, Formatter3164};
27
28use serde::{Deserialize, Serialize};
29
30pub(crate) const ENV_VAR: &str = "RUST_COMPANION";
31pub(crate) const PROGRAM_NAME: &str = "rust-companion";
32
33#[cfg(feature = "log")]
34fn setup_logger() {
35    use companion::PROGRAM_NAME;
36
37    let formatter = Formatter3164 {
38        facility: Facility::LOG_USER,
39        hostname: None,
40        process: PROGRAM_NAME.into(),
41        pid: 0,
42    };
43
44    let logger = syslog::unix(formatter).expect("could not connect to syslog");
45    log::set_boxed_logger(Box::new(BasicLogger::new(logger)))
46        .map(|()| log::set_max_level(LevelFilter::Info))
47        .expect("could not register logger");
48}
49
50#[derive(Serialize, Deserialize, Debug)]
51pub enum Response {
52    String(String),
53    List(Vec<String>),
54    Ok,
55    NotFound,
56}
57
58impl Response {
59    pub fn as_bytes(&self) -> Vec<u8> {
60        bincode::serialize(&self).unwrap()
61    }
62}
63
64impl From<&[u8]> for Response {
65    fn from(bytes: &[u8]) -> Self {
66        bincode::deserialize(bytes).unwrap()
67    }
68}
69
70#[derive(Serialize, Deserialize, Debug)]
71pub enum Task<'a> {
72    // Get data by name
73    Get(&'a str),
74    // Store data by name
75    Set(&'a str, &'a str),
76    // List stored names
77    List,
78    Sum(Vec<i64>),
79    Shutdown,
80}
81
82impl<'a> Task<'a> {
83    pub fn as_bytes(&self) -> Vec<u8> {
84        bincode::serialize(&self).unwrap()
85    }
86}
87
88pub fn companion_addr() -> String {
89    if let Ok(addr) = env::var(ENV_VAR) {
90        addr
91    } else {
92        // let mut dir = std::env::temp_dir();
93        // dir.push(&format!("{}.sock", PROGRAM_NAME));
94        // dir
95        "[::]:2000".into()
96    }
97}
98
99pub fn pid_path() -> PathBuf {
100    let mut dir = std::env::temp_dir();
101    dir.push(&format!("{}.pid", PROGRAM_NAME));
102    dir
103}
104
105fn check_started<P>(path: P) -> bool
106where
107    P: AsRef<Path>,
108{
109    if let Ok(pids) = fs::read_to_string(&path) {
110        // println!("pids: {pids}");
111        let sys = sysinfo::System::new_all();
112        let processes = sys.processes();
113        let pids: Vec<u32> = pids.lines().filter_map(|s| s.parse::<u32>().ok()).collect();
114        let mut started = false;
115        let mut new_pids = vec![];
116        for pid in pids.iter() {
117            if processes.contains_key(&Pid::from_u32(*pid)) {
118                started = true;
119                new_pids.push(*pid);
120            }
121        }
122
123        if started {
124            let contents = new_pids
125                .iter()
126                .map(|v| v.to_string())
127                .collect::<Vec<String>>()
128                .join("\n");
129
130            fs::write(&path, contents).unwrap();
131            return true;
132        }
133    }
134    false
135}
136
137pub fn launch<P>(path: P)
138where
139    P: AsRef<Path>,
140{
141    #[cfg(feature = "log")]
142    setup_logger();
143
144    let pid = std::process::id();
145
146    fs::write(path, pid.to_string()).unwrap();
147
148    let socket_path = companion_addr();
149
150    let mut storage: HashMap<String, String> = HashMap::new();
151
152    let sock = UdpSocket::bind(&socket_path).unwrap();
153
154    'outer: loop {
155        let mut buf = [0; 65507];
156        let sock = sock.try_clone().expect("Failed to clone socket");
157
158        let (len, src) = sock.recv_from(&mut buf).unwrap();
159        let buf = &mut buf[..len];
160
161        let task: Task = bincode::deserialize(buf).unwrap();
162        println!("{task:?}");
163        match task {
164            Task::Get(key) => {
165                #[cfg(feature = "log")]
166                log::info!("get {}", key);
167                match storage.get(key) {
168                    Some(data) => {
169                        let buf = bincode::serialize(&Response::String(data.clone())).unwrap();
170                        sock.send_to(&buf, src).unwrap();
171                    }
172                    None => {
173                        let buf = bincode::serialize(&Response::NotFound).unwrap();
174                        sock.send_to(&buf, src).unwrap();
175                    }
176                }
177            }
178            Task::Set(key, data) => {
179                #[cfg(feature = "log")]
180                log::info!("set {}", key);
181                storage.insert(key.into(), data.into());
182                let buf = bincode::serialize(&Response::Ok).unwrap();
183                sock.send_to(&buf, src).unwrap();
184            }
185            Task::List => {
186                let keys: Vec<String> = storage.keys().map(Clone::clone).collect();
187                let buf = bincode::serialize(&Response::List(keys)).unwrap();
188                sock.send_to(&buf, src).unwrap();
189            }
190            Task::Sum(_values) => {
191                #[cfg(feature = "log")]
192                log::info!("shutdown");
193                // tx.send(Response::NotFound).unwrap();
194            }
195            Task::Shutdown => {
196                #[cfg(feature = "log")]
197                log::info!("shutdown");
198                break 'outer;
199            }
200        }
201    }
202}
203
204pub fn lockfile() -> String {
205    let mut path = PathBuf::new();
206    // from outdir
207    let source = PathBuf::from(std::env::var("OUT_DIR").unwrap());
208    let mut prev = String::new();
209    for part in source.iter() {
210        if prev == "target" && (part == "debug" || part == "release") {
211            break;
212        }
213        prev = part.to_string_lossy().into();
214        path.push(part);
215    }
216    path.push("companion.lock");
217    path.as_os_str().to_string_lossy().into()
218}
219
220pub fn bootstrap() -> std::result::Result<String, Box<dyn std::error::Error>> {
221    let pid_path = pid_path();
222
223    let lockfile = lockfile();
224
225    if !check_started(&pid_path) {
226        match env::args().nth(1) {
227            Some(arg) => {
228                if arg == "-d" {
229                    launch(&pid_path);
230                }
231            }
232            None => {
233                match env::current_exe() {
234                    Ok(exe) => {
235                        let _child = std::process::Command::new(&exe)
236                            .arg("-d")
237                            .stderr(Stdio::null())
238                            .stdout(Stdio::null())
239                            .spawn()?;
240                        std::thread::sleep(Duration::from_micros(50));
241                        // write lock file
242                        let exe = exe.as_os_str().to_string_lossy().to_string();
243                        let _ = std::fs::write(&lockfile, &exe);
244                    }
245                    Err(e) => println!("failed to get current exe path: {e}"),
246                };
247            }
248        }
249    }
250
251    Ok(lockfile)
252}