use mctrust::{Environment, Outcome, Reward};
use wafrift_encoding::encoding;
use wafrift_grammar::grammar;
use wafrift_types::{Request, Technique};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TechniqueAction {
Encode(String),
GrammarMutate(String),
ContentTypeSwitch(String),
HeaderTrick(String),
}
impl TechniqueAction {
#[must_use]
pub fn to_technique(&self) -> Technique {
match self {
Self::Encode(name) => Technique::PayloadEncoding(name.clone()),
Self::GrammarMutate(name) => Technique::GrammarMutation(name.clone()),
Self::ContentTypeSwitch(name) => Technique::ContentTypeSwitch(name.clone()),
Self::HeaderTrick(name) => Technique::HeaderObfuscation(name.clone()),
}
}
}
#[derive(Clone)]
pub struct WafRiftEnv {
pub req: Request,
pub applied_techniques: Vec<Technique>,
pub max_depth: usize,
grammar_mutations: Vec<grammar::GrammarMutation>,
grammar_applied: bool,
content_type_applied: bool,
header_applied: bool,
pub sql_dialect: wafrift_oracle::sql::DatabaseDialect,
}
impl WafRiftEnv {
#[must_use]
pub fn new(req: Request, max_depth: usize) -> Self {
Self::with_dialect(
req,
max_depth,
wafrift_oracle::sql::DatabaseDialect::Generic,
)
}
#[must_use]
pub fn with_dialect(
req: Request,
max_depth: usize,
sql_dialect: wafrift_oracle::sql::DatabaseDialect,
) -> Self {
let grammar_mutations = req
.body
.as_ref()
.filter(|_| crate::strategy::is_text_payload(&req))
.and_then(|body| {
let body_str = match std::str::from_utf8(body) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "MCTS bridge skipped non-UTF-8 body");
return None;
}
};
let payload = body_str
.split('&')
.filter_map(|pair| pair.split_once('=').map(|(_, v)| v))
.collect::<Vec<_>>()
.join(" ");
Some(if payload.is_empty() {
grammar::mutate(body_str, 8)
} else {
grammar::mutate(&payload, 8)
})
})
.unwrap_or_default();
Self {
req,
applied_techniques: Vec::new(),
max_depth,
grammar_mutations,
grammar_applied: false,
content_type_applied: false,
header_applied: false,
sql_dialect,
}
}
}
impl Environment for WafRiftEnv {
type Action = TechniqueAction;
fn legal_actions(&self) -> Vec<Self::Action> {
let mut actions = Vec::new();
if self.applied_techniques.len() >= self.max_depth {
return actions;
}
for strat in encoding::all_strategies() {
let tech_name = strat.as_str().to_string();
if let Some(Technique::PayloadEncoding(last)) = self.applied_techniques.last()
&& last == &tech_name
{
continue;
}
actions.push(TechniqueAction::Encode(tech_name));
}
if !self.grammar_applied {
for mutation in &self.grammar_mutations {
let desc = mutation.rules_applied.first().copied().unwrap_or("grammar");
actions.push(TechniqueAction::GrammarMutate(desc.to_string()));
}
}
if !self.content_type_applied
&& let Some(ref body) = self.req.body
&& crate::strategy::is_text_payload(&self.req)
{
let params = wafrift_content_type::parse_form_body(body);
if !params.is_empty() {
actions.push(TechniqueAction::ContentTypeSwitch("Multipart".to_string()));
actions.push(TechniqueAction::ContentTypeSwitch(
"JsonUnicodeEscape".to_string(),
));
actions.push(TechniqueAction::ContentTypeSwitch("XmlCdata".to_string()));
actions.push(TechniqueAction::ContentTypeSwitch(
"MultipartQuotedBoundary".to_string(),
));
}
}
if !self.header_applied {
for trick in [
"CaseMixing",
"TabSeparator",
"WhitespacePadding",
"LineFolding",
"LfOnlyLineFolding",
"DuplicateHeader",
"UnderscoreSubstitution",
"NullByteInjection",
"TrailingSpace",
"MultiLineFolding",
"LfOnlyMultiLineFolding",
"CommaJoin",
] {
actions.push(TechniqueAction::HeaderTrick(trick.to_string()));
}
}
actions
}
fn apply(&mut self, action: &Self::Action) {
self.applied_techniques.push(action.to_technique());
match action {
TechniqueAction::Encode(encoding_name) => {
if let Some(strategy) = encoding::all_strategies()
.iter()
.copied()
.find(|s| s.as_str() == *encoding_name)
&& let Some(ref body) = self.req.body
&& crate::strategy::is_text_payload(&self.req)
{
if let Ok(encoded) = encoding::encode(body.as_slice(), strategy) {
self.req.body = Some(encoded.into_bytes());
}
}
}
TechniqueAction::GrammarMutate(rule_name) => {
if let Some(mutation) = self
.grammar_mutations
.iter()
.find(|m| m.rules_applied.first().copied() == Some(rule_name.as_str()))
{
if let Some(ref body) = self.req.body
&& crate::strategy::is_text_payload(&self.req)
{
let Ok(body_str) = std::str::from_utf8(body) else {
return;
};
if let Some((first_pair, rest)) = body_str.split_once('&') {
if let Some((key, _value)) = first_pair.split_once('=') {
let new_body = format!("{key}={}&{rest}", mutation.payload);
self.req.body = Some(new_body.into_bytes());
}
} else if let Some((key, _value)) = body_str.split_once('=') {
let new_body = format!("{key}={}", mutation.payload);
self.req.body = Some(new_body.into_bytes());
} else {
self.req.body = Some(mutation.payload.clone().into_bytes());
}
}
}
self.grammar_applied = true;
}
TechniqueAction::ContentTypeSwitch(technique_name) => {
if let Some(ref body) = self.req.body
&& crate::strategy::is_text_payload(&self.req)
{
let params = wafrift_content_type::parse_form_body(body);
if !params.is_empty() {
let variants = wafrift_content_type::generate_variants(¶ms);
if let Some(variant) = variants
.into_iter()
.find(|v| format!("{:?}", v.technique) == *technique_name)
{
self.req
.headers
.retain(|(k, _)| !k.eq_ignore_ascii_case("content-type"));
self.req
.headers
.push(("Content-Type".into(), variant.content_type));
self.req.body = Some(variant.body);
}
}
}
self.content_type_applied = true;
}
TechniqueAction::HeaderTrick(_) => {
if let Some(ct_idx) = self
.req
.headers
.iter()
.position(|(k, _)| k.eq_ignore_ascii_case("content-type"))
{
let (_, value) = &self.req.headers[ct_idx];
let mixed_name = wafrift_encoding::header::case_mix("Content-Type");
let value_clone = value.clone();
self.req.headers[ct_idx] = (mixed_name, value_clone);
}
self.header_applied = true;
}
}
}
fn evaluate(&self) -> Outcome {
if !self.applied_techniques.is_empty()
&& let Some(ref body) = self.req.body
&& crate::strategy::is_text_payload(&self.req)
{
let body_str = match std::str::from_utf8(body) {
Ok(s) => s,
Err(_) => return Outcome::Failure,
};
for pair in body_str.split('&') {
if let Some((_, v)) = pair.split_once('=') {
let decoded = urlencoding::decode(v)
.unwrap_or_else(|_| v.into())
.into_owned();
let payload_type = grammar::classify(&decoded);
match payload_type {
grammar::PayloadType::Sql => {
let looks_like_sql = decoded.contains('\'')
|| decoded.contains('=')
|| decoded.contains("--");
if looks_like_sql
&& !wafrift_oracle::sql::is_valid_expression_injection(
&decoded,
self.sql_dialect,
)
{
return Outcome::Failure;
}
}
grammar::PayloadType::Xss => {
use wafrift_oracle::traits::PayloadOracle;
let oracle = wafrift_oracle::xss::XssOracle;
if !oracle.is_semantically_valid(&decoded, &decoded) {
return Outcome::Failure;
}
}
grammar::PayloadType::TemplateInjection => {
use wafrift_oracle::traits::PayloadOracle;
let oracle = wafrift_oracle::ssti::SstiOracle;
if !oracle.is_semantically_valid(&decoded, &decoded) {
return Outcome::Failure;
}
}
grammar::PayloadType::CommandInjection => {
use wafrift_oracle::traits::PayloadOracle;
let oracle = wafrift_oracle::cmdi::CmdiOracle;
if !oracle.is_semantically_valid(&decoded, &decoded) {
return Outcome::Failure;
}
}
grammar::PayloadType::PathTraversal => {
use wafrift_oracle::traits::PayloadOracle;
let oracle = wafrift_oracle::path::PathOracle;
if !oracle.is_semantically_valid(&decoded, &decoded) {
return Outcome::Failure;
}
}
_ => {}
}
}
}
}
if self.applied_techniques.len() >= self.max_depth {
let diversity = technique_diversity(&self.applied_techniques);
Outcome::Success(Reward::new(0.5 + diversity * 0.5))
} else {
Outcome::Ongoing
}
}
fn max_depth(&self) -> Option<usize> {
Some(self.max_depth)
}
}
fn technique_diversity(techniques: &[Technique]) -> f64 {
if techniques.is_empty() {
return 0.0;
}
let mut has_encoding = false;
let mut has_grammar = false;
let mut has_content_type = false;
let mut has_header = false;
for tech in techniques {
match tech {
Technique::PayloadEncoding(_) => has_encoding = true,
Technique::GrammarMutation(_) => has_grammar = true,
Technique::ContentTypeSwitch(_) => has_content_type = true,
Technique::HeaderObfuscation(_) => has_header = true,
_ => {}
}
}
let dimensions_used = u32::from(has_encoding)
+ u32::from(has_grammar)
+ u32::from(has_content_type)
+ u32::from(has_header);
f64::from(dimensions_used) / 4.0
}
#[cfg(test)]
mod tests {
use super::*;
use mctrust::{SearchConfig, TreeSearch};
#[test]
fn mcts_bridge_finds_encoding_technique() {
let req = Request::post(
"http://example.com/login",
b"user=admin' OR '1'='1".to_vec(),
);
let env = WafRiftEnv::new(req, 2);
let config = SearchConfig::builder()
.iterations(100)
.exploration_constant(1.41)
.max_depth(2)
.build();
let mut engine = TreeSearch::new(env, config);
let optimal_action = engine.run();
assert!(
optimal_action.is_some(),
"MCTS should discover at least one valid progression"
);
}
#[test]
fn action_space_includes_multiple_dimensions() {
let req = Request::post("http://example.com/search", b"q=admin' OR 1=1--".to_vec());
let env = WafRiftEnv::new(req, 3);
let actions = env.legal_actions();
let has_encoding = actions
.iter()
.any(|a| matches!(a, TechniqueAction::Encode(_)));
let has_grammar = actions
.iter()
.any(|a| matches!(a, TechniqueAction::GrammarMutate(_)));
let has_content_type = actions
.iter()
.any(|a| matches!(a, TechniqueAction::ContentTypeSwitch(_)));
let has_header = actions
.iter()
.any(|a| matches!(a, TechniqueAction::HeaderTrick(_)));
assert!(has_encoding, "action space must include encoding");
assert!(has_grammar, "action space must include grammar mutations");
assert!(
has_content_type,
"action space must include content-type switching"
);
assert!(has_header, "action space must include header tricks");
}
#[test]
fn grammar_only_applied_once() {
let req = Request::post("http://example.com/search", b"q=admin' OR 1=1--".to_vec());
let mut env = WafRiftEnv::new(req, 4);
let grammar_action = env
.legal_actions()
.into_iter()
.find(|a| matches!(a, TechniqueAction::GrammarMutate(_)))
.expect("expected at least one grammar action for SQL-like payload");
env.apply(&grammar_action);
let actions = env.legal_actions();
let has_grammar = actions
.iter()
.any(|a| matches!(a, TechniqueAction::GrammarMutate(_)));
assert!(
!has_grammar,
"grammar mutations should only be available once per path"
);
}
#[test]
fn content_type_only_applied_once() {
let req = Request::post("http://example.com/search", b"q=test".to_vec());
let mut env = WafRiftEnv::new(req, 4);
let ct_action = env
.legal_actions()
.into_iter()
.find(|a| matches!(a, TechniqueAction::ContentTypeSwitch(_)))
.expect("expected at least one content-type action for form body");
env.apply(&ct_action);
let actions = env.legal_actions();
let has_ct = actions
.iter()
.any(|a| matches!(a, TechniqueAction::ContentTypeSwitch(_)));
assert!(
!has_ct,
"content-type switch should only be available once per path"
);
}
#[test]
fn technique_diversity_scoring() {
assert!((technique_diversity(&[]) - 0.0).abs() < f64::EPSILON);
let single = vec![Technique::PayloadEncoding("test".into())];
assert!((technique_diversity(&single) - 0.25).abs() < f64::EPSILON);
let dual = vec![
Technique::PayloadEncoding("test".into()),
Technique::GrammarMutation("sql".into()),
];
assert!((technique_diversity(&dual) - 0.5).abs() < f64::EPSILON);
let triple = vec![
Technique::PayloadEncoding("test".into()),
Technique::GrammarMutation("sql".into()),
Technique::ContentTypeSwitch("json".into()),
];
assert!((technique_diversity(&triple) - 0.75).abs() < f64::EPSILON);
}
#[test]
fn multi_step_mcts_explores_combinations() {
let req = Request::post(
"http://example.com/login",
b"user=admin' OR '1'='1".to_vec(),
);
let env = WafRiftEnv::new(req, 3);
let config = SearchConfig::builder()
.iterations(200)
.exploration_constant(1.41)
.max_depth(3)
.build();
let mut engine = TreeSearch::new(env, config);
let result = engine.run();
assert!(result.is_some(), "MCTS should find a multi-step path");
}
#[test]
fn to_technique_conversion() {
let encode = TechniqueAction::Encode("UrlEncode".into());
assert!(matches!(
encode.to_technique(),
Technique::PayloadEncoding(_)
));
let grammar = TechniqueAction::GrammarMutate("tautology_swap".into());
assert!(matches!(
grammar.to_technique(),
Technique::GrammarMutation(_)
));
let ct = TechniqueAction::ContentTypeSwitch("Multipart".into());
assert!(matches!(ct.to_technique(), Technique::ContentTypeSwitch(_)));
let header = TechniqueAction::HeaderTrick("CaseMixing".into());
assert!(matches!(
header.to_technique(),
Technique::HeaderObfuscation(_)
));
}
#[test]
fn xss_payload_uses_xss_oracle() {
let req = Request::post(
"http://example.com/comment",
b"msg=<script>alert(1)</script>".to_vec(),
);
let env = WafRiftEnv::new(req, 2);
let config = SearchConfig::builder()
.iterations(100)
.exploration_constant(1.41)
.max_depth(2)
.build();
let mut engine = TreeSearch::new(env, config);
let _ = engine.run();
}
}