#![allow(clippy::unused_io_amount)]
#![cfg(not(tarpaulin_include))]
use extrasafe::builtins::{danger_zone::Threads, Networking, SystemIO};
use extrasafe::SafetyContext;
use std::io::prelude::*;
use warp::Filter;
use std::os::unix::net::{UnixListener, UnixStream};
use std::os::unix::process::CommandExt;
use std::sync::{Arc, Mutex};
enum DBMsg {
List,
Write(String),
}
type DbConn = Arc<Mutex<UnixStream>>;
fn run_subprocess(cmd: &[&str]) -> std::process::Child {
let exe_path = std::env::current_exe().unwrap();
std::process::Command::new(exe_path.to_str().unwrap())
.arg0(cmd[0])
.args(cmd)
.spawn()
.map_err(|e| format!("subcommand `{}` failed to start: {:?}", cmd.join(" "), e))
.unwrap()
}
fn with_db(
db: DbConn,
) -> impl Filter<Extract = (DbConn,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || db.clone())
}
fn run_webserver(db_socket_path: &str) {
println!("webserver thread connecting to db unix socket");
let socket = UnixStream::connect(db_socket_path).expect("failed to connect to db socket");
let db_socket: DbConn = Arc::new(Mutex::new(socket));
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build().unwrap();
let listener = std::net::TcpListener::bind("127.0.0.1:5576").unwrap();
SafetyContext::new()
.enable(Networking::nothing()
.allow_running_tcp_servers()).unwrap()
.apply_to_current_thread()
.unwrap();
let routes = warp::path("write")
.and(warp::post())
.and(warp::body::bytes())
.and(with_db(db_socket.clone()))
.map(|param: bytes::Bytes, db_conn: DbConn| {
println!("webserver got write request");
let mut conn = db_conn.lock().unwrap();
let s = std::str::from_utf8(¶m).unwrap();
conn.write_all(format!("write {}", s).as_bytes())
.expect("failed to send write message to db");
"ok"
})
.or(warp::path("read")
.and(warp::get())
.and(with_db(db_socket))
.map(|db_conn: DbConn| {
println!("webserver got read request");
let mut conn = db_conn.lock().unwrap();
println!("sending list command to db");
conn.write_all("list".as_bytes())
.expect("failed to send read message to db");
println!("waiting for response from db");
let mut buf: [u8; 100] = [0; 100];
conn.read(&mut buf)
.expect("failed to read response from db");
println!("got response from db");
let messages = String::from_utf8(buf.to_vec())
.unwrap()
.trim_end_matches('\0')
.to_string();
messages
})
);
let svc = warp::service(routes);
let make_svc = hyper::service::make_service_fn(move |_| {
let warp_svc = svc.clone();
async move { Ok::<_, std::convert::Infallible>(warp_svc) }
});
let _in_runtime = runtime.enter();
let server = hyper::Server::from_tcp(listener).unwrap();
println!("Server about to start listening...");
runtime.block_on(server.serve(make_svc)).unwrap();
}
fn run_db(socket_path: &str) {
let socket = UnixListener::bind(socket_path).unwrap();
let dir = tempfile::tempdir().unwrap();
let mut path = dir.path().to_path_buf();
path.push("testdb.sql3");
let db = rusqlite::Connection::open(&path).unwrap();
db.pragma_update(None, "locking_mode", "exclusive").unwrap();
db.pragma_update(None, "journal_mode", "wal").unwrap();
db.execute("CREATE TABLE messages ( msg TEXT NOT NULL );", []).unwrap();
let mut get_rows = db.prepare("SELECT msg FROM messages;").unwrap();
let mut insert_row = db.prepare("INSERT INTO messages VALUES (?)").unwrap();
SafetyContext::new()
.enable(Networking::nothing()
.allow_running_unix_servers()
).unwrap()
.enable(SystemIO::nothing()
.allow_read()
.allow_write()
.allow_metadata()
.allow_ioctl()
.allow_close()).unwrap()
.enable(Threads::nothing()
.allow_sleep().yes_really()).unwrap()
.apply_to_current_thread()
.unwrap();
println!("database opened at {:?}", &path);
println!("db server waiting to accept connection");
let conn = socket.accept();
if let Err(err) = conn {
panic!("Error accepting db connection: {:?}", err);
}
let (mut conn, _) = conn.unwrap();
println!("db server got connection on unix socket");
loop {
println!("db server waiting for unix socket message");
let mut buf: [u8; 100] = [0; 100];
conn.read(&mut buf)
.expect("failed reading request to db server");
let buf = String::from_utf8(buf.to_vec())
.unwrap()
.trim_end_matches('\0')
.to_string();
println!("db got unix socket message: '{}'", buf);
let msg: DBMsg;
if buf == "list" {
msg = DBMsg::List;
}
else if buf.starts_with("write") {
msg = DBMsg::Write(buf[6..].to_string());
}
else {
panic!("unknown message recieved in db: {}", buf);
}
match msg {
DBMsg::List => {
let messages: Vec<String> = get_rows
.query_map([], |row| row.get(0)).unwrap()
.map(Result::unwrap)
.collect();
conn.write_all(messages.join("\n").as_bytes())
.expect("failed writing response from db server");
}
DBMsg::Write(s) => {
insert_row.execute([s]).unwrap();
}
}
}
}
fn run_client_write(msg: &str) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
SafetyContext::new()
.enable(Networking::nothing()
.allow_start_tcp_clients()).unwrap()
.apply_to_current_thread()
.unwrap();
println!("about to make request with msg {}", msg);
let msg = msg.to_string();
runtime.block_on(async {
let client = reqwest::Client::new();
let res = client
.post("http://127.0.0.1:5576/write")
.body(msg)
.send()
.await;
assert!(
res.is_ok(),
"Error writing to server db: {:?}",
res.unwrap_err()
);
let text = res.unwrap().text().await.unwrap();
assert_eq!(text, "ok");
});
}
fn run_client_read() {
let runtime = tokio::runtime::Builder::new_current_thread()
.worker_threads(1)
.enable_all()
.build()
.unwrap();
let client = reqwest::Client::new();
SafetyContext::new()
.enable(Networking::nothing()
.allow_start_tcp_clients()).unwrap()
.enable(Threads::nothing()
.allow_create()).unwrap()
.enable(
SystemIO::nothing()
.allow_open_readonly()
.allow_read()
.allow_metadata()
.allow_close(),
)
.unwrap()
.apply_to_current_thread()
.unwrap();
runtime.block_on(async {
let resp = client.get("https://example.org/").send().await.unwrap();
let res = resp.text().await;
assert!(
res.is_ok(),
"failed getting example.org response: {:?}",
res.unwrap_err()
);
println!("about to make read request to webserver");
let res = client.get("http://127.0.0.1:5576/read").send().await;
assert!(
res.is_ok(),
"Error reading from server db: {:?}",
res.unwrap_err()
);
let text = res.unwrap().text().await.unwrap();
assert_eq!(text, "hello\nextrasafe");
println!("got response: {}", text);
});
}
fn main() {
let args: Vec<String> = std::env::args().into_iter().collect();
if args.len() > 1 {
match args[1].as_str() {
"db" => run_db(&args[2]),
"webserver" => run_webserver(&args[2]),
"read_client" => run_client_read(),
"write_client" => run_client_write(&args[2]),
other => panic!("unknown subcommand {}", other),
}
return;
}
let dir = tempfile::TempDir::new().unwrap();
let mut path = dir.path().to_path_buf();
path.push("db.sock");
let mut db_child = run_subprocess(&["db", path.to_str().unwrap()]);
std::thread::sleep(std::time::Duration::from_millis(100));
let mut webserver_child = run_subprocess(&["webserver", path.to_str().unwrap()]);
std::thread::sleep(std::time::Duration::from_millis(100));
let res1 = run_subprocess(&["write_client", "hello"]).wait();
assert!(
res1.is_ok(),
"client1 failed to finish: {:?}",
res1.unwrap_err()
);
let status = res1.unwrap();
assert!(
status.success(),
"client1 exited unsuccessfully: {:?}",
status
);
let res2 = run_subprocess(&["write_client", "extrasafe"]).wait();
assert!(
res2.is_ok(),
"client2 failed to finish: {:?}",
res2.unwrap_err()
);
let status = res2.unwrap();
assert!(
status.success(),
"client2 exited unsuccessfully: {:?}",
status
);
let res3 = run_subprocess(&["read_client"]).wait();
assert!(
res3.is_ok(),
"client3 failed to finish: {:?}",
res3.unwrap_err()
);
let status = res3.unwrap();
assert!(
status.success(),
"client3 exited unsuccessfully: {:?}",
status
);
db_child.kill().unwrap();
webserver_child.kill().unwrap();
}