1use std::{
2 path::PathBuf,
3 sync::{Arc, OnceLock},
4};
5
6use axum::{
7 Json, Router,
8 extract::{ConnectInfo, connect_info},
9 routing::{get, post},
10 serve::IncomingStream,
11};
12use color_eyre::eyre::Result;
13use serde::{Deserialize, Serialize};
14use tokio::{
15 net::{UnixListener, unix::UCred},
16 sync::RwLock,
17};
18
19use crate::vault::Vault;
20use message::MsgData;
21
22pub mod message;
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct ServerResult {
26 status: String,
27 pub msg: String,
28}
29
30impl ServerResult {
31 fn success(msg: &str) -> ServerResult {
32 ServerResult {
33 status: "success".to_string(),
34 msg: msg.to_string(),
35 }
36 }
37
38 fn fail(msg: &str) -> ServerResult {
39 ServerResult {
40 status: "fail".to_string(),
41 msg: msg.to_string(),
42 }
43 }
44
45 pub fn is_success(&self) -> bool {
46 self.status == "success"
47 }
48}
49
50static VAULT: OnceLock<RwLock<Vault>> = OnceLock::new();
51fn vault_data() -> &'static RwLock<Vault> {
52 VAULT.get().unwrap()
53}
54
55pub async fn start_server(socket: &PathBuf, data_file: &PathBuf) -> Result<()> {
57 let v = Vault::from(data_file)?;
59 VAULT.set(RwLock::new(v)).unwrap();
60 let listener = UnixListener::bind(socket)?;
62 let app = Router::new()
64 .route("/unlock", post(handle_unlock))
65 .route("/lock", get(handle_lock))
66 .route("/encrypt", post(handle_encrypt))
67 .route("/decrypt", post(handle_decrypt))
68 .into_make_service_with_connect_info::<UdsConnectInfo>();
69 axum::serve(listener, app).await?;
70 Ok(())
71}
72
73async fn handle_unlock(
75 ConnectInfo(_): ConnectInfo<UdsConnectInfo>,
76 Json(payload): Json<MsgData>,
77) -> Json<ServerResult> {
78 if let MsgData::Unlock(data) = payload {
79 if cfg!(debug_assertions) {
80 println!("received: {:?}", data);
81 }
82 let mut vault = vault_data().write().await;
83 if vault.unlock(data.username, data.password).is_ok() {
84 return Json(ServerResult::success("unlocked"));
85 }
86 return Json(ServerResult::fail("failed to unlock vault"));
87 }
88 Json(ServerResult::fail("invalid data"))
89}
90
91async fn handle_lock(ConnectInfo(_): ConnectInfo<UdsConnectInfo>) -> Json<ServerResult> {
92 if cfg!(debug_assertions) {
93 println!("locking..");
94 }
95 let mut vault = vault_data().write().await;
96 vault.lock();
97 Json(ServerResult::success("locked"))
98}
99
100async fn handle_encrypt(
101 ConnectInfo(_): ConnectInfo<UdsConnectInfo>,
102 Json(payload): Json<MsgData>,
103) -> Json<ServerResult> {
104 if let MsgData::Encrypt(ref data) = payload {
105 if cfg!(debug_assertions) {
106 println!("encrypting {:?}..", payload);
107 }
108 let vault = vault_data().read().await;
109 if let Ok(encrypted) = vault.encrypt_from_base64(&data.data) {
110 return Json(ServerResult::success(&encrypted));
111 }
112 if !vault.is_unlocked() {
113 return Json(ServerResult::fail("vault is locked"));
114 }
115 }
116 Json(ServerResult::fail("invalid data"))
117}
118
119async fn handle_decrypt(
120 ConnectInfo(_): ConnectInfo<UdsConnectInfo>,
121 Json(payload): Json<MsgData>,
122) -> Json<ServerResult> {
123 if let MsgData::Decrypt(ref data) = payload {
124 if cfg!(debug_assertions) {
125 println!("decrypting {:?}..", payload);
126 }
127 let vault = vault_data().read().await;
128 if let Ok(decrypted) = vault.decrypt_to_base64(&data.data) {
129 return Json(ServerResult::success(&decrypted));
130 }
131 if !vault.is_unlocked() {
132 return Json(ServerResult::fail("vault is locked"));
133 }
134 }
135 Json(ServerResult::fail("invalid data"))
136}
137#[derive(Clone, Debug)]
140#[allow(dead_code)]
141struct UdsConnectInfo {
142 peer_addr: Arc<tokio::net::unix::SocketAddr>,
143 peer_cred: UCred,
144}
145
146impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
147 fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
148 let peer_addr = stream.io().peer_addr().unwrap();
149 let peer_cred = stream.io().peer_cred().unwrap();
150
151 Self {
152 peer_addr: Arc::new(peer_addr),
153 peer_cred,
154 }
155 }
156}