mod scheduler;
pub use scheduler::{DynamicThreadPool, RpsTracker};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Semaphore;
use tracing::{info, warn, error};
pub type HandlerFn = Arc<dyn Fn(Request) -> Response + Send + Sync>;
#[derive(Debug, Clone)]
pub struct Request {
pub method: String,
pub path: String,
pub headers: Vec<(String, String)>,
pub body: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Response {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: String,
}
impl Response {
pub fn ok(body: impl Into<String>) -> Self {
Self {
status: 200,
headers: vec![("Content-Type".to_string(), "text/plain".to_string())],
body: body.into(),
}
}
pub fn json(body: impl Into<String>) -> Self {
Self {
status: 200,
headers: vec![("Content-Type".to_string(), "application/json".to_string())],
body: body.into(),
}
}
pub fn status(code: u16) -> Self {
Self {
status: code,
headers: vec![],
body: String::new(),
}
}
pub fn not_found() -> Self {
Self {
status: 404,
headers: vec![],
body: "Not Found".to_string(),
}
}
pub fn internal_error(msg: impl Into<String>) -> Self {
Self {
status: 500,
headers: vec![],
body: msg.into(),
}
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((key.into(), value.into()));
self
}
pub fn body(mut self, body: impl Into<String>) -> Self {
self.body = body.into();
self
}
}
pub struct Vapor {
pool: Arc<DynamicThreadPool>,
rps_tracker: Arc<RpsTracker>,
handler: HandlerFn,
port: u16,
}
impl Vapor {
pub fn new(handler: HandlerFn) -> Self {
Self {
pool: Arc::new(DynamicThreadPool::new()),
rps_tracker: Arc::new(RpsTracker::new(1000)),
handler,
port: 8080,
}
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn pool(mut self, pool: DynamicThreadPool) -> Self {
self.pool = Arc::new(pool);
self
}
pub fn rps_tracker(&self) -> &RpsTracker {
&self.rps_tracker
}
pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
let addr = format!("0.0.0.0:{}", self.port);
let listener = TcpListener::bind(&addr).await?;
let listener = Arc::new(listener);
info!("Vapor starting on {} with {} threads", addr, self.pool.current_threads());
let pool = self.pool.clone();
let rps_tracker = self.rps_tracker.clone();
let handler = self.handler.clone();
let sem = Arc::new(Semaphore::new(pool.current_threads()));
let sem_for_scale = sem.clone();
tokio::spawn(async move {
loop {
if let Some(scale_up) = pool.should_scale() {
let current = pool.current_threads();
if scale_up {
let new_count = (current + 1).min(pool.max_threads());
pool.set_thread_count(new_count);
let additional = new_count.saturating_sub(sem_for_scale.available_permits());
if additional > 0 {
sem_for_scale.add_permits(additional);
}
info!("Scaling up: {} -> {} threads", current, new_count);
} else {
let new_count = (current - 1).max(pool.min_threads());
pool.set_thread_count(new_count);
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
let sem2 = sem.clone();
let listener2 = listener.clone();
tokio::spawn(async move {
loop {
let sem_clone = sem2.clone();
let permit = match sem_clone.acquire().await {
Ok(p) => p,
Err(_) => continue,
};
match listener2.accept().await {
Ok((stream, _)) => {
rps_tracker.record();
let h = handler.clone();
if let Err(e) = handle_connection(stream, h).await {
warn!("Connection error: {}", e);
}
drop(permit);
}
Err(e) => {
drop(permit);
error!("Accept error: {}", e);
}
}
}
});
use std::future::pending;
pending::<()>().await;
unreachable!();
}
}
async fn handle_connection(
mut stream: TcpStream,
handler: HandlerFn,
) -> Result<(), Box<dyn std::error::Error>> {
let mut buffer = vec![0u8; 8192];
let n = stream.read(&mut buffer).await?;
if n == 0 {
return Ok(());
}
let request_str = String::from_utf8_lossy(&buffer[..n]);
let request = parse_request(&request_str);
let response = handler(request);
let response_bytes = build_response(&response);
stream.write_all(&response_bytes).await?;
stream.flush().await?;
Ok(())
}
fn parse_request(request_str: &str) -> Request {
let lines: Vec<&str> = request_str.lines().collect();
if lines.is_empty() {
return Request {
method: "GET".to_string(),
path: "/".to_string(),
headers: vec![],
body: None,
};
}
let first_line: Vec<&str> = lines[0].split_whitespace().collect();
let method = first_line.get(0).unwrap_or(&"GET").to_string();
let path = first_line.get(1).unwrap_or(&"/").to_string();
let mut headers = vec![];
let mut body_start = 0;
for (i, line) in lines.iter().enumerate().skip(1) {
if line.is_empty() {
body_start = i + 1;
break;
}
if let Some((key, value)) = line.split_once(':') {
headers.push((key.trim().to_string(), value.trim().to_string()));
}
}
let body = if body_start > 0 && body_start < lines.len() {
Some(lines[body_start..].join("\n"))
} else {
None
};
Request { method, path, headers, body }
}
fn build_response(response: &Response) -> Vec<u8> {
let status_text = match response.status {
200 => "OK",
404 => "Not Found",
500 => "Internal Server Error",
_ => "Unknown",
};
let mut response_str = format!(
"HTTP/1.1 {} {}\r\n",
response.status, status_text
);
for (key, value) in &response.headers {
response_str.push_str(&format!("{}: {}\r\n", key, value));
}
response_str.push_str(&format!("Content-Length: {}\r\n", response.body.len()));
response_str.push_str("\r\n");
response_str.push_str(&response.body);
response_str.into_bytes()
}
pub async fn shutdown() {
tokio::signal::ctrl_c().await.ok();
}