minivault/
server.rs

1use std::{
2    path::PathBuf,
3    sync::{Arc, OnceLock},
4};
5
6use axum::{
7    Json, Router,
8    extract::{ConnectInfo, connect_info},
9    routing::{get, post},
10    serve::IncomingStream,
11};
12use color_eyre::eyre::Result;
13use serde::{Deserialize, Serialize};
14use tokio::{
15    net::{UnixListener, unix::UCred},
16    sync::RwLock,
17};
18
19use crate::vault::Vault;
20use message::MsgData;
21
22pub mod message;
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct ServerResult {
26    status: String,
27    pub msg: String,
28}
29
30impl ServerResult {
31    fn success(msg: &str) -> ServerResult {
32        ServerResult {
33            status: "success".to_string(),
34            msg: msg.to_string(),
35        }
36    }
37
38    fn fail(msg: &str) -> ServerResult {
39        ServerResult {
40            status: "fail".to_string(),
41            msg: msg.to_string(),
42        }
43    }
44
45    pub fn is_success(&self) -> bool {
46        self.status == "success"
47    }
48}
49
50static VAULT: OnceLock<RwLock<Vault>> = OnceLock::new();
51fn vault_data() -> &'static RwLock<Vault> {
52    VAULT.get().unwrap()
53}
54
55/// Run server on a UNIX socket.
56pub async fn start_server(socket: &PathBuf, data_file: &PathBuf) -> Result<()> {
57    // setup vault
58    let v = Vault::from(data_file)?;
59    VAULT.set(RwLock::new(v)).unwrap();
60    // setup unix socket
61    let listener = UnixListener::bind(socket)?;
62    // define paths
63    let app = Router::new()
64        .route("/unlock", post(handle_unlock))
65        .route("/lock", get(handle_lock))
66        .route("/encrypt", post(handle_encrypt))
67        .route("/decrypt", post(handle_decrypt))
68        .into_make_service_with_connect_info::<UdsConnectInfo>();
69    axum::serve(listener, app).await?;
70    Ok(())
71}
72
73/* BEGIN HANDLERS */
74async fn handle_unlock(
75    ConnectInfo(_): ConnectInfo<UdsConnectInfo>,
76    Json(payload): Json<MsgData>,
77) -> Json<ServerResult> {
78    if let MsgData::Unlock(data) = payload {
79        if cfg!(debug_assertions) {
80            println!("received: {:?}", data);
81        }
82        let mut vault = vault_data().write().await;
83        if vault.unlock(data.username, data.password).is_ok() {
84            return Json(ServerResult::success("unlocked"));
85        }
86        return Json(ServerResult::fail("failed to unlock vault"));
87    }
88    Json(ServerResult::fail("invalid data"))
89}
90
91async fn handle_lock(ConnectInfo(_): ConnectInfo<UdsConnectInfo>) -> Json<ServerResult> {
92    if cfg!(debug_assertions) {
93        println!("locking..");
94    }
95    let mut vault = vault_data().write().await;
96    vault.lock();
97    Json(ServerResult::success("locked"))
98}
99
100async fn handle_encrypt(
101    ConnectInfo(_): ConnectInfo<UdsConnectInfo>,
102    Json(payload): Json<MsgData>,
103) -> Json<ServerResult> {
104    if let MsgData::Encrypt(ref data) = payload {
105        if cfg!(debug_assertions) {
106            println!("encrypting {:?}..", payload);
107        }
108        let vault = vault_data().read().await;
109        if let Ok(encrypted) = vault.encrypt_from_base64(&data.data) {
110            return Json(ServerResult::success(&encrypted));
111        }
112        if !vault.is_unlocked() {
113            return Json(ServerResult::fail("vault is locked"));
114        }
115    }
116    Json(ServerResult::fail("invalid data"))
117}
118
119async fn handle_decrypt(
120    ConnectInfo(_): ConnectInfo<UdsConnectInfo>,
121    Json(payload): Json<MsgData>,
122) -> Json<ServerResult> {
123    if let MsgData::Decrypt(ref data) = payload {
124        if cfg!(debug_assertions) {
125            println!("decrypting {:?}..", payload);
126        }
127        let vault = vault_data().read().await;
128        if let Ok(decrypted) = vault.decrypt_to_base64(&data.data) {
129            return Json(ServerResult::success(&decrypted));
130        }
131        if !vault.is_unlocked() {
132            return Json(ServerResult::fail("vault is locked"));
133        }
134    }
135    Json(ServerResult::fail("invalid data"))
136}
137/* END HANDLERS */
138
139#[derive(Clone, Debug)]
140#[allow(dead_code)]
141struct UdsConnectInfo {
142    peer_addr: Arc<tokio::net::unix::SocketAddr>,
143    peer_cred: UCred,
144}
145
146impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
147    fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
148        let peer_addr = stream.io().peer_addr().unwrap();
149        let peer_cred = stream.io().peer_cred().unwrap();
150
151        Self {
152            peer_addr: Arc::new(peer_addr),
153            peer_cred,
154        }
155    }
156}