use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use super::{Controller, Error};
use crate::http::{Request, Response};
use async_trait::async_trait;
use rayon::{ThreadPool, ThreadPoolBuilder};
use tokio::sync::oneshot::channel;
use tracing::{info, warn};
use tokio::fs::{metadata, File};
use rwf_ruby::{RackRequest, RackResponseOwned, Ruby};
use std::sync::Arc;
pub struct RackController {
pool: ThreadPool,
path: PathBuf,
loaded: Arc<AtomicBool>,
}
impl RackController {
pub fn new(path: &str) -> Self {
Self {
pool: Self::runtime(1),
path: PathBuf::from(path).join("config/environment.rb"),
loaded: Arc::new(AtomicBool::new(false)),
}
}
fn runtime(threads: usize) -> ThreadPool {
ThreadPoolBuilder::new()
.num_threads(threads)
.panic_handler(|_| {
warn!("Rack thread panicked. This is a bug in the Rack application.");
})
.build()
.unwrap()
}
}
#[async_trait]
impl Controller for RackController {
fn skip_csrf(&self) -> bool {
true
}
async fn handle(&self, request: &Request) -> Result<Response, Error> {
let (tx, rx) = channel();
let path = PathBuf::from(&self.path);
let loaded = self.loaded.clone();
let req_path = request.path().path().to_string();
let method = request.method().to_string();
let query = request.query().to_string();
let req_uri = format!("{}{}", req_path, query);
let body = request.body().to_vec();
let content_type = request
.headers()
.get("content-type")
.unwrap_or(&String::from("application/x-www-form-urlencoded"))
.to_string();
let content_length = request
.headers()
.get("content-length")
.unwrap_or(&String::from(body.len().to_string().as_str()))
.to_string();
let mut env = HashMap::from([
("REQUEST_URI".into(), req_uri),
("PATH_INFO".into(), req_path.clone()),
("REQUEST_PATH".into(), req_path),
("SERVER_PROTOCOL".into(), "HTTP/1.1".into()),
("REQUEST_METHOD".into(), method),
("QUERY_STRING".into(), query.replace("?", "")),
("CONTENT_TYPE".into(), content_type),
("CONTENT_LENGTH".into(), content_length),
]);
for (key, value) in request.headers().iter() {
env.insert(
format!("HTTP_{}", crate::snake_case(key).to_ascii_uppercase()),
value.to_string(),
);
}
self.pool.spawn(move || {
if !loaded.load(Ordering::Relaxed) {
info!("Loading the Rack application, this may take a while...");
Ruby::load_app(&path).unwrap();
loaded.store(true, Ordering::Relaxed);
info!("Rack application loaded");
}
let response = RackRequest::send(env, &body).unwrap();
let owned = RackResponseOwned::from(response);
let _ = tx.send(owned);
});
let response = rx.await.unwrap();
if response.is_file() {
let path = PathBuf::from(String::from_utf8_lossy(response.body()).to_string());
let meta = if let Ok(meta) = metadata(&path).await {
meta
} else {
return Ok(Response::not_found());
};
let file = if let Ok(file) = File::open(&path).await {
file
} else {
return Ok(Response::not_found());
};
Ok(Response::new().body((path, file, meta)))
} else {
let mut res = Response::new().body(response.body());
for (key, value) in response.headers() {
res = res.header(key, value);
}
Ok(res.code(response.code()))
}
}
}