#![allow(dead_code)]
use rquickjs::{Context, Runtime, Value};
use std::collections::HashMap;
use crate::resources::Module as DGateModule;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModuleFuncType {
RequestModifier,
ResponseModifier,
RequestHandler,
ErrorHandler,
FetchUpstreamUrl,
}
impl ModuleFuncType {
pub fn export_name(&self) -> &'static str {
match self {
ModuleFuncType::RequestModifier => "requestModifier",
ModuleFuncType::ResponseModifier => "responseModifier",
ModuleFuncType::RequestHandler => "requestHandler",
ModuleFuncType::ErrorHandler => "errorHandler",
ModuleFuncType::FetchUpstreamUrl => "fetchUpstreamUrl",
}
}
}
#[derive(Debug, Clone)]
pub struct RequestContext {
pub method: String,
pub path: String,
pub query: HashMap<String, String>,
pub headers: HashMap<String, String>,
pub body: Option<Vec<u8>>,
pub params: HashMap<String, String>,
pub route_name: String,
pub namespace: String,
pub service_name: Option<String>,
pub documents: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct ResponseContext {
pub status_code: u16,
pub headers: HashMap<String, String>,
pub body: Option<Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct ModuleResponse {
pub status_code: u16,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
pub documents: HashMap<String, serde_json::Value>,
}
impl Default for ModuleResponse {
fn default() -> Self {
Self {
status_code: 200,
headers: HashMap::new(),
body: Vec::new(),
documents: HashMap::new(),
}
}
}
pub struct CompiledModule {
source: String,
pub name: String,
pub namespace: String,
}
impl CompiledModule {
pub fn new(module: &DGateModule) -> Result<Self, ModuleError> {
let source = module
.decode_payload()
.map_err(|e| ModuleError::DecodeError(e.to_string()))?;
Ok(Self {
source,
name: module.name.clone(),
namespace: module.namespace.clone(),
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum ModuleError {
#[error("Failed to decode module payload: {0}")]
DecodeError(String),
#[error("JavaScript error: {0}")]
JsError(String),
#[error("Module function not found: {0}")]
FunctionNotFound(String),
#[error("Invalid return value from module")]
InvalidReturnValue,
#[error("Module compilation failed: {0}")]
CompilationError(String),
#[error("Runtime error: {0}")]
RuntimeError(String),
}
pub struct ModuleExecutor {
modules: HashMap<String, CompiledModule>,
}
impl ModuleExecutor {
pub fn new() -> Self {
Self {
modules: HashMap::new(),
}
}
pub fn add_module(&mut self, module: &DGateModule) -> Result<(), ModuleError> {
let key = format!("{}:{}", module.namespace, module.name);
let compiled = CompiledModule::new(module)?;
self.modules.insert(key, compiled);
Ok(())
}
pub fn remove_module(&mut self, namespace: &str, name: &str) {
let key = format!("{}:{}", namespace, name);
self.modules.remove(&key);
}
pub fn get_module(&self, namespace: &str, name: &str) -> Option<&CompiledModule> {
let key = format!("{}:{}", namespace, name);
self.modules.get(&key)
}
pub fn create_context(&self, modules: &[String], namespace: &str) -> ModuleContext {
let sources: Vec<_> = modules
.iter()
.filter_map(|name| self.get_module(namespace, name))
.map(|m| m.source.clone())
.collect();
ModuleContext::new(sources)
}
}
impl Default for ModuleExecutor {
fn default() -> Self {
Self::new()
}
}
pub struct ModuleContext {
sources: Vec<String>,
}
impl ModuleContext {
pub fn new(sources: Vec<String>) -> Self {
Self { sources }
}
pub fn execute_request_modifier(
&self,
req_ctx: &RequestContext,
) -> Result<RequestContext, ModuleError> {
if self.sources.is_empty() {
return Ok(req_ctx.clone());
}
let runtime = Runtime::new().map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
let context =
Context::full(&runtime).map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
let mut result = req_ctx.clone();
context.with(|ctx| -> Result<(), ModuleError> {
for source in &self.sources {
self.setup_globals(&ctx)?;
self.setup_request_context(&ctx, &result)?;
let wrapped = format!(
r#"
{}
if (typeof requestModifier === 'function') {{
requestModifier(__ctx__);
}}
"#,
source
);
ctx.eval::<Value, _>(wrapped.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
result = self.extract_request_context(&ctx)?;
}
Ok(())
})?;
Ok(result)
}
pub fn execute_request_handler(
&self,
req_ctx: &RequestContext,
) -> Result<ModuleResponse, ModuleError> {
if self.sources.is_empty() {
return Err(ModuleError::FunctionNotFound("requestHandler".to_string()));
}
let runtime = Runtime::new().map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
let context =
Context::full(&runtime).map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
context.with(|ctx| {
self.setup_globals(&ctx)?;
self.setup_request_context(&ctx, req_ctx)?;
self.setup_response_writer(&ctx)?;
let combined: String = self.sources.join("\n");
let wrapped = format!(
r#"
{}
if (typeof requestHandler === 'function') {{
requestHandler(__ctx__);
}} else {{
throw new Error('requestHandler function not found');
}}
"#,
combined
);
ctx.eval::<Value, _>(wrapped.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
self.extract_response(&ctx)
})
}
pub fn execute_response_modifier(
&self,
req_ctx: &RequestContext,
res_ctx: &ResponseContext,
) -> Result<ResponseContext, ModuleError> {
if self.sources.is_empty() {
return Ok(res_ctx.clone());
}
let runtime = Runtime::new().map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
let context =
Context::full(&runtime).map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
context.with(|ctx| {
self.setup_globals(&ctx)?;
self.setup_request_context(&ctx, req_ctx)?;
self.setup_response_context(&ctx, res_ctx)?;
let combined: String = self.sources.join("\n");
let wrapped = format!(
r#"
{}
if (typeof responseModifier === 'function') {{
responseModifier(__ctx__, __res__);
}}
"#,
combined
);
ctx.eval::<Value, _>(wrapped.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
self.extract_response_context(&ctx)
})
}
pub fn execute_error_handler(
&self,
req_ctx: &RequestContext,
error: &str,
) -> Result<ModuleResponse, ModuleError> {
if self.sources.is_empty() {
return Err(ModuleError::FunctionNotFound("errorHandler".to_string()));
}
let runtime = Runtime::new().map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
let context =
Context::full(&runtime).map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
context.with(|ctx| {
self.setup_globals(&ctx)?;
self.setup_request_context(&ctx, req_ctx)?;
self.setup_response_writer(&ctx)?;
let combined: String = self.sources.join("\n");
let wrapped = format!(
r#"
{}
if (typeof errorHandler === 'function') {{
errorHandler(__ctx__, new Error({}));
}}
"#,
combined,
serde_json::to_string(error).unwrap_or_else(|_| "\"Unknown error\"".to_string())
);
ctx.eval::<Value, _>(wrapped.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
self.extract_response(&ctx)
})
}
pub fn execute_fetch_upstream_url(
&self,
req_ctx: &RequestContext,
) -> Result<Option<String>, ModuleError> {
if self.sources.is_empty() {
return Ok(None);
}
let runtime = Runtime::new().map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
let context =
Context::full(&runtime).map_err(|e| ModuleError::RuntimeError(e.to_string()))?;
context.with(|ctx| {
self.setup_globals(&ctx)?;
self.setup_request_context(&ctx, req_ctx)?;
let combined: String = self.sources.join("\n");
let wrapped = format!(
r#"
{}
var __upstream_url__ = null;
if (typeof fetchUpstreamUrl === 'function') {{
__upstream_url__ = fetchUpstreamUrl(__ctx__);
}}
__upstream_url__;
"#,
combined
);
let result: Value = ctx
.eval(wrapped.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
if result.is_null() || result.is_undefined() {
Ok(None)
} else if let Some(s) = result.as_string() {
Ok(Some(
s.to_string()
.map_err(|e| ModuleError::JsError(e.to_string()))?,
))
} else {
Ok(None)
}
})
}
fn setup_globals(&self, ctx: &rquickjs::Ctx) -> Result<(), ModuleError> {
let console_code = r#"
var console = {
log: function() {
// No-op in production, could be wired to tracing
},
error: function() {},
warn: function() {},
info: function() {}
};
"#;
ctx.eval::<Value, _>(console_code)
.map_err(|e| ModuleError::JsError(e.to_string()))?;
let base64_code = r#"
function btoa(str) {
var chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=';
var encoded = '';
var i = 0;
while (i < str.length) {
var a = str.charCodeAt(i++);
var b = str.charCodeAt(i++);
var c = str.charCodeAt(i++);
var enc1 = a >> 2;
var enc2 = ((a & 3) << 4) | (b >> 4);
var enc3 = ((b & 15) << 2) | (c >> 6);
var enc4 = c & 63;
if (isNaN(b)) { enc3 = enc4 = 64; }
else if (isNaN(c)) { enc4 = 64; }
encoded += chars.charAt(enc1) + chars.charAt(enc2) + chars.charAt(enc3) + chars.charAt(enc4);
}
return encoded;
}
function atob(str) {
var chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=';
var decoded = '';
var i = 0;
str = str.replace(/[^A-Za-z0-9\+\/\=]/g, '');
while (i < str.length) {
var enc1 = chars.indexOf(str.charAt(i++));
var enc2 = chars.indexOf(str.charAt(i++));
var enc3 = chars.indexOf(str.charAt(i++));
var enc4 = chars.indexOf(str.charAt(i++));
var a = (enc1 << 2) | (enc2 >> 4);
var b = ((enc2 & 15) << 4) | (enc3 >> 2);
var c = ((enc3 & 3) << 6) | enc4;
decoded += String.fromCharCode(a);
if (enc3 !== 64) decoded += String.fromCharCode(b);
if (enc4 !== 64) decoded += String.fromCharCode(c);
}
return decoded;
}
"#;
ctx.eval::<Value, _>(base64_code)
.map_err(|e| ModuleError::JsError(e.to_string()))?;
Ok(())
}
fn setup_request_context(
&self,
ctx: &rquickjs::Ctx,
req_ctx: &RequestContext,
) -> Result<(), ModuleError> {
let ctx_json = serde_json::json!({
"request": {
"method": req_ctx.method,
"path": req_ctx.path,
"query": req_ctx.query,
"headers": req_ctx.headers,
"body": req_ctx.body.as_ref().map(|b| String::from_utf8_lossy(b).to_string()),
},
"params": req_ctx.params,
"route": req_ctx.route_name,
"namespace": req_ctx.namespace,
"service": req_ctx.service_name,
"documents": req_ctx.documents,
});
let setup_code = format!(
r#"var __ctx__ = {};
__ctx__.response = {{ statusCode: 200, headers: {{}}, body: '' }};
__ctx__._docStore = __ctx__.documents || {{}};
"#,
ctx_json
);
ctx.eval::<Value, _>(setup_code.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
Ok(())
}
fn setup_response_writer(&self, ctx: &rquickjs::Ctx) -> Result<(), ModuleError> {
let setup_code = r#"
__ctx__.response = { statusCode: 200, headers: {}, body: '' };
__ctx__.setHeader = function(name, value) {
__ctx__.response.headers[name] = value;
};
__ctx__.setStatus = function(code) {
__ctx__.response.statusCode = code;
};
__ctx__.write = function(data) {
__ctx__.response.body += data;
};
__ctx__.json = function(data) {
__ctx__.response.headers['Content-Type'] = 'application/json';
__ctx__.response.body = JSON.stringify(data);
};
__ctx__.redirect = function(url, status) {
__ctx__.response.statusCode = status || 302;
__ctx__.response.headers['Location'] = url;
__ctx__.response.body = '';
};
__ctx__.pathParam = function(name) {
return __ctx__.params[name] || null;
};
__ctx__.queryParam = function(name) {
return __ctx__.request.query[name] || null;
};
__ctx__.status = function(code) {
__ctx__.response.statusCode = code;
return __ctx__;
};
// Document storage functions
__ctx__.getDocument = function(collection, id) {
var key = collection + ':' + id;
var doc = __ctx__._docStore[key];
return doc ? { data: doc } : null;
};
__ctx__.setDocument = function(collection, id, data) {
var key = collection + ':' + id;
__ctx__._docStore[key] = data;
__ctx__._docsModified = true;
};
__ctx__.deleteDocument = function(collection, id) {
var key = collection + ':' + id;
delete __ctx__._docStore[key];
__ctx__._docsModified = true;
};
// Simple hash function for URL shortener
__ctx__.hashString = function(str) {
var hash = 0;
for (var i = 0; i < str.length; i++) {
var char = str.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash; // Convert to 32bit integer
}
// Convert to base36 and take last 8 characters
return Math.abs(hash).toString(36).slice(-8).padStart(8, '0');
};
"#;
ctx.eval::<Value, _>(setup_code)
.map_err(|e| ModuleError::JsError(e.to_string()))?;
Ok(())
}
fn setup_response_context(
&self,
ctx: &rquickjs::Ctx,
res_ctx: &ResponseContext,
) -> Result<(), ModuleError> {
let res_json = serde_json::json!({
"statusCode": res_ctx.status_code,
"headers": res_ctx.headers,
"body": res_ctx.body.as_ref().map(|b| String::from_utf8_lossy(b).to_string()),
});
let setup_code = format!("var __res__ = {};", res_json);
ctx.eval::<Value, _>(setup_code.as_str())
.map_err(|e| ModuleError::JsError(e.to_string()))?;
Ok(())
}
fn extract_request_context(&self, ctx: &rquickjs::Ctx) -> Result<RequestContext, ModuleError> {
let extract_code = r#"
JSON.stringify({
method: __ctx__.request.method,
path: __ctx__.request.path,
query: __ctx__.request.query,
headers: __ctx__.request.headers,
body: __ctx__.request.body,
params: __ctx__.params,
route: __ctx__.route,
namespace: __ctx__.namespace,
service: __ctx__.service,
documents: __ctx__._docStore || {}
})
"#;
let result: String = ctx
.eval(extract_code)
.map_err(|e| ModuleError::JsError(e.to_string()))?;
let parsed: serde_json::Value =
serde_json::from_str(&result).map_err(|e| ModuleError::JsError(e.to_string()))?;
Ok(RequestContext {
method: parsed["method"].as_str().unwrap_or("GET").to_string(),
path: parsed["path"].as_str().unwrap_or("/").to_string(),
query: parsed["query"]
.as_object()
.map(|o| {
o.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default(),
headers: parsed["headers"]
.as_object()
.map(|o| {
o.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default(),
body: parsed["body"].as_str().map(|s| s.as_bytes().to_vec()),
params: parsed["params"]
.as_object()
.map(|o| {
o.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default(),
route_name: parsed["route"].as_str().unwrap_or("").to_string(),
namespace: parsed["namespace"]
.as_str()
.unwrap_or("default")
.to_string(),
service_name: parsed["service"].as_str().map(|s| s.to_string()),
documents: parsed["documents"]
.as_object()
.map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
.unwrap_or_default(),
})
}
fn extract_response(&self, ctx: &rquickjs::Ctx) -> Result<ModuleResponse, ModuleError> {
let extract_code = r#"
JSON.stringify({
response: __ctx__.response,
documents: __ctx__._docStore || {}
})
"#;
let result: String = ctx
.eval(extract_code)
.map_err(|e| ModuleError::JsError(e.to_string()))?;
let parsed: serde_json::Value =
serde_json::from_str(&result).map_err(|e| ModuleError::JsError(e.to_string()))?;
let response = &parsed["response"];
Ok(ModuleResponse {
status_code: response["statusCode"].as_u64().unwrap_or(200) as u16,
headers: response["headers"]
.as_object()
.map(|o| {
o.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default(),
body: response["body"]
.as_str()
.map(|s| s.as_bytes().to_vec())
.unwrap_or_default(),
documents: parsed["documents"]
.as_object()
.map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
.unwrap_or_default(),
})
}
fn extract_response_context(
&self,
ctx: &rquickjs::Ctx,
) -> Result<ResponseContext, ModuleError> {
let extract_code = r#"
JSON.stringify(__res__)
"#;
let result: String = ctx
.eval(extract_code)
.map_err(|e| ModuleError::JsError(e.to_string()))?;
let parsed: serde_json::Value =
serde_json::from_str(&result).map_err(|e| ModuleError::JsError(e.to_string()))?;
Ok(ResponseContext {
status_code: parsed["statusCode"].as_u64().unwrap_or(200) as u16,
headers: parsed["headers"]
.as_object()
.map(|o| {
o.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default(),
body: parsed["body"].as_str().map(|s| s.as_bytes().to_vec()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_modifier() {
let source = r#"
function requestModifier(ctx) {
ctx.request.headers['X-Modified'] = 'true';
}
"#;
let ctx = ModuleContext::new(vec![source.to_string()]);
let req = RequestContext {
method: "GET".to_string(),
path: "/test".to_string(),
query: HashMap::new(),
headers: HashMap::new(),
body: None,
params: HashMap::new(),
route_name: "test".to_string(),
namespace: "default".to_string(),
service_name: None,
documents: HashMap::new(),
};
let result = ctx.execute_request_modifier(&req).unwrap();
assert_eq!(result.headers.get("X-Modified"), Some(&"true".to_string()));
}
#[test]
fn test_request_handler() {
let source = r#"
function requestHandler(ctx) {
ctx.setStatus(201);
ctx.setHeader('Content-Type', 'text/plain');
ctx.write('Hello, World!');
}
"#;
let ctx = ModuleContext::new(vec![source.to_string()]);
let req = RequestContext {
method: "GET".to_string(),
path: "/test".to_string(),
query: HashMap::new(),
headers: HashMap::new(),
body: None,
params: HashMap::new(),
route_name: "test".to_string(),
namespace: "default".to_string(),
service_name: None,
documents: HashMap::new(),
};
let result = ctx.execute_request_handler(&req).unwrap();
assert_eq!(result.status_code, 201);
assert_eq!(
result.headers.get("Content-Type"),
Some(&"text/plain".to_string())
);
assert_eq!(String::from_utf8_lossy(&result.body), "Hello, World!");
}
}