hypermangle_core/
console.rs1use std::{ffi::OsString, mem::take};
2
3use clap::{crate_name, Parser};
4use futures::AsyncReadExt;
5use interprocess::local_socket::tokio::{LocalSocketListener, LocalSocketStream};
6use log::error;
7use serde::{Deserialize, Serialize};
8
9use futures::AsyncWriteExt;
10use tokio::sync::mpsc;
11
12pub struct RemoteClient {
13 stream: Option<LocalSocketStream>,
14}
15
16impl RemoteClient {
17 pub async fn send(&mut self, msg: String) {
18 if let Err(e) = send_msg(BaseCommand::Packet(msg), self.stream.as_mut().unwrap()).await {
19 error!("Faced the following error while responding to remote client: {e}");
20 }
21 }
22}
23
24impl Drop for RemoteClient {
25 fn drop(&mut self) {
26 let mut stream = take(&mut self.stream).unwrap();
27 tokio::spawn(async move {
28 if let Err(e) = send_msg(BaseCommand::CloseSocket, &mut stream).await {
29 error!("Faced the following error while ending connection to remote client: {e}");
30 }
31 });
32 }
33}
34
35#[derive(Serialize, Deserialize)]
36enum BaseCommand {
37 IdRequest,
38 IdResponse(u32),
39 Args(Vec<OsString>),
40 Packet(String),
41 CloseSocket,
42}
43
44fn get_socket_name() -> String {
45 format!("/run/{}.sock", crate_name!())
46}
47
48#[tokio::main(flavor = "current_thread")]
49pub async fn does_remote_exist() -> Option<u32> {
50 let Ok(mut stream) = LocalSocketStream::connect(get_socket_name()).await else {
51 return None;
52 };
53 send_msg(BaseCommand::IdRequest, &mut stream).await.ok()?;
54 let Ok(BaseCommand::IdResponse(id)) = recv_msg(&mut stream).await else {
55 panic!("Remote service should have responded with is Process ID")
56 };
57 Some(id)
58}
59
60async fn send_msg(msg: BaseCommand, stream: &mut LocalSocketStream) -> std::io::Result<()> {
61 let mut msg = bincode::serialize(&msg).unwrap();
62
63 let mut tmp = msg.len().to_ne_bytes().to_vec();
64 tmp.append(&mut msg);
65 msg = tmp;
66
67 stream.write_all(&msg).await
68}
69
70async fn recv_msg(
71 stream: &mut LocalSocketStream,
72) -> Result<BaseCommand, Box<dyn std::error::Error>> {
73 let mut msg_size = [0u8; (usize::BITS / 8) as usize];
74 stream.read_exact(&mut msg_size).await.map_err(Box::new)?;
75 let msg_size = usize::from_ne_bytes(msg_size);
76 let mut msg = vec![0u8; msg_size];
77 stream.read_exact(&mut msg).await.map_err(Box::new)?;
78
79 bincode::deserialize(&msg).map_err(Into::into)
80}
81
82#[tokio::main(flavor = "current_thread")]
83pub async fn send_args_to_remote() {
84 let mut stream = LocalSocketStream::connect(get_socket_name())
85 .await
86 .expect("Connection to remote service should have succeeded");
87
88 send_msg(
89 BaseCommand::Args(std::env::args_os().collect()),
90 &mut stream,
91 )
92 .await
93 .expect("Remote service should have accepted the given arguments");
94
95 loop {
96 let msg = recv_msg(&mut stream)
97 .await
98 .expect("Remote service should have sent a valid message");
99
100 match msg {
101 BaseCommand::Packet(msg) => print!("{msg}"),
102 BaseCommand::CloseSocket => break,
103 _ => {}
104 }
105 }
106}
107
108pub trait ExecutableArgs: Parser + Send + 'static {
109 fn execute(self, writer: RemoteClient) -> impl std::future::Future<Output=bool> + Send;
110}
111
112pub fn listen_for_commands<P: ExecutableArgs>() -> impl std::future::Future<Output=()> {
113 let (sender, receiver) = mpsc::channel(1);
114 tokio::spawn(listen_for_commands_inner::<P>(receiver));
115 async move {
116 let _sender = sender;
117 std::future::pending::<()>().await;
118 }
119}
120
121
122async fn listen_for_commands_inner<P: ExecutableArgs + Send>(mut receiver: mpsc::Receiver<()>) {
123 #[cfg(unix)]
124 let _ = std::fs::remove_file(get_socket_name());
125
126 let listener = LocalSocketListener::bind(get_socket_name())
127 .expect("Command listener should have started successfully");
128
129 loop {
130 let mut stream;
131
132 macro_rules! unwrap {
133 ($result: expr) => {
134 match $result {
135 Ok(x) => x,
136 Err(e) => {
137 error!("Faced the following error while listening for commands: {e}");
138 continue;
140 }
141 }
142 };
143 }
144
145 tokio::select! {
146 _ = receiver.recv() => {
147 break
148 }
149 result = listener.accept() => {
150 stream = unwrap!(result);
151 }
152 }
153
154 let msg: BaseCommand = unwrap!(recv_msg(&mut stream).await);
155
156 match msg {
157 BaseCommand::IdRequest => {
158 unwrap!(send_msg(BaseCommand::IdResponse(std::process::id()), &mut stream).await);
159 }
160 BaseCommand::Args(args) => {
161 let args = match P::try_parse_from(args) {
162 Ok(x) => x,
163 Err(e) => {
164 unwrap!(send_msg(BaseCommand::Packet(e.to_string()), &mut stream).await);
165 let _ = stream.close().await;
166 continue;
167 }
168 };
169 if args
170 .execute(RemoteClient {
171 stream: Some(stream),
172 })
173 .await
174 {
175 break;
176 }
177 continue;
178 }
179 _ => {}
180 }
181
182 unwrap!(send_msg(BaseCommand::CloseSocket, &mut stream).await);
183 }
184}