use crate::socket_server::{
Connection, ConnectionHandler, Message, SocketClient, SocketServer, SocketServerConfig,
};
use crate::IpcError;
use parking_lot::RwLock;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Read};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Method {
GET,
POST,
PUT,
DELETE,
PATCH,
OPTIONS,
HEAD,
}
impl Method {
pub fn parse(s: &str) -> Option<Self> {
match s.to_uppercase().as_str() {
"GET" => Some(Method::GET),
"POST" => Some(Method::POST),
"PUT" => Some(Method::PUT),
"DELETE" => Some(Method::DELETE),
"PATCH" => Some(Method::PATCH),
"OPTIONS" => Some(Method::OPTIONS),
"HEAD" => Some(Method::HEAD),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Method::GET => "GET",
Method::POST => "POST",
Method::PUT => "PUT",
Method::DELETE => "DELETE",
Method::PATCH => "PATCH",
Method::OPTIONS => "OPTIONS",
Method::HEAD => "HEAD",
}
}
}
#[derive(Debug)]
pub struct Request {
pub method: Method,
pub path: String,
pub query: HashMap<String, String>,
pub headers: HashMap<String, String>,
pub body: Option<JsonValue>,
pub raw_body: Vec<u8>,
pub params: HashMap<String, String>,
}
impl Request {
pub fn new(method: Method, path: &str) -> Self {
Self {
method,
path: path.to_string(),
query: HashMap::new(),
headers: HashMap::new(),
body: None,
raw_body: Vec::new(),
params: HashMap::new(),
}
}
pub fn query_param(&self, name: &str) -> Option<&str> {
self.query.get(name).map(|s| s.as_str())
}
pub fn path_param(&self, name: &str) -> Option<&str> {
self.params.get(name).map(|s| s.as_str())
}
pub fn header(&self, name: &str) -> Option<&str> {
self.headers.get(&name.to_lowercase()).map(|s| s.as_str())
}
pub fn content_type(&self) -> Option<&str> {
self.header("content-type")
}
pub fn accepts_json(&self) -> bool {
self.header("accept")
.map(|s| s.contains("application/json") || s.contains("*/*"))
.unwrap_or(true)
}
pub fn parse(data: &[u8]) -> Result<Self, ParseError> {
let mut reader = BufReader::new(data);
let mut first_line = String::new();
reader.read_line(&mut first_line)?;
let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() < 2 {
return Err(ParseError::InvalidRequestLine);
}
let method = Method::parse(parts[0]).ok_or(ParseError::InvalidMethod)?;
let full_path = parts[1];
let (path, query) = if let Some(idx) = full_path.find('?') {
let path = &full_path[..idx];
let query_str = &full_path[idx + 1..];
(path.to_string(), parse_query_string(query_str))
} else {
(full_path.to_string(), HashMap::new())
};
let mut headers = HashMap::new();
loop {
let mut line = String::new();
reader.read_line(&mut line)?;
let line = line.trim();
if line.is_empty() {
break;
}
if let Some(idx) = line.find(':') {
let key = line[..idx].trim().to_lowercase();
let value = line[idx + 1..].trim().to_string();
headers.insert(key, value);
}
}
let mut raw_body = Vec::new();
if let Some(len_str) = headers.get("content-length") {
if let Ok(len) = len_str.parse::<usize>() {
raw_body.resize(len, 0);
reader.read_exact(&mut raw_body)?;
}
}
let body = if !raw_body.is_empty() {
let content_type = headers.get("content-type").map(|s| s.as_str());
if content_type
.map(|s| s.contains("application/json"))
.unwrap_or(false)
{
serde_json::from_slice(&raw_body).ok()
} else {
None
}
} else {
None
};
Ok(Self {
method,
path,
query,
headers,
body,
raw_body,
params: HashMap::new(),
})
}
}
#[derive(Debug)]
pub enum ParseError {
InvalidRequestLine,
InvalidMethod,
IoError(std::io::Error),
}
impl From<std::io::Error> for ParseError {
fn from(e: std::io::Error) -> Self {
ParseError::IoError(e)
}
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParseError::InvalidRequestLine => write!(f, "Invalid request line"),
ParseError::InvalidMethod => write!(f, "Invalid HTTP method"),
ParseError::IoError(e) => write!(f, "IO error: {}", e),
}
}
}
impl std::error::Error for ParseError {}
fn parse_query_string(query: &str) -> HashMap<String, String> {
let mut params = HashMap::new();
for pair in query.split('&') {
if let Some(idx) = pair.find('=') {
let key = urlencoding_decode(&pair[..idx]);
let value = urlencoding_decode(&pair[idx + 1..]);
params.insert(key, value);
} else if !pair.is_empty() {
params.insert(urlencoding_decode(pair), String::new());
}
}
params
}
fn urlencoding_decode(s: &str) -> String {
let mut result = String::new();
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
} else {
result.push('%');
result.push_str(&hex);
}
} else if c == '+' {
result.push(' ');
} else {
result.push(c);
}
}
result
}
#[derive(Debug)]
pub struct Response {
pub status: u16,
pub status_message: String,
pub headers: HashMap<String, String>,
pub body: ResponseBody,
}
#[derive(Debug)]
pub enum ResponseBody {
Json(JsonValue),
Bytes(Vec<u8>),
Text(String),
Empty,
}
impl Response {
pub fn new(status: u16) -> Self {
Self {
status,
status_message: status_message(status).to_string(),
headers: HashMap::new(),
body: ResponseBody::Empty,
}
}
pub fn ok(body: JsonValue) -> Self {
let mut resp = Self::new(200);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(body);
resp
}
pub fn created(body: JsonValue) -> Self {
let mut resp = Self::new(201);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(body);
resp
}
pub fn no_content() -> Self {
Self::new(204)
}
pub fn bad_request(message: &str) -> Self {
let mut resp = Self::new(400);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(serde_json::json!({
"error": "Bad Request",
"message": message
}));
resp
}
pub fn unauthorized(message: &str) -> Self {
let mut resp = Self::new(401);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(serde_json::json!({
"error": "Unauthorized",
"message": message
}));
resp
}
pub fn forbidden(message: &str) -> Self {
let mut resp = Self::new(403);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(serde_json::json!({
"error": "Forbidden",
"message": message
}));
resp
}
pub fn not_found() -> Self {
let mut resp = Self::new(404);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(serde_json::json!({
"error": "Not Found"
}));
resp
}
pub fn internal_error(message: &str) -> Self {
let mut resp = Self::new(500);
resp.headers
.insert("Content-Type".to_string(), "application/json".to_string());
resp.body = ResponseBody::Json(serde_json::json!({
"error": "Internal Server Error",
"message": message
}));
resp
}
pub fn header(mut self, key: &str, value: &str) -> Self {
self.headers.insert(key.to_string(), value.to_string());
self
}
pub fn json(mut self, body: JsonValue) -> Self {
self.headers
.insert("Content-Type".to_string(), "application/json".to_string());
self.body = ResponseBody::Json(body);
self
}
pub fn text(mut self, body: &str) -> Self {
self.headers
.insert("Content-Type".to_string(), "text/plain".to_string());
self.body = ResponseBody::Text(body.to_string());
self
}
pub fn bytes(mut self, body: Vec<u8>, content_type: &str) -> Self {
self.headers
.insert("Content-Type".to_string(), content_type.to_string());
self.body = ResponseBody::Bytes(body);
self
}
pub fn to_bytes(&self) -> Vec<u8> {
let body_bytes = match &self.body {
ResponseBody::Json(v) => serde_json::to_vec(v).unwrap_or_default(),
ResponseBody::Bytes(b) => b.clone(),
ResponseBody::Text(s) => s.as_bytes().to_vec(),
ResponseBody::Empty => Vec::new(),
};
let mut output = format!("HTTP/1.1 {} {}\r\n", self.status, self.status_message);
for (key, value) in &self.headers {
output.push_str(&format!("{}: {}\r\n", key, value));
}
output.push_str(&format!("Content-Length: {}\r\n", body_bytes.len()));
output.push_str("\r\n");
let mut bytes = output.into_bytes();
bytes.extend(body_bytes);
bytes
}
}
fn status_message(status: u16) -> &'static str {
match status {
200 => "OK",
201 => "Created",
204 => "No Content",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
405 => "Method Not Allowed",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
_ => "Unknown",
}
}
#[derive(Debug, Clone)]
enum PathSegment {
Static(String),
Param(String),
Wildcard(String),
}
#[derive(Debug, Clone)]
pub struct PathPattern {
segments: Vec<PathSegment>,
#[allow(dead_code)]
original: String,
}
impl PathPattern {
pub fn parse(pattern: &str) -> Self {
let segments: Vec<PathSegment> = pattern
.trim_matches('/')
.split('/')
.filter(|s| !s.is_empty())
.map(|s| {
if s.starts_with("{*") && s.ends_with('}') {
PathSegment::Wildcard(s[2..s.len() - 1].to_string())
} else if s.starts_with('{') && s.ends_with('}') {
PathSegment::Param(s[1..s.len() - 1].to_string())
} else {
PathSegment::Static(s.to_string())
}
})
.collect();
Self {
segments,
original: pattern.to_string(),
}
}
pub fn matches(&self, path: &str) -> Option<HashMap<String, String>> {
let path_segments: Vec<&str> = path
.trim_matches('/')
.split('/')
.filter(|s| !s.is_empty())
.collect();
let mut params = HashMap::new();
let mut path_idx = 0;
for seg in self.segments.iter() {
match seg {
PathSegment::Static(s) => {
if path_idx >= path_segments.len() || path_segments[path_idx] != s {
return None;
}
path_idx += 1;
}
PathSegment::Param(name) => {
if path_idx >= path_segments.len() {
return None;
}
params.insert(name.clone(), path_segments[path_idx].to_string());
path_idx += 1;
}
PathSegment::Wildcard(name) => {
let rest: Vec<&str> = path_segments[path_idx..].to_vec();
params.insert(name.clone(), rest.join("/"));
return Some(params);
}
}
}
if path_idx == path_segments.len() {
Some(params)
} else {
None
}
}
}
pub type HandlerFn = Box<dyn Fn(Request) -> Response + Send + Sync>;
struct Route {
method: Method,
pattern: PathPattern,
handler: HandlerFn,
}
pub type MiddlewareFn =
Box<dyn Fn(Request, &dyn Fn(Request) -> Response) -> Response + Send + Sync>;
pub struct Router {
routes: Vec<Route>,
middlewares: Vec<MiddlewareFn>,
not_found_handler: Option<HandlerFn>,
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Router {
pub fn new() -> Self {
Self {
routes: Vec::new(),
middlewares: Vec::new(),
not_found_handler: None,
}
}
pub fn get<F>(&mut self, path: &str, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.route(Method::GET, path, handler)
}
pub fn post<F>(&mut self, path: &str, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.route(Method::POST, path, handler)
}
pub fn put<F>(&mut self, path: &str, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.route(Method::PUT, path, handler)
}
pub fn delete<F>(&mut self, path: &str, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.route(Method::DELETE, path, handler)
}
pub fn patch<F>(&mut self, path: &str, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.route(Method::PATCH, path, handler)
}
pub fn route<F>(&mut self, method: Method, path: &str, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.routes.push(Route {
method,
pattern: PathPattern::parse(path),
handler: Box::new(handler),
});
self
}
pub fn middleware<F>(&mut self, middleware: F) -> &mut Self
where
F: Fn(Request, &dyn Fn(Request) -> Response) -> Response + Send + Sync + 'static,
{
self.middlewares.push(Box::new(middleware));
self
}
pub fn not_found<F>(&mut self, handler: F) -> &mut Self
where
F: Fn(Request) -> Response + Send + Sync + 'static,
{
self.not_found_handler = Some(Box::new(handler));
self
}
pub fn handle(&self, mut req: Request) -> Response {
for route in &self.routes {
if route.method == req.method {
if let Some(params) = route.pattern.matches(&req.path) {
req.params = params;
if self.middlewares.is_empty() {
return (route.handler)(req);
} else {
let handler = &route.handler;
let mut chain: Box<dyn Fn(Request) -> Response + '_> = Box::new(handler);
for middleware in self.middlewares.iter().rev() {
let next = chain;
chain = Box::new(move |r| middleware(r, &*next));
}
return chain(req);
}
}
}
}
if let Some(ref handler) = self.not_found_handler {
handler(req)
} else {
Response::not_found()
}
}
}
#[derive(Debug, Clone)]
pub struct ApiServerConfig {
pub socket_config: SocketServerConfig,
pub enable_cors: bool,
pub cors_origins: Vec<String>,
}
impl Default for ApiServerConfig {
fn default() -> Self {
Self {
socket_config: SocketServerConfig::default(),
enable_cors: true,
cors_origins: vec!["*".to_string()],
}
}
}
#[derive(Clone)]
struct ApiHandler {
router: Arc<RwLock<Router>>,
config: ApiServerConfig,
}
impl ConnectionHandler for ApiHandler {
fn on_message(&self, _conn: &mut Connection, msg: Message) -> crate::Result<Option<Message>> {
let data = if let Some(binary_data) = msg.as_binary() {
binary_data
} else if let Some(text) = msg.as_text() {
text.as_bytes().to_vec()
} else {
serde_json::to_vec(&msg.payload).unwrap_or_default()
};
let request = match Request::parse(&data) {
Ok(req) => req,
Err(e) => {
let resp = Response::bad_request(&e.to_string());
return Ok(Some(Message::binary(resp.to_bytes())));
}
};
if request.method == Method::OPTIONS && self.config.enable_cors {
let resp = self.cors_preflight_response();
return Ok(Some(Message::binary(resp.to_bytes())));
}
let mut response = self.router.read().handle(request);
if self.config.enable_cors {
self.add_cors_headers(&mut response);
}
Ok(Some(Message::binary(response.to_bytes())))
}
}
impl ApiHandler {
fn cors_preflight_response(&self) -> Response {
let origin = if self.config.cors_origins.contains(&"*".to_string()) {
"*".to_string()
} else {
self.config.cors_origins.join(", ")
};
Response::new(204)
.header("Access-Control-Allow-Origin", &origin)
.header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
)
.header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization",
)
.header("Access-Control-Max-Age", "86400")
}
fn add_cors_headers(&self, response: &mut Response) {
let origin = if self.config.cors_origins.contains(&"*".to_string()) {
"*".to_string()
} else {
self.config.cors_origins.join(", ")
};
response
.headers
.insert("Access-Control-Allow-Origin".to_string(), origin);
}
}
pub struct ApiServer {
config: ApiServerConfig,
router: Arc<RwLock<Router>>,
}
impl ApiServer {
pub fn new(config: ApiServerConfig) -> Self {
Self {
config,
router: Arc::new(RwLock::new(Router::new())),
}
}
pub fn router(&self) -> impl std::ops::DerefMut<Target = Router> + '_ {
self.router.write()
}
pub fn run(self) -> crate::Result<()> {
let handler = ApiHandler {
router: Arc::clone(&self.router),
config: self.config.clone(),
};
let server = SocketServer::new(self.config.socket_config)?;
server.run(handler)
}
pub fn spawn(self) -> std::thread::JoinHandle<crate::Result<()>> {
std::thread::spawn(move || self.run())
}
}
pub struct ApiClient {
socket_path: String,
timeout: Option<std::time::Duration>,
}
impl ApiClient {
pub fn new(socket_path: &str) -> Self {
Self {
socket_path: socket_path.to_string(),
timeout: None,
}
}
pub fn with_timeout(socket_path: &str, timeout: std::time::Duration) -> Self {
Self {
socket_path: socket_path.to_string(),
timeout: Some(timeout),
}
}
pub fn connect() -> Self {
Self::new(&SocketServerConfig::default().path)
}
pub fn connect_timeout(timeout: std::time::Duration) -> Self {
Self::with_timeout(&SocketServerConfig::default().path, timeout)
}
pub fn set_timeout(&mut self, timeout: Option<std::time::Duration>) {
self.timeout = timeout;
}
pub fn get_timeout(&self) -> Option<std::time::Duration> {
self.timeout
}
pub fn get(&self, path: &str) -> crate::Result<JsonValue> {
self.request(Method::GET, path, None)
}
pub fn post(&self, path: &str, body: Option<JsonValue>) -> crate::Result<JsonValue> {
self.request(Method::POST, path, body)
}
pub fn put(&self, path: &str, body: Option<JsonValue>) -> crate::Result<JsonValue> {
self.request(Method::PUT, path, body)
}
pub fn delete(&self, path: &str) -> crate::Result<JsonValue> {
self.request(Method::DELETE, path, None)
}
fn request(
&self,
method: Method,
path: &str,
body: Option<JsonValue>,
) -> crate::Result<JsonValue> {
let mut client = match self.timeout {
Some(timeout) => SocketClient::connect_timeout(&self.socket_path, timeout)?,
None => SocketClient::connect(&self.socket_path)?,
};
let body_bytes = body
.as_ref()
.map(|b| serde_json::to_vec(b).unwrap_or_default())
.unwrap_or_default();
let request_str = format!(
"{} {} HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n",
method.as_str(),
path,
body_bytes.len()
);
let mut request_bytes = request_str.into_bytes();
request_bytes.extend(body_bytes);
let msg = Message::binary(request_bytes);
client.send(&msg)?;
let response = client.recv()?;
if let Some(binary_data) = response.as_binary() {
if let Some(body_start) = find_body_start(&binary_data) {
let body = &binary_data[body_start..];
serde_json::from_slice(body).map_err(|e| IpcError::Serialization(e.to_string()))
} else {
Ok(JsonValue::Null)
}
} else if let Some(text) = response.as_text() {
serde_json::from_str(text).map_err(|e| IpcError::Deserialization(e.to_string()))
} else {
Ok(response.payload)
}
}
}
fn find_body_start(data: &[u8]) -> Option<usize> {
for i in 0..data.len().saturating_sub(3) {
if &data[i..i + 4] == b"\r\n\r\n" {
return Some(i + 4);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_pattern_static() {
let pattern = PathPattern::parse("/v1/tasks");
assert!(pattern.matches("/v1/tasks").is_some());
assert!(pattern.matches("/v1/tasks/").is_some());
assert!(pattern.matches("/v1/other").is_none());
}
#[test]
fn test_path_pattern_param() {
let pattern = PathPattern::parse("/v1/tasks/{id}");
let params = pattern.matches("/v1/tasks/123").unwrap();
assert_eq!(params.get("id"), Some(&"123".to_string()));
let params = pattern.matches("/v1/tasks/abc").unwrap();
assert_eq!(params.get("id"), Some(&"abc".to_string()));
assert!(pattern.matches("/v1/tasks").is_none());
assert!(pattern.matches("/v1/tasks/123/extra").is_none());
}
#[test]
fn test_path_pattern_wildcard() {
let pattern = PathPattern::parse("/files/{*path}");
let params = pattern.matches("/files/a/b/c").unwrap();
assert_eq!(params.get("path"), Some(&"a/b/c".to_string()));
let params = pattern.matches("/files/single").unwrap();
assert_eq!(params.get("path"), Some(&"single".to_string()));
}
#[test]
fn test_router() {
let mut router = Router::new();
router.get("/v1/tasks", |_| Response::ok(serde_json::json!([])));
router.get("/v1/tasks/{id}", |req| {
let id = req.params.get("id").unwrap();
Response::ok(serde_json::json!({"id": id}))
});
let req = Request::new(Method::GET, "/v1/tasks");
let resp = router.handle(req);
assert_eq!(resp.status, 200);
let req = Request::new(Method::GET, "/v1/tasks/123");
let resp = router.handle(req);
assert_eq!(resp.status, 200);
let req = Request::new(Method::GET, "/not/found");
let resp = router.handle(req);
assert_eq!(resp.status, 404);
}
#[test]
fn test_response_to_bytes() {
let resp = Response::ok(serde_json::json!({"key": "value"}));
let bytes = resp.to_bytes();
let text = String::from_utf8_lossy(&bytes);
assert!(text.contains("HTTP/1.1 200 OK"));
assert!(text.contains("Content-Type: application/json"));
assert!(text.contains("\"key\":\"value\""));
}
#[test]
fn test_request_parse() {
let raw = b"GET /v1/tasks?limit=10 HTTP/1.1\r\nHost: localhost\r\n\r\n";
let req = Request::parse(raw).unwrap();
assert_eq!(req.method, Method::GET);
assert_eq!(req.path, "/v1/tasks");
assert_eq!(req.query.get("limit"), Some(&"10".to_string()));
}
}