use extrasafe::builtins::{danger_zone::Threads, Networking, SystemIO};
use extrasafe::SafetyContext;
use crossbeam::channel;
use crossbeam_queue::SegQueue;
use warp::Filter;
use std::sync::Arc;
enum DBMsg {
List(channel::Sender<Vec<String>>),
Write(String),
}
type DbConn = Arc<SegQueue<DBMsg>>;
fn with_db(
db: DbConn,
) -> impl Filter<Extract = (DbConn,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || db.clone())
}
fn run_server() {
let queue: DbConn = Arc::new(SegQueue::new());
let db_queue = queue.clone();
let read_queue = queue.clone();
let write_queue = queue;
std::thread::Builder::new()
.name("db".into())
.spawn(move || run_db(&db_queue)).unwrap();
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build().unwrap();
let listener = std::net::TcpListener::bind("127.0.0.1:5575").unwrap();
SafetyContext::new()
.enable(Networking::nothing()
.allow_running_tcp_servers()).unwrap()
.apply_to_current_thread()
.unwrap();
let routes = warp::path("write")
.and(warp::post())
.and(warp::body::bytes())
.and(with_db(write_queue))
.map(|param: bytes::Bytes, msg_queue: DbConn| {
let s = std::str::from_utf8(¶m).unwrap();
msg_queue.push(DBMsg::Write(s.into()));
"ok"
})
.or(warp::path("read")
.and(warp::get())
.and(with_db(read_queue))
.map(|msg_queue: DbConn| {
let (send, recv) = crossbeam_channel::bounded(1);
msg_queue.push(DBMsg::List(send));
let messages = recv.recv().unwrap();
messages.join("\n")
})
);
let svc = warp::service(routes);
let make_svc = hyper::service::make_service_fn(move |_| {
let warp_svc = svc.clone();
async move { Ok::<_, std::convert::Infallible>(warp_svc) }
});
let _in_runtime = runtime.enter();
let server = hyper::Server::from_tcp(listener).unwrap();
println!("Server about to start listening...");
runtime.block_on(server.serve(make_svc)).unwrap();
}
fn run_db(queue: &DbConn) {
let dir = tempfile::tempdir().unwrap();
let mut path = dir.path().to_path_buf();
path.push("testdb.sql3");
let db = rusqlite::Connection::open(&path).unwrap();
db.pragma_update(None, "locking_mode", "exclusive").unwrap();
db.pragma_update(None, "journal_mode", "wal").unwrap();
db.execute("CREATE TABLE messages ( msg TEXT NOT NULL );", []).unwrap();
let mut get_rows = db.prepare("SELECT msg FROM messages;").unwrap();
let mut insert_row = db.prepare("INSERT INTO messages VALUES (?)").unwrap();
SafetyContext::new()
.enable(SystemIO::nothing()
.allow_read()
.allow_write()
.allow_metadata()
.allow_ioctl()
.allow_close()).unwrap()
.enable(Threads::nothing()
.allow_sleep().yes_really()).unwrap()
.apply_to_current_thread()
.unwrap();
println!("database opened at {:?}", &path);
loop {
if queue.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(55));
continue;
}
let msg = queue.pop().unwrap();
match msg {
DBMsg::List(send) => {
let messages: Vec<String> = get_rows
.query_map([], |row| row.get(0)).unwrap()
.map(Result::unwrap)
.collect();
send.send(messages).unwrap();
}
DBMsg::Write(s) => {
insert_row.execute([s]).unwrap();
}
}
}
}
fn run_client_write(msg: &str) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build().unwrap();
SafetyContext::new()
.enable(Networking::nothing()
.allow_start_tcp_clients()).unwrap()
.apply_to_current_thread()
.unwrap();
println!("about to make request with msg {}", msg);
let msg = msg.to_string();
runtime.block_on(async {
let client = reqwest::Client::new();
let res = client
.post("http://127.0.0.1:5575/write")
.body(msg)
.send()
.await;
assert!(
res.is_ok(),
"Error writing to server db: {:?}",
res.unwrap_err()
);
let text = res.unwrap().text().await.unwrap();
assert_eq!(text, "ok");
});
}
fn run_client_read() {
let runtime = tokio::runtime::Builder::new_current_thread()
.worker_threads(1)
.enable_all()
.build().unwrap();
let client = reqwest::Client::new();
SafetyContext::new()
.enable(Networking::nothing()
.allow_start_tcp_clients()).unwrap()
.enable(Threads::nothing()
.allow_create()).unwrap()
.enable(
SystemIO::nothing()
.allow_open_readonly()
.allow_read()
.allow_metadata()
.allow_close(),
)
.unwrap()
.apply_to_current_thread()
.unwrap();
runtime.block_on(async {
let resp = client.get("https://example.org/").send().await.unwrap();
let res = resp.text().await;
assert!(
res.is_ok(),
"failed getting example.org response: {:?}",
res.unwrap_err()
);
let res = client.get("http://127.0.0.1:5575/read").send().await;
assert!(
res.is_ok(),
"Error reading from server db: {:?}",
res.unwrap_err()
);
let text = res.unwrap().text().await.unwrap();
assert_eq!(text, "hello\nextrasafe");
println!("got response: {}", text);
});
}
fn main() {
let _server_thread = std::thread::Builder::new()
.name("server".into())
.spawn(run_server).unwrap();
std::thread::sleep(std::time::Duration::from_millis(100));
let client1_thread = std::thread::Builder::new()
.name("client1".into())
.spawn(|| run_client_write("hello")).unwrap();
let res1 = client1_thread.join();
assert!(res1.is_ok(), "client1 failed: {:?}", res1.unwrap_err());
let client2_thread = std::thread::Builder::new()
.name("client2".into())
.spawn(|| run_client_write("extrasafe")).unwrap();
let res2 = client2_thread.join();
assert!(res2.is_ok(), "client2 failed: {:?}", res2.unwrap_err());
let client3_thread = std::thread::Builder::new()
.name("client3".into())
.spawn(run_client_read).unwrap();
let res3 = client3_thread.join();
assert!(res3.is_ok(), "client3 failed: {:?}", res3.unwrap_err());
}