1#![allow(unused_imports)]
2#![allow(unreachable_code)]
3#![allow(unused_variables)]
4use std::{
9 collections::HashMap,
10 env, fs,
11 net::UdpSocket,
12 os::unix::{
13 io::{FromRawFd, IntoRawFd},
14 net::UnixListener,
15 },
16 path::{Path, PathBuf},
17 process::Stdio,
18 time::Duration,
19};
20
21use sysinfo::{Pid, PidExt, SystemExt};
22
23#[cfg(feature = "log")]
24use log::*;
25#[cfg(feature = "log")]
26use syslog::{BasicLogger, Facility, Formatter3164};
27
28use serde::{Deserialize, Serialize};
29
30pub(crate) const ENV_VAR: &str = "RUST_COMPANION";
31pub(crate) const PROGRAM_NAME: &str = "rust-companion";
32
33#[cfg(feature = "log")]
34fn setup_logger() {
35 use companion::PROGRAM_NAME;
36
37 let formatter = Formatter3164 {
38 facility: Facility::LOG_USER,
39 hostname: None,
40 process: PROGRAM_NAME.into(),
41 pid: 0,
42 };
43
44 let logger = syslog::unix(formatter).expect("could not connect to syslog");
45 log::set_boxed_logger(Box::new(BasicLogger::new(logger)))
46 .map(|()| log::set_max_level(LevelFilter::Info))
47 .expect("could not register logger");
48}
49
50#[derive(Serialize, Deserialize, Debug)]
51pub enum Response {
52 String(String),
53 List(Vec<String>),
54 Ok,
55 NotFound,
56}
57
58impl Response {
59 pub fn as_bytes(&self) -> Vec<u8> {
60 bincode::serialize(&self).unwrap()
61 }
62}
63
64impl From<&[u8]> for Response {
65 fn from(bytes: &[u8]) -> Self {
66 bincode::deserialize(bytes).unwrap()
67 }
68}
69
70#[derive(Serialize, Deserialize, Debug)]
71pub enum Task<'a> {
72 Get(&'a str),
74 Set(&'a str, &'a str),
76 List,
78 Sum(Vec<i64>),
79 Shutdown,
80}
81
82impl<'a> Task<'a> {
83 pub fn as_bytes(&self) -> Vec<u8> {
84 bincode::serialize(&self).unwrap()
85 }
86}
87
88pub fn companion_addr() -> String {
89 if let Ok(addr) = env::var(ENV_VAR) {
90 addr
91 } else {
92 "[::]:2000".into()
96 }
97}
98
99pub fn pid_path() -> PathBuf {
100 let mut dir = std::env::temp_dir();
101 dir.push(&format!("{}.pid", PROGRAM_NAME));
102 dir
103}
104
105fn check_started<P>(path: P) -> bool
106where
107 P: AsRef<Path>,
108{
109 if let Ok(pids) = fs::read_to_string(&path) {
110 let sys = sysinfo::System::new_all();
112 let processes = sys.processes();
113 let pids: Vec<u32> = pids.lines().filter_map(|s| s.parse::<u32>().ok()).collect();
114 let mut started = false;
115 let mut new_pids = vec![];
116 for pid in pids.iter() {
117 if processes.contains_key(&Pid::from_u32(*pid)) {
118 started = true;
119 new_pids.push(*pid);
120 }
121 }
122
123 if started {
124 let contents = new_pids
125 .iter()
126 .map(|v| v.to_string())
127 .collect::<Vec<String>>()
128 .join("\n");
129
130 fs::write(&path, contents).unwrap();
131 return true;
132 }
133 }
134 false
135}
136
137pub fn launch<P>(path: P)
138where
139 P: AsRef<Path>,
140{
141 #[cfg(feature = "log")]
142 setup_logger();
143
144 let pid = std::process::id();
145
146 fs::write(path, pid.to_string()).unwrap();
147
148 let socket_path = companion_addr();
149
150 let mut storage: HashMap<String, String> = HashMap::new();
151
152 let sock = UdpSocket::bind(&socket_path).unwrap();
153
154 'outer: loop {
155 let mut buf = [0; 65507];
156 let sock = sock.try_clone().expect("Failed to clone socket");
157
158 let (len, src) = sock.recv_from(&mut buf).unwrap();
159 let buf = &mut buf[..len];
160
161 let task: Task = bincode::deserialize(buf).unwrap();
162 println!("{task:?}");
163 match task {
164 Task::Get(key) => {
165 #[cfg(feature = "log")]
166 log::info!("get {}", key);
167 match storage.get(key) {
168 Some(data) => {
169 let buf = bincode::serialize(&Response::String(data.clone())).unwrap();
170 sock.send_to(&buf, src).unwrap();
171 }
172 None => {
173 let buf = bincode::serialize(&Response::NotFound).unwrap();
174 sock.send_to(&buf, src).unwrap();
175 }
176 }
177 }
178 Task::Set(key, data) => {
179 #[cfg(feature = "log")]
180 log::info!("set {}", key);
181 storage.insert(key.into(), data.into());
182 let buf = bincode::serialize(&Response::Ok).unwrap();
183 sock.send_to(&buf, src).unwrap();
184 }
185 Task::List => {
186 let keys: Vec<String> = storage.keys().map(Clone::clone).collect();
187 let buf = bincode::serialize(&Response::List(keys)).unwrap();
188 sock.send_to(&buf, src).unwrap();
189 }
190 Task::Sum(_values) => {
191 #[cfg(feature = "log")]
192 log::info!("shutdown");
193 }
195 Task::Shutdown => {
196 #[cfg(feature = "log")]
197 log::info!("shutdown");
198 break 'outer;
199 }
200 }
201 }
202}
203
204pub fn lockfile() -> String {
205 let mut path = PathBuf::new();
206 let source = PathBuf::from(std::env::var("OUT_DIR").unwrap());
208 let mut prev = String::new();
209 for part in source.iter() {
210 if prev == "target" && (part == "debug" || part == "release") {
211 break;
212 }
213 prev = part.to_string_lossy().into();
214 path.push(part);
215 }
216 path.push("companion.lock");
217 path.as_os_str().to_string_lossy().into()
218}
219
220pub fn bootstrap() -> std::result::Result<String, Box<dyn std::error::Error>> {
221 let pid_path = pid_path();
222
223 let lockfile = lockfile();
224
225 if !check_started(&pid_path) {
226 match env::args().nth(1) {
227 Some(arg) => {
228 if arg == "-d" {
229 launch(&pid_path);
230 }
231 }
232 None => {
233 match env::current_exe() {
234 Ok(exe) => {
235 let _child = std::process::Command::new(&exe)
236 .arg("-d")
237 .stderr(Stdio::null())
238 .stdout(Stdio::null())
239 .spawn()?;
240 std::thread::sleep(Duration::from_micros(50));
241 let exe = exe.as_os_str().to_string_lossy().to_string();
243 let _ = std::fs::write(&lockfile, &exe);
244 }
245 Err(e) => println!("failed to get current exe path: {e}"),
246 };
247 }
248 }
249 }
250
251 Ok(lockfile)
252}