use clap::{Parser, Subcommand};
use serde_json::{json, Map as JsonMap, Value};
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use virtuus::table::ValidationMode;
use virtuus::{Database, Table, VERSION};
#[cfg(unix)]
fn current_rss_bytes() -> u64 {
unsafe {
let mut usage: libc::rusage = std::mem::zeroed();
if libc::getrusage(libc::RUSAGE_SELF, &mut usage) == 0 {
#[cfg(target_os = "macos")]
{
usage.ru_maxrss as u64
}
#[cfg(not(target_os = "macos"))]
{
(usage.ru_maxrss as u64) * 1024
}
} else {
0
}
}
}
#[cfg(not(unix))]
fn current_rss_bytes() -> u64 {
0
}
#[derive(Parser)]
#[command(name = "virtuus", version = VERSION)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Query {
#[arg(long)]
dir: PathBuf,
#[arg(long)]
schema: Option<PathBuf>,
#[arg(long)]
table: String,
#[arg(long)]
index: Option<String>,
#[arg(long)]
pk: Option<String>,
#[arg(long, value_name = "key=value")]
r#where: Option<String>,
},
Serve {
#[arg(long)]
dir: PathBuf,
#[arg(long)]
schema: PathBuf,
#[arg(long, default_value = "8080")]
port: u16,
},
Memory {
#[arg(long, default_value = "8080")]
port: u16,
},
}
fn main() {
let cli = Cli::parse();
if let Err(err) = run(cli) {
eprintln!("{err}");
std::process::exit(1);
}
}
fn run(cli: Cli) -> Result<(), String> {
match cli.command {
Commands::Query {
dir,
schema,
table,
index,
pk,
r#where,
} => run_query(dir, schema, table, index, pk, r#where),
Commands::Serve { dir, schema, port } => run_serve(dir, schema, port),
Commands::Memory { port } => run_memory(port),
}
}
fn run_query(
dir: PathBuf,
schema: Option<PathBuf>,
table: String,
index: Option<String>,
pk: Option<String>,
r#where: Option<String>,
) -> Result<(), String> {
let where_pair = if let Some(where_clause) = r#where.as_deref() {
Some(parse_where(where_clause)?)
} else {
None
};
let mut db = if let Some(schema_path) = schema {
Database::from_schema(&schema_path, Some(dir.as_path()))
} else {
let mut db = Database::new();
let table_dir = dir.join(&table);
if !table_dir.exists() {
return Err(format!("table \"{table}\" not found"));
}
let mut tbl = Table::new(
&table,
Some("id"),
None,
None,
Some(table_dir),
ValidationMode::Silent,
);
if let (Some(index_name), Some((where_key, _))) = (index.as_deref(), where_pair.as_ref()) {
tbl.add_gsi(index_name, where_key, None);
}
tbl.load_from_dir(None);
db.add_table(&table, tbl);
db
};
let mut directive = JsonMap::new();
if let Some(pk_value) = pk {
directive.insert("pk".to_string(), Value::String(pk_value));
}
if let Some(index_name) = index {
directive.insert("index".to_string(), Value::String(index_name));
}
if let Some((key, value)) = where_pair {
let mut where_map = JsonMap::new();
where_map.insert(key, Value::String(value));
directive.insert("where".to_string(), Value::Object(where_map));
} else if directive.get("index").is_some() && directive.get("pk").is_none() {
return Err("missing --where for index query".to_string());
}
let query = Value::Object(JsonMap::from_iter([(
table.clone(),
Value::Object(directive),
)]));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| db.execute(&query)))
.map_err(|err| format!("query failed: {}", panic_message(err)))?;
let output = if let Some(items) = result.get("items") {
items.clone()
} else {
result
};
let json_text = serde_json::to_string(&output).map_err(|err| err.to_string())?;
println!("{json_text}");
Ok(())
}
struct HttpRequest {
method: String,
path: String,
body: Vec<u8>,
}
fn run_serve(dir: PathBuf, schema: PathBuf, port: u16) -> Result<(), String> {
let db = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Database::from_schema(&schema, Some(dir.as_path()))
}))
.map_err(|err| format!("failed to load schema: {}", panic_message(err)))?;
let state = Arc::new(Mutex::new(db));
let load_count = Arc::new(AtomicUsize::new(1));
let refresh_count = Arc::new(AtomicUsize::new(0));
let listener = TcpListener::bind(("127.0.0.1", port)).map_err(|err| err.to_string())?;
for stream in listener.incoming() {
let stream = stream.map_err(|err| err.to_string())?;
let state = Arc::clone(&state);
let load_count = Arc::clone(&load_count);
let refresh_count = Arc::clone(&refresh_count);
thread::spawn(move || {
if let Err(err) = handle_connection(stream, state, load_count, refresh_count) {
eprintln!("server error: {err}");
}
});
}
Ok(())
}
fn handle_connection(
mut stream: TcpStream,
state: Arc<Mutex<Database>>,
load_count: Arc<AtomicUsize>,
refresh_count: Arc<AtomicUsize>,
) -> std::io::Result<()> {
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
let request = match read_request(&mut stream) {
Ok(req) => req,
Err(_) => return Ok(()),
};
let path = request
.path
.split('?')
.next()
.unwrap_or(request.path.as_str());
let mut status = 200u16;
let response = match (request.method.as_str(), path) {
("GET", "/health") => {
json!({
"status": "ok",
"load_count": load_count.load(Ordering::SeqCst),
"refresh_count": refresh_count.load(Ordering::SeqCst)
})
}
("POST", "/query") => {
let text = String::from_utf8_lossy(&request.body);
match serde_json::from_str::<Value>(&text) {
Ok(query) => {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut db = state.lock().expect("db lock");
db.execute(&query)
}));
match result {
Ok(value) => value,
Err(err) => {
status = 400;
json!({ "error": panic_message(err) })
}
}
}
Err(err) => {
status = 400;
json!({ "error": format!("invalid json: {err}") })
}
}
}
("POST", "/describe") => {
let mut db = state.lock().expect("db lock");
let describe = db.describe();
serde_json::to_value(describe).unwrap_or_else(|_| json!({}))
}
("POST", "/validate") => {
let mut db = state.lock().expect("db lock");
let violations = db.validate();
json!({
"valid": violations.is_empty(),
"violations": violations
})
}
("POST", "/warm") => {
let mut db = state.lock().expect("db lock");
db.warm();
refresh_count.fetch_add(1, Ordering::SeqCst);
let tables: Vec<String> = db.tables().keys().cloned().collect();
json!({
"status": "ok",
"tables": tables
})
}
("GET", "/memory") => {
let rss = current_rss_bytes();
json!({
"rss_bytes": rss,
"rss_kb": rss / 1024
})
}
_ => {
status = 404;
json!({ "error": "not found" })
}
};
let body = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
write_response(&mut stream, status, &body)
}
fn run_memory(port: u16) -> Result<(), String> {
let addr = format!("127.0.0.1:{port}");
let mut stream = TcpStream::connect(addr).map_err(|e| format!("connect: {e}"))?;
let request =
"GET /memory HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n".to_string();
stream
.write_all(request.as_bytes())
.map_err(|e| format!("write: {e}"))?;
let mut buf = String::new();
stream
.read_to_string(&mut buf)
.map_err(|e| format!("read: {e}"))?;
if let Some(body) = buf.split("\r\n\r\n").nth(1) {
println!("{body}");
} else {
println!("{buf}");
}
Ok(())
}
fn read_request(stream: &mut TcpStream) -> std::io::Result<HttpRequest> {
let mut reader = BufReader::new(stream.try_clone()?);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
if request_line.trim().is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"empty request",
));
}
let mut parts = request_line.split_whitespace();
let method = parts.next().unwrap_or("").to_string();
let path = parts.next().unwrap_or("/").to_string();
let mut headers = HashMap::new();
loop {
let mut line = String::new();
reader.read_line(&mut line)?;
if line == "\r\n" || line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_lowercase(), value.trim().to_string());
}
}
let length = headers
.get("content-length")
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(0);
let mut body = vec![0u8; length];
if length > 0 {
reader.read_exact(&mut body)?;
}
Ok(HttpRequest { method, path, body })
}
fn write_response(stream: &mut TcpStream, status: u16, body: &str) -> std::io::Result<()> {
let reason = match status {
200 => "OK",
400 => "Bad Request",
404 => "Not Found",
_ => "OK",
};
let response = format!(
"HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes())?;
stream.flush()?;
Ok(())
}
fn parse_where(input: &str) -> Result<(String, String), String> {
let mut parts = input.splitn(2, '=');
let key = parts
.next()
.map(str::trim)
.filter(|s| !s.is_empty())
.ok_or_else(|| "invalid --where; expected key=value".to_string())?;
let value = parts
.next()
.map(str::trim)
.filter(|s| !s.is_empty())
.ok_or_else(|| "invalid --where; expected key=value".to_string())?;
Ok((key.to_string(), value.to_string()))
}
fn panic_message(err: Box<dyn std::any::Any + Send>) -> String {
if let Some(msg) = err.downcast_ref::<&str>() {
msg.to_string()
} else if let Some(msg) = err.downcast_ref::<String>() {
msg.clone()
} else {
"unknown error".to_string()
}
}