hypermangle_core/
console.rs

1use std::{ffi::OsString, mem::take};
2
3use clap::{crate_name, Parser};
4use futures::AsyncReadExt;
5use interprocess::local_socket::tokio::{LocalSocketListener, LocalSocketStream};
6use log::error;
7use serde::{Deserialize, Serialize};
8
9use futures::AsyncWriteExt;
10use tokio::sync::mpsc;
11
12pub struct RemoteClient {
13    stream: Option<LocalSocketStream>,
14}
15
16impl RemoteClient {
17    pub async fn send(&mut self, msg: String) {
18        if let Err(e) = send_msg(BaseCommand::Packet(msg), self.stream.as_mut().unwrap()).await {
19            error!("Faced the following error while responding to remote client: {e}");
20        }
21    }
22}
23
24impl Drop for RemoteClient {
25    fn drop(&mut self) {
26        let mut stream = take(&mut self.stream).unwrap();
27        tokio::spawn(async move {
28            if let Err(e) = send_msg(BaseCommand::CloseSocket, &mut stream).await {
29                error!("Faced the following error while ending connection to remote client: {e}");
30            }
31        });
32    }
33}
34
35#[derive(Serialize, Deserialize)]
36enum BaseCommand {
37    IdRequest,
38    IdResponse(u32),
39    Args(Vec<OsString>),
40    Packet(String),
41    CloseSocket,
42}
43
44fn get_socket_name() -> String {
45    format!("/run/{}.sock", crate_name!())
46}
47
48#[tokio::main(flavor = "current_thread")]
49pub async fn does_remote_exist() -> Option<u32> {
50    let Ok(mut stream) = LocalSocketStream::connect(get_socket_name()).await else {
51        return None;
52    };
53    send_msg(BaseCommand::IdRequest, &mut stream).await.ok()?;
54    let Ok(BaseCommand::IdResponse(id)) = recv_msg(&mut stream).await else {
55        panic!("Remote service should have responded with is Process ID")
56    };
57    Some(id)
58}
59
60async fn send_msg(msg: BaseCommand, stream: &mut LocalSocketStream) -> std::io::Result<()> {
61    let mut msg = bincode::serialize(&msg).unwrap();
62
63    let mut tmp = msg.len().to_ne_bytes().to_vec();
64    tmp.append(&mut msg);
65    msg = tmp;
66
67    stream.write_all(&msg).await
68}
69
70async fn recv_msg(
71    stream: &mut LocalSocketStream,
72) -> Result<BaseCommand, Box<dyn std::error::Error>> {
73    let mut msg_size = [0u8; (usize::BITS / 8) as usize];
74    stream.read_exact(&mut msg_size).await.map_err(Box::new)?;
75    let msg_size = usize::from_ne_bytes(msg_size);
76    let mut msg = vec![0u8; msg_size];
77    stream.read_exact(&mut msg).await.map_err(Box::new)?;
78
79    bincode::deserialize(&msg).map_err(Into::into)
80}
81
82#[tokio::main(flavor = "current_thread")]
83pub async fn send_args_to_remote() {
84    let mut stream = LocalSocketStream::connect(get_socket_name())
85        .await
86        .expect("Connection to remote service should have succeeded");
87
88    send_msg(
89        BaseCommand::Args(std::env::args_os().collect()),
90        &mut stream,
91    )
92    .await
93    .expect("Remote service should have accepted the given arguments");
94
95    loop {
96        let msg = recv_msg(&mut stream)
97            .await
98            .expect("Remote service should have sent a valid message");
99
100        match msg {
101            BaseCommand::Packet(msg) => print!("{msg}"),
102            BaseCommand::CloseSocket => break,
103            _ => {}
104        }
105    }
106}
107
108pub trait ExecutableArgs: Parser + Send + 'static {
109    fn execute(self, writer: RemoteClient) -> impl std::future::Future<Output=bool> + Send;
110}
111
112pub fn listen_for_commands<P: ExecutableArgs>() -> impl std::future::Future<Output=()> {
113    let (sender, receiver) = mpsc::channel(1);
114    tokio::spawn(listen_for_commands_inner::<P>(receiver));
115    async move {
116        let _sender = sender;
117        std::future::pending::<()>().await;
118    }
119}
120
121
122async fn listen_for_commands_inner<P: ExecutableArgs + Send>(mut receiver: mpsc::Receiver<()>) {
123    #[cfg(unix)]
124    let _ = std::fs::remove_file(get_socket_name());
125
126    let listener = LocalSocketListener::bind(get_socket_name())
127        .expect("Command listener should have started successfully");
128
129    loop {
130        let mut stream;
131
132        macro_rules! unwrap {
133            ($result: expr) => {
134                match $result {
135                    Ok(x) => x,
136                    Err(e) => {
137                        error!("Faced the following error while listening for commands: {e}");
138                        // let _ = send_msg(BaseCommand::Packet(e.to_string()), &mut stream).await;
139                        continue;
140                    }
141                }
142            };
143        }
144
145        tokio::select! {
146            _ = receiver.recv() => {
147                break
148            }
149            result = listener.accept() => {
150                stream = unwrap!(result);
151            }
152        }
153
154        let msg: BaseCommand = unwrap!(recv_msg(&mut stream).await);
155
156        match msg {
157            BaseCommand::IdRequest => {
158                unwrap!(send_msg(BaseCommand::IdResponse(std::process::id()), &mut stream).await);
159            }
160            BaseCommand::Args(args) => {
161                let args = match P::try_parse_from(args) {
162                    Ok(x) => x,
163                    Err(e) => {
164                        unwrap!(send_msg(BaseCommand::Packet(e.to_string()), &mut stream).await);
165                        let _ = stream.close().await;
166                        continue;
167                    }
168                };
169                if args
170                    .execute(RemoteClient {
171                        stream: Some(stream),
172                    })
173                    .await
174                {
175                    break;
176                }
177                continue;
178            }
179            _ => {}
180        }
181
182        unwrap!(send_msg(BaseCommand::CloseSocket, &mut stream).await);
183    }
184}