use crate::rule::model::BodySource;
use crate::rule::engine::executor::ExecutionContext;
use relay_core_api::flow::{BodyData, Flow, Layer};
use relay_core_api::policy::ProxyPolicy;
use crate::utils::path::PathSanitizer;
use chrono::Utc;
use uuid::Uuid;
pub async fn resolve_body_source(source: &BodySource, policy: Option<&ProxyPolicy>) -> Option<BodyData> {
match source {
BodySource::Text(t) => Some(BodyData {
encoding: "utf-8".to_string(),
content: t.clone(),
size: t.len() as u64,
}),
BodySource::Base64(b) => Some(BodyData {
encoding: "base64".to_string(),
content: b.clone(),
size: b.len() as u64, }),
BodySource::File(path) => {
let root = if let Some(p) = policy {
if let Some(r) = &p.sandbox_root {
r.clone()
} else {
std::env::current_dir().unwrap_or_default()
}
} else {
std::env::current_dir().unwrap_or_default()
};
let sanitizer = PathSanitizer::new(root);
if let Ok(canon_path) = sanitizer.sanitize(path) {
if let Ok(metadata) = tokio::fs::metadata(&canon_path).await {
let max_bytes = policy.map(|p| p.max_local_file_bytes).unwrap_or(10 * 1024 * 1024);
if metadata.len() > max_bytes as u64 {
return None;
}
}
if let Ok(bytes) = tokio::fs::read(&canon_path).await {
if let Ok(text) = String::from_utf8(bytes.clone()) {
Some(BodyData {
encoding: "utf-8".to_string(),
content: text,
size: bytes.len() as u64,
})
} else {
use data_encoding::BASE64;
Some(BodyData {
encoding: "base64".to_string(),
content: BASE64.encode(&bytes),
size: bytes.len() as u64,
})
}
} else {
None
}
} else {
None
}
}
}
}
pub fn substitute_variables(
template: &str,
flow: &Flow,
ctx: &ExecutionContext,
previous_value: Option<&str>,
) -> String {
let mut result = template.to_string();
for (k, v) in &ctx.variables {
let key = format!("{{{{{}}}}}", k); if result.contains(&key) {
result = result.replace(&key, v);
}
}
if let Some(prev) = previous_value {
result = result.replace("{{previous}}", prev);
} else {
result = result.replace("{{previous}}", "");
}
if result.contains("{{timestamp}}") {
result = result.replace("{{timestamp}}", &Utc::now().timestamp_millis().to_string());
}
if result.contains("{{uuid}}") {
result = result.replace("{{uuid}}", &Uuid::new_v4().to_string());
}
if result.contains("{{client.ip}}") {
result = result.replace("{{client.ip}}", &flow.network.client_ip);
}
if result.contains("{{client_ip}}") {
result = result.replace("{{client_ip}}", &flow.network.client_ip);
}
if result.contains("{{server.ip}}") {
result = result.replace("{{server.ip}}", &flow.network.server_ip);
}
if result.contains("{{server_ip}}") {
result = result.replace("{{server_ip}}", &flow.network.server_ip);
}
if result.contains("{{server.port}}") {
result = result.replace("{{server.port}}", &flow.network.server_port.to_string());
}
if result.contains("{{server_port}}") {
result = result.replace("{{server_port}}", &flow.network.server_port.to_string());
}
let request = match &flow.layer {
Layer::Http(http) => Some(&http.request),
Layer::WebSocket(ws) => Some(&ws.handshake_request),
_ => None,
};
if let Some(req) = request {
if result.contains("{{request.method}}") {
result = result.replace("{{request.method}}", &req.method);
}
if result.contains("{{request.host}}") {
if let Some(host) = req.url.host_str() {
result = result.replace("{{request.host}}", host);
} else {
result = result.replace("{{request.host}}", "");
}
}
if result.contains("{{request.url}}") {
result = result.replace("{{request.url}}", req.url.as_str());
}
if result.contains("{{request.path}}") {
result = result.replace("{{request.path}}", req.url.path());
}
if result.contains("{{request.query}}") {
if let Some(q) = req.url.query() {
result = result.replace("{{request.query}}", q);
} else {
result = result.replace("{{request.query}}", "");
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::{resolve_body_source, substitute_variables};
use crate::rule::engine::executor::ExecutionContext;
use crate::rule::engine::state::InMemoryRuleStateStore;
use crate::rule::model::BodySource;
use crate::rule::model::RuleTraceSummary;
use chrono::Utc;
use relay_core_api::flow::{
Flow, HttpLayer, HttpRequest, HttpResponse, Layer, NetworkInfo, ResponseTiming,
TransportProtocol, WebSocketLayer,
};
use relay_core_api::policy::ProxyPolicy;
use std::collections::HashMap;
use std::sync::Arc;
use url::Url;
use uuid::Uuid;
fn sample_flow() -> Flow {
Flow {
id: Uuid::new_v4(),
start_time: Utc::now(),
end_time: None,
network: NetworkInfo {
client_ip: "127.0.0.1".to_string(),
client_port: 12345,
server_ip: "1.1.1.1".to_string(),
server_port: 443,
protocol: TransportProtocol::TCP,
tls: true,
tls_version: Some("TLS1.3".to_string()),
sni: Some("example.com".to_string()),
},
layer: Layer::Http(HttpLayer {
request: HttpRequest {
method: "GET".to_string(),
url: Url::parse("https://example.com/path?q=abc").expect("url"),
version: "HTTP/1.1".to_string(),
headers: vec![],
cookies: vec![],
query: vec![],
body: None,
},
response: None,
error: None,
}),
tags: vec![],
meta: HashMap::new(),
}
}
fn sample_ctx() -> ExecutionContext {
ExecutionContext {
trace: vec![],
variables: HashMap::new(),
policy: None,
summary: RuleTraceSummary::NoMatch,
state_store: Arc::new(InMemoryRuleStateStore::new()),
}
}
fn sample_ws_flow() -> Flow {
Flow {
id: Uuid::new_v4(),
start_time: Utc::now(),
end_time: None,
network: NetworkInfo {
client_ip: "127.0.0.1".to_string(),
client_port: 23456,
server_ip: "2.2.2.2".to_string(),
server_port: 443,
protocol: TransportProtocol::TCP,
tls: true,
tls_version: Some("TLS1.3".to_string()),
sni: Some("ws.example.com".to_string()),
},
layer: Layer::WebSocket(WebSocketLayer {
handshake_request: HttpRequest {
method: "GET".to_string(),
url: Url::parse("wss://ws.example.com/socket?q=1").expect("url"),
version: "HTTP/1.1".to_string(),
headers: vec![],
cookies: vec![],
query: vec![],
body: None,
},
handshake_response: HttpResponse {
status: 101,
status_text: "Switching Protocols".to_string(),
version: "HTTP/1.1".to_string(),
headers: vec![],
cookies: vec![],
body: None,
timing: ResponseTiming {
time_to_first_byte: None,
time_to_last_byte: None,
connect_time_ms: None,
ssl_time_ms: None,
},
},
messages: vec![],
closed: false,
}),
tags: vec![],
meta: HashMap::new(),
}
}
#[test]
fn test_substitute_previous_and_request_fields() {
let flow = sample_flow();
let mut ctx = sample_ctx();
ctx.variables.insert("env".to_string(), "dev".to_string());
let out = substitute_variables(
"v={{previous}},m={{request.method}},h={{request.host}},p={{request.path}},e={{env}}",
&flow,
&ctx,
Some("old"),
);
assert_eq!(out, "v=old,m=GET,h=example.com,p=/path,e=dev");
}
#[test]
fn test_substitute_timestamp_is_unix_millis() {
let flow = sample_flow();
let ctx = sample_ctx();
let out = substitute_variables("ts={{timestamp}}", &flow, &ctx, None);
let ts = out.strip_prefix("ts=").expect("prefix");
let millis = ts.parse::<i64>().expect("timestamp millis");
assert!(millis > 0);
}
#[test]
fn test_substitute_network_legacy_aliases() {
let flow = sample_flow();
let ctx = sample_ctx();
let out = substitute_variables(
"c={{client_ip}},s={{server_ip}},p={{server_port}}",
&flow,
&ctx,
None,
);
assert_eq!(out, "c=127.0.0.1,s=1.1.1.1,p=443");
}
#[test]
fn test_substitute_request_fields_for_websocket_handshake() {
let flow = sample_ws_flow();
let ctx = sample_ctx();
let out = substitute_variables(
"m={{request.method}},h={{request.host}},p={{request.path}},q={{request.query}}",
&flow,
&ctx,
None,
);
assert_eq!(out, "m=GET,h=ws.example.com,p=/socket,q=q=1");
}
#[tokio::test]
async fn test_resolve_body_source_file_too_large_returns_none() {
let temp_dir = std::env::temp_dir().join(format!("relay-utils-test-{}", Uuid::new_v4()));
std::fs::create_dir_all(&temp_dir).expect("create dir");
let file = temp_dir.join("large.txt");
std::fs::write(&file, vec![b'a'; 33]).expect("write file");
let policy = ProxyPolicy {
sandbox_root: Some(temp_dir.clone()),
max_local_file_bytes: 32,
..Default::default()
};
let out = resolve_body_source(
&BodySource::File(file.to_string_lossy().to_string()),
Some(&policy),
)
.await;
assert!(out.is_none(), "large file should be rejected");
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[tokio::test]
async fn test_resolve_body_source_binary_falls_back_to_base64() {
let temp_dir = std::env::temp_dir().join(format!("relay-utils-test-{}", Uuid::new_v4()));
std::fs::create_dir_all(&temp_dir).expect("create dir");
let file = temp_dir.join("bin.dat");
std::fs::write(&file, vec![0xff, 0xfe, 0xfd, 0x00]).expect("write file");
let policy = ProxyPolicy {
sandbox_root: Some(temp_dir.clone()),
max_local_file_bytes: 1024,
..Default::default()
};
let out = resolve_body_source(
&BodySource::File(file.to_string_lossy().to_string()),
Some(&policy),
)
.await
.expect("binary file should still be loadable");
assert_eq!(out.encoding, "base64");
assert_eq!(out.size, 4);
assert!(!out.content.is_empty(), "base64 payload should not be empty");
let _ = std::fs::remove_dir_all(&temp_dir);
}
}