use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use crate::crawler::{DiscoveredForm, FormInput};
use crate::headless_crawler::{CsrfTokenInfo, FormSubmissionResult, HeadlessCrawler};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FormSubmission {
pub id: String,
pub action_url: String,
pub method: String,
pub fields: Vec<FormField>,
pub headers: HashMap<String, String>,
pub cookies: HashMap<String, String>,
pub source_url: String,
pub recorded_at: String,
pub response_url: Option<String>,
pub response_status: Option<u16>,
pub sequence_index: Option<usize>,
pub is_wizard_step: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FormField {
pub name: String,
pub field_type: String,
pub value: Option<String>,
pub options: Option<Vec<String>>,
pub required: bool,
pub is_dynamic_token: bool,
pub token_type: Option<TokenType>,
pub is_injectable: bool,
pub validation_pattern: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum TokenType {
Csrf,
Nonce,
Timestamp,
Session,
Captcha,
Unknown,
}
impl FormSubmission {
pub fn new(action_url: String, method: String, source_url: String) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
action_url,
method,
fields: Vec::new(),
headers: HashMap::new(),
cookies: HashMap::new(),
source_url,
recorded_at: chrono::Utc::now().to_rfc3339(),
response_url: None,
response_status: None,
sequence_index: None,
is_wizard_step: false,
}
}
pub fn from_discovered_form(form: &DiscoveredForm) -> Self {
let mut submission = Self::new(
form.action.clone(),
form.method.clone(),
form.discovered_at.clone(),
);
for input in &form.inputs {
submission.fields.push(FormField::from_form_input(input));
}
submission
}
pub fn add_field(&mut self, field: FormField) {
self.fields.push(field);
}
pub fn get_injectable_fields(&self) -> Vec<&FormField> {
self.fields.iter().filter(|f| f.is_injectable).collect()
}
pub fn get_token_fields(&self) -> Vec<&FormField> {
self.fields.iter().filter(|f| f.is_dynamic_token).collect()
}
pub fn with_modified_field(&self, field_name: &str, new_value: &str) -> Self {
let mut cloned = self.clone();
for field in &mut cloned.fields {
if field.name == field_name {
field.value = Some(new_value.to_string());
break;
}
}
cloned
}
pub fn to_form_data(&self) -> Vec<(String, String)> {
self.fields
.iter()
.filter_map(|f| f.value.as_ref().map(|v| (f.name.clone(), v.clone())))
.collect()
}
pub fn signature(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.action_url.hash(&mut hasher);
self.method.to_uppercase().hash(&mut hasher);
let mut names: Vec<_> = self.fields.iter().map(|f| &f.name).collect();
names.sort();
for name in names {
name.hash(&mut hasher);
}
hasher.finish()
}
}
impl FormField {
pub fn new(name: String, field_type: String) -> Self {
let is_dynamic_token = Self::detect_token_field(&name, &field_type);
let token_type = if is_dynamic_token {
Some(Self::classify_token(&name))
} else {
None
};
let is_injectable = !is_dynamic_token && field_type != "hidden";
Self {
name,
field_type,
value: None,
options: None,
required: false,
is_dynamic_token,
token_type,
is_injectable,
validation_pattern: None,
}
}
pub fn from_form_input(input: &FormInput) -> Self {
let is_dynamic_token = Self::detect_token_field(&input.name, &input.input_type);
let token_type = if is_dynamic_token {
Some(Self::classify_token(&input.name))
} else {
None
};
let is_injectable = !is_dynamic_token
&& (input.input_type != "hidden" || !Self::is_system_field(&input.name));
Self {
name: input.name.clone(),
field_type: input.input_type.clone(),
value: input.value.clone(),
options: input.options.clone(),
required: input.required,
is_dynamic_token,
token_type,
is_injectable,
validation_pattern: Self::detect_validation_pattern(&input.name, &input.input_type),
}
}
fn detect_token_field(name: &str, field_type: &str) -> bool {
let name_lower = name.to_lowercase();
let token_patterns = [
"csrf",
"_csrf",
"xsrf",
"_xsrf",
"_token",
"authenticity_token",
"verification_token",
"requestverificationtoken",
"__requestverificationtoken",
"csrfmiddlewaretoken",
"anti-forgery",
"antiforgery",
"form_token",
"formtoken",
"security_token",
"nonce",
"_nonce",
"__nonce",
"timestamp",
"_timestamp",
"ts",
"captcha",
"recaptcha",
"hcaptcha",
];
for pattern in &token_patterns {
if name_lower.contains(pattern) {
return true;
}
}
if field_type == "hidden" {
if name_lower.ends_with("_token") || name_lower.ends_with("token") {
return true;
}
}
false
}
fn classify_token(name: &str) -> TokenType {
let name_lower = name.to_lowercase();
if name_lower.contains("csrf")
|| name_lower.contains("xsrf")
|| name_lower.contains("authenticity")
|| name_lower.contains("verification")
|| name_lower.contains("antiforgery")
{
TokenType::Csrf
} else if name_lower.contains("nonce") {
TokenType::Nonce
} else if name_lower.contains("timestamp") || name_lower.contains("ts") {
TokenType::Timestamp
} else if name_lower.contains("session") {
TokenType::Session
} else if name_lower.contains("captcha") {
TokenType::Captcha
} else {
TokenType::Unknown
}
}
fn is_system_field(name: &str) -> bool {
let name_lower = name.to_lowercase();
name_lower.starts_with("__")
|| name_lower.contains("viewstate")
|| name_lower.contains("eventvalidation")
|| name_lower == "utf8"
}
fn detect_validation_pattern(name: &str, field_type: &str) -> Option<String> {
let name_lower = name.to_lowercase();
if field_type == "email" || name_lower.contains("email") {
Some("email".to_string())
} else if field_type == "tel"
|| name_lower.contains("phone")
|| name_lower.contains("mobile")
{
Some("phone".to_string())
} else if field_type == "url"
|| name_lower.contains("url")
|| name_lower.contains("website")
{
Some("url".to_string())
} else if name_lower.contains("zip") || name_lower.contains("postal") {
Some("postal_code".to_string())
} else if name_lower.contains("credit") || name_lower.contains("card_number") {
Some("credit_card".to_string())
} else {
None
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FormSequence {
pub id: String,
pub name: String,
pub submissions: Vec<FormSubmission>,
pub start_url: String,
pub end_url: Option<String>,
pub is_wizard: bool,
pub flow_type: Option<FlowType>,
pub step_count: usize,
pub initial_state: SequenceState,
pub recorded_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum FlowType {
Checkout,
Registration,
Login,
PasswordReset,
ProfileUpdate,
Survey,
Wizard,
Unknown,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SequenceState {
pub cookies: HashMap<String, String>,
pub local_storage: HashMap<String, String>,
pub session_storage: HashMap<String, String>,
pub auth_token: Option<String>,
pub custom: HashMap<String, String>,
}
impl FormSequence {
pub fn new(name: String, start_url: String) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
name,
submissions: Vec::new(),
start_url,
end_url: None,
is_wizard: false,
flow_type: None,
step_count: 0,
initial_state: SequenceState::default(),
recorded_at: chrono::Utc::now().to_rfc3339(),
}
}
pub fn add_submission(&mut self, mut submission: FormSubmission) {
submission.sequence_index = Some(self.submissions.len());
submission.is_wizard_step = self.is_wizard;
self.submissions.push(submission);
self.step_count = self.submissions.len();
}
pub fn get_all_injectable_fields(&self) -> Vec<(&FormSubmission, &FormField)> {
self.submissions
.iter()
.flat_map(|s| {
s.fields
.iter()
.filter(|f| f.is_injectable)
.map(move |f| (s, f))
})
.collect()
}
pub fn detect_flow_type(&mut self) {
let all_urls: Vec<&str> = self
.submissions
.iter()
.map(|s| s.action_url.as_str())
.chain(std::iter::once(self.start_url.as_str()))
.collect();
let url_text = all_urls.join(" ").to_lowercase();
let all_fields: Vec<&str> = self
.submissions
.iter()
.flat_map(|s| s.fields.iter().map(|f| f.name.as_str()))
.collect();
let field_text = all_fields.join(" ").to_lowercase();
if url_text.contains("checkout")
|| url_text.contains("cart")
|| url_text.contains("payment")
|| url_text.contains("order")
|| field_text.contains("credit")
|| field_text.contains("shipping")
{
self.flow_type = Some(FlowType::Checkout);
}
else if url_text.contains("register")
|| url_text.contains("signup")
|| url_text.contains("create-account")
|| (field_text.contains("password") && field_text.contains("confirm"))
{
self.flow_type = Some(FlowType::Registration);
}
else if url_text.contains("login")
|| url_text.contains("signin")
|| url_text.contains("authenticate")
{
self.flow_type = Some(FlowType::Login);
}
else if url_text.contains("password")
&& (url_text.contains("reset") || url_text.contains("forgot"))
{
self.flow_type = Some(FlowType::PasswordReset);
}
else if url_text.contains("profile")
|| url_text.contains("settings")
|| url_text.contains("account")
{
self.flow_type = Some(FlowType::ProfileUpdate);
}
else if url_text.contains("survey")
|| url_text.contains("questionnaire")
|| url_text.contains("quiz")
{
self.flow_type = Some(FlowType::Survey);
}
else if self.submissions.len() > 1 && self.is_wizard {
self.flow_type = Some(FlowType::Wizard);
} else {
self.flow_type = Some(FlowType::Unknown);
}
}
pub fn requires_csrf_refresh(&self) -> bool {
self.submissions.iter().any(|s| {
s.fields
.iter()
.any(|f| f.token_type == Some(TokenType::Csrf))
})
}
}
#[derive(Debug, Clone)]
pub struct FormRecorderConfig {
pub max_submissions_per_form: usize,
pub max_total_submissions: usize,
pub detect_wizards: bool,
pub flow_detection_window_secs: u64,
pub record_hidden_fields: bool,
pub capture_headers: bool,
}
impl Default for FormRecorderConfig {
fn default() -> Self {
Self {
max_submissions_per_form: 10,
max_total_submissions: 100,
detect_wizards: true,
flow_detection_window_secs: 300, record_hidden_fields: true,
capture_headers: true,
}
}
}
#[derive(Debug)]
pub struct FormRecorder {
config: FormRecorderConfig,
submissions: Vec<FormSubmission>,
sequences: Vec<FormSequence>,
seen_signatures: HashSet<u64>,
active_flow: Option<ActiveFlowState>,
current_state: SequenceState,
submission_times: Vec<(u64, Instant)>, }
#[derive(Debug)]
struct ActiveFlowState {
start_url: String,
submissions: Vec<FormSubmission>,
started_at: Instant,
last_submission_at: Instant,
expected_patterns: Vec<String>,
}
impl FormRecorder {
pub fn new() -> Self {
Self::with_config(FormRecorderConfig::default())
}
pub fn with_config(config: FormRecorderConfig) -> Self {
Self {
config,
submissions: Vec::new(),
sequences: Vec::new(),
seen_signatures: HashSet::new(),
active_flow: None,
current_state: SequenceState::default(),
submission_times: Vec::new(),
}
}
pub fn record_submission(&mut self, mut submission: FormSubmission) -> bool {
if self.submissions.len() >= self.config.max_total_submissions {
debug!("[FormRecorder] Max submissions reached, skipping");
return false;
}
let sig = submission.signature();
if self.seen_signatures.contains(&sig) {
debug!(
"[FormRecorder] Duplicate submission, skipping: {}",
submission.action_url
);
return false;
}
for field in &mut submission.fields {
if Self::is_token_value(&field.value) {
field.is_dynamic_token = true;
if field.token_type.is_none() {
field.token_type = Some(TokenType::Unknown);
}
}
}
if self.config.detect_wizards {
self.update_flow_detection(&submission);
}
self.seen_signatures.insert(sig);
self.submission_times.push((sig, Instant::now()));
self.submissions.push(submission);
info!(
"[FormRecorder] Recorded submission #{} to {}",
self.submissions.len(),
self.submissions
.last()
.map(|s| &s.action_url)
.unwrap_or(&String::new())
);
true
}
pub fn record_from_discovered_form(&mut self, form: &DiscoveredForm) -> bool {
let submission = FormSubmission::from_discovered_form(form);
self.record_submission(submission)
}
pub fn update_state(&mut self, state: SequenceState) {
self.current_state = state;
}
pub fn update_cookies(&mut self, cookies: HashMap<String, String>) {
self.current_state.cookies = cookies;
}
pub fn set_auth_token(&mut self, token: Option<String>) {
self.current_state.auth_token = token;
}
fn is_token_value(value: &Option<String>) -> bool {
if let Some(v) = value {
let len = v.len();
if len >= 32
&& v.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return true;
}
if len >= 20
&& v.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=')
{
if v.ends_with('=') || (len % 4 == 0 && len >= 24) {
return true;
}
}
if v.matches('.').count() == 2 {
return true;
}
if v.parse::<i64>().is_ok() && len >= 10 {
return true;
}
}
false
}
fn update_flow_detection(&mut self, submission: &FormSubmission) {
let now = Instant::now();
let window = Duration::from_secs(self.config.flow_detection_window_secs);
let should_continue = if let Some(ref flow) = self.active_flow {
let time_since_last = now.duration_since(flow.last_submission_at);
time_since_last <= window && self.is_flow_continuation(&flow.start_url, submission)
} else {
false
};
if should_continue {
if let Some(ref mut flow) = self.active_flow {
flow.submissions.push(submission.clone());
flow.last_submission_at = now;
debug!(
"[FormRecorder] Continuing wizard flow, step {}",
flow.submissions.len()
);
return;
}
} else if self.active_flow.is_some() {
self.finalize_active_flow();
}
if self.could_start_wizard(submission) {
debug!(
"[FormRecorder] Starting potential wizard flow at {}",
submission.source_url
);
self.active_flow = Some(ActiveFlowState {
start_url: submission.source_url.clone(),
submissions: vec![submission.clone()],
started_at: now,
last_submission_at: now,
expected_patterns: self.generate_expected_patterns(&submission.action_url),
});
}
}
fn could_start_wizard(&self, submission: &FormSubmission) -> bool {
let url_lower = submission.action_url.to_lowercase();
url_lower.contains("step")
|| url_lower.contains("wizard")
|| url_lower.contains("checkout")
|| url_lower.contains("register")
|| url_lower.contains("signup")
|| url_lower.contains("onboard")
|| submission.fields.len() <= 5 }
fn is_flow_continuation(&self, start_url: &str, submission: &FormSubmission) -> bool {
if let (Ok(start), Ok(current)) = (
url::Url::parse(start_url),
url::Url::parse(&submission.source_url),
) {
if start.host_str() != current.host_str() {
return false;
}
}
let url_lower = submission.action_url.to_lowercase();
if url_lower.contains("step")
|| url_lower.contains("next")
|| url_lower.contains("continue")
|| url_lower.contains("proceed")
{
return true;
}
if let Some(flow) = &self.active_flow {
if let Some(last) = flow.submissions.last() {
if submission.source_url.starts_with(&last.action_url) {
return true;
}
if let Some(ref resp_url) = last.response_url {
if submission.source_url == *resp_url {
return true;
}
}
}
}
false
}
fn generate_expected_patterns(&self, action_url: &str) -> Vec<String> {
let mut patterns = Vec::new();
if let Ok(url) = url::Url::parse(action_url) {
let path = url.path();
if let Some(idx) = path.find("step") {
let after_step = &path[idx + 4..];
if let Some(num) = after_step
.chars()
.take_while(|c| c.is_ascii_digit())
.collect::<String>()
.parse::<u32>()
.ok()
{
let next_step = format!("step{}", num + 1);
patterns.push(next_step);
}
}
patterns.push("next".to_string());
patterns.push("continue".to_string());
patterns.push("proceed".to_string());
}
patterns
}
fn finalize_active_flow(&mut self) {
if let Some(flow) = self.active_flow.take() {
if flow.submissions.len() >= 2 {
let mut sequence = FormSequence::new(
format!("Wizard Flow #{}", self.sequences.len() + 1),
flow.start_url,
);
sequence.is_wizard = true;
sequence.initial_state = self.current_state.clone();
for submission in flow.submissions {
sequence.add_submission(submission);
}
if let Some(last) = sequence.submissions.last() {
sequence.end_url = last.response_url.clone();
}
sequence.detect_flow_type();
info!(
"[FormRecorder] Finalized wizard sequence with {} steps: {:?}",
sequence.step_count, sequence.flow_type
);
self.sequences.push(sequence);
}
}
}
pub fn finalize(mut self) -> FormRecorderResults {
self.finalize_active_flow();
let standalone_submissions: Vec<_> = self
.submissions
.iter()
.filter(|s| s.sequence_index.is_none())
.cloned()
.collect();
for submission in standalone_submissions {
let mut sequence = FormSequence::new(
format!("Form: {}", submission.action_url),
submission.source_url.clone(),
);
sequence.initial_state = self.current_state.clone();
sequence.add_submission(submission);
sequence.detect_flow_type();
self.sequences.push(sequence);
}
let wizard_count = self.sequences.iter().filter(|s| s.is_wizard).count();
let total = self.submissions.len();
FormRecorderResults {
sequences: self.sequences,
total_submissions: total,
wizard_flows_detected: wizard_count,
}
}
pub fn stats(&self) -> FormRecorderStats {
FormRecorderStats {
total_submissions: self.submissions.len(),
unique_forms: self.seen_signatures.len(),
active_flow_steps: self
.active_flow
.as_ref()
.map(|f| f.submissions.len())
.unwrap_or(0),
sequences_detected: self.sequences.len(),
}
}
}
impl Default for FormRecorder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FormRecorderResults {
pub sequences: Vec<FormSequence>,
pub total_submissions: usize,
pub wizard_flows_detected: usize,
}
#[derive(Debug, Clone)]
pub struct FormRecorderStats {
pub total_submissions: usize,
pub unique_forms: usize,
pub active_flow_steps: usize,
pub sequences_detected: usize,
}
#[derive(Debug, Clone)]
pub struct FormReplayConfig {
pub submission_timeout_secs: u64,
pub step_delay_ms: u64,
pub auto_refresh_csrf: bool,
pub restore_state: bool,
pub max_retries: u32,
pub auth_token: Option<String>,
pub custom_headers: HashMap<String, String>,
}
impl Default for FormReplayConfig {
fn default() -> Self {
Self {
submission_timeout_secs: 30,
step_delay_ms: 500,
auto_refresh_csrf: true,
restore_state: true,
max_retries: 2,
auth_token: None,
custom_headers: HashMap::new(),
}
}
}
pub struct FormReplayer {
config: FormReplayConfig,
crawler: HeadlessCrawler,
}
#[derive(Debug, Clone)]
pub struct ReplaySubmissionResult {
pub submission: FormSubmission,
pub success: bool,
pub response_url: String,
pub status_code: Option<u16>,
pub has_errors: bool,
pub response_excerpt: Option<String>,
pub duration_ms: u64,
pub modifications: Vec<FieldModification>,
}
#[derive(Debug, Clone)]
pub struct FieldModification {
pub field_name: String,
pub original_value: Option<String>,
pub injected_value: String,
}
#[derive(Debug, Clone)]
pub struct ReplaySequenceResult {
pub sequence_id: String,
pub submission_results: Vec<ReplaySubmissionResult>,
pub sequence_success: bool,
pub total_duration_ms: u64,
pub refreshed_tokens: Vec<String>,
pub errors: Vec<String>,
}
impl FormReplayer {
pub fn new() -> Self {
Self::with_config(FormReplayConfig::default())
}
pub fn with_config(config: FormReplayConfig) -> Self {
let crawler = HeadlessCrawler::with_headers(
config.submission_timeout_secs,
config.auth_token.clone(),
config.custom_headers.clone(),
);
Self { config, crawler }
}
pub async fn replay_baseline(&self, sequence: &FormSequence) -> Result<ReplaySequenceResult> {
info!(
"[FormReplayer] Starting baseline replay of sequence: {}",
sequence.name
);
self.replay_sequence_internal(sequence, vec![]).await
}
pub async fn replay_with_injection(
&self,
sequence: &FormSequence,
field_name: &str,
payload: &str,
) -> Result<ReplaySequenceResult> {
info!(
"[FormReplayer] Replaying with injection: {} = {}",
field_name, payload
);
let modification = FieldModification {
field_name: field_name.to_string(),
original_value: None, injected_value: payload.to_string(),
};
self.replay_sequence_internal(sequence, vec![modification])
.await
}
pub async fn replay_with_modifications(
&self,
sequence: &FormSequence,
modifications: Vec<FieldModification>,
) -> Result<ReplaySequenceResult> {
info!(
"[FormReplayer] Replaying with {} modifications",
modifications.len()
);
self.replay_sequence_internal(sequence, modifications).await
}
async fn replay_sequence_internal(
&self,
sequence: &FormSequence,
modifications: Vec<FieldModification>,
) -> Result<ReplaySequenceResult> {
let start_time = Instant::now();
let mut submission_results = Vec::new();
let mut refreshed_tokens = Vec::new();
let mut errors = Vec::new();
let mut sequence_success = true;
if self.config.restore_state {
self.restore_state(&sequence.initial_state).await?;
}
debug!(
"[FormReplayer] Navigating to start URL: {}",
sequence.start_url
);
for (idx, submission) in sequence.submissions.iter().enumerate() {
info!(
"[FormReplayer] Processing step {}/{}: {}",
idx + 1,
sequence.submissions.len(),
submission.action_url
);
if self.config.auto_refresh_csrf
&& submission
.fields
.iter()
.any(|f| f.token_type == Some(TokenType::Csrf))
{
if let Some(csrf) = self.refresh_csrf_for_submission(submission).await? {
refreshed_tokens.push(csrf.field_name.clone());
}
}
let modified_submission = self.apply_modifications(submission, &modifications);
let submit_start = Instant::now();
let result = self.submit_form(&modified_submission).await;
let duration_ms = submit_start.elapsed().as_millis() as u64;
match result {
Ok(submit_result) => {
let has_errors = submit_result.has_error;
submission_results.push(ReplaySubmissionResult {
submission: modified_submission.clone(),
success: submit_result.success,
response_url: submit_result.final_url,
status_code: None, has_errors,
response_excerpt: None,
duration_ms,
modifications: modifications
.iter()
.filter(|m| {
modified_submission
.fields
.iter()
.any(|f| f.name == m.field_name)
})
.cloned()
.collect(),
});
if !submit_result.success {
sequence_success = false;
errors.push(format!(
"Step {} failed: {}",
idx + 1,
submit_result.submit_status
));
}
}
Err(e) => {
sequence_success = false;
errors.push(format!("Step {} error: {}", idx + 1, e));
submission_results.push(ReplaySubmissionResult {
submission: modified_submission.clone(),
success: false,
response_url: String::new(),
status_code: None,
has_errors: true,
response_excerpt: Some(e.to_string()),
duration_ms,
modifications: vec![],
});
break;
}
}
if idx < sequence.submissions.len() - 1 {
tokio::time::sleep(Duration::from_millis(self.config.step_delay_ms)).await;
}
}
let total_duration_ms = start_time.elapsed().as_millis() as u64;
Ok(ReplaySequenceResult {
sequence_id: sequence.id.clone(),
submission_results,
sequence_success,
total_duration_ms,
refreshed_tokens,
errors,
})
}
async fn restore_state(&self, state: &SequenceState) -> Result<()> {
debug!(
"[FormReplayer] Restoring session state: {} cookies, {} localStorage items",
state.cookies.len(),
state.local_storage.len()
);
Ok(())
}
async fn refresh_csrf_for_submission(
&self,
submission: &FormSubmission,
) -> Result<Option<CsrfTokenInfo>> {
let csrf_field = submission
.fields
.iter()
.find(|f| f.token_type == Some(TokenType::Csrf));
if csrf_field.is_none() {
return Ok(None);
}
let token = self
.crawler
.extract_csrf_token(&submission.source_url)
.await?;
if let Some(ref t) = token {
debug!(
"[FormReplayer] Refreshed CSRF token: {} = {}...",
t.field_name,
&t.value[..t.value.char_indices().nth(20).map_or(t.value.len(), |(i, _)| i)]
);
}
Ok(token)
}
fn apply_modifications(
&self,
submission: &FormSubmission,
modifications: &[FieldModification],
) -> FormSubmission {
let mut modified = submission.clone();
for modification in modifications {
for field in &mut modified.fields {
if field.name == modification.field_name && field.is_injectable {
field.value = Some(modification.injected_value.clone());
}
}
}
modified
}
async fn submit_form(&self, submission: &FormSubmission) -> Result<FormSubmissionResult> {
let form_data = submission.to_form_data();
self.crawler
.submit_form_with_csrf(&submission.source_url, &submission.action_url, &form_data)
.await
}
pub fn get_injection_targets(&self, sequence: &FormSequence) -> Vec<InjectionTarget> {
sequence
.submissions
.iter()
.enumerate()
.flat_map(|(step_idx, submission)| {
submission
.fields
.iter()
.filter(|f| f.is_injectable)
.map(move |field| InjectionTarget {
sequence_id: sequence.id.clone(),
step_index: step_idx,
field_name: field.name.clone(),
field_type: field.field_type.clone(),
current_value: field.value.clone(),
validation_pattern: field.validation_pattern.clone(),
})
})
.collect()
}
}
impl Default for FormReplayer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct InjectionTarget {
pub sequence_id: String,
pub step_index: usize,
pub field_name: String,
pub field_type: String,
pub current_value: Option<String>,
pub validation_pattern: Option<String>,
}
#[derive(Debug)]
pub struct FormSequenceBuilder {
sequence: FormSequence,
}
impl FormSequenceBuilder {
pub fn new(name: &str, start_url: &str) -> Self {
Self {
sequence: FormSequence::new(name.to_string(), start_url.to_string()),
}
}
pub fn wizard(mut self) -> Self {
self.sequence.is_wizard = true;
self
}
pub fn flow_type(mut self, flow_type: FlowType) -> Self {
self.sequence.flow_type = Some(flow_type);
self
}
pub fn with_state(mut self, state: SequenceState) -> Self {
self.sequence.initial_state = state;
self
}
pub fn add_step(mut self, submission: FormSubmission) -> Self {
self.sequence.add_submission(submission);
self
}
pub fn build(mut self) -> FormSequence {
self.sequence.detect_flow_type();
self.sequence
}
}
#[derive(Debug)]
pub struct FormSubmissionBuilder {
submission: FormSubmission,
}
impl FormSubmissionBuilder {
pub fn new(action_url: &str, method: &str, source_url: &str) -> Self {
Self {
submission: FormSubmission::new(
action_url.to_string(),
method.to_string(),
source_url.to_string(),
),
}
}
pub fn text_field(mut self, name: &str, value: &str) -> Self {
let mut field = FormField::new(name.to_string(), "text".to_string());
field.value = Some(value.to_string());
self.submission.add_field(field);
self
}
pub fn email_field(mut self, name: &str, value: &str) -> Self {
let mut field = FormField::new(name.to_string(), "email".to_string());
field.value = Some(value.to_string());
field.validation_pattern = Some("email".to_string());
self.submission.add_field(field);
self
}
pub fn password_field(mut self, name: &str, value: &str) -> Self {
let mut field = FormField::new(name.to_string(), "password".to_string());
field.value = Some(value.to_string());
self.submission.add_field(field);
self
}
pub fn hidden_field(mut self, name: &str, value: &str) -> Self {
let mut field = FormField::new(name.to_string(), "hidden".to_string());
field.value = Some(value.to_string());
field.is_injectable = false;
self.submission.add_field(field);
self
}
pub fn csrf_field(mut self, name: &str, value: &str) -> Self {
let mut field = FormField::new(name.to_string(), "hidden".to_string());
field.value = Some(value.to_string());
field.is_dynamic_token = true;
field.token_type = Some(TokenType::Csrf);
field.is_injectable = false;
self.submission.add_field(field);
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.submission.headers = headers;
self
}
pub fn with_cookies(mut self, cookies: HashMap<String, String>) -> Self {
self.submission.cookies = cookies;
self
}
pub fn build(self) -> FormSubmission {
self.submission
}
}
#[derive(Debug, Clone)]
pub struct SharedFormRecorder {
inner: Arc<Mutex<FormRecorder>>,
}
impl SharedFormRecorder {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(FormRecorder::new())),
}
}
pub fn with_config(config: FormRecorderConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(FormRecorder::with_config(config))),
}
}
pub fn record_submission(&self, submission: FormSubmission) -> bool {
if let Ok(mut recorder) = self.inner.lock() {
recorder.record_submission(submission)
} else {
warn!("[SharedFormRecorder] Failed to acquire lock for recording");
false
}
}
pub fn record_from_discovered_form(&self, form: &DiscoveredForm) -> bool {
if let Ok(mut recorder) = self.inner.lock() {
recorder.record_from_discovered_form(form)
} else {
warn!("[SharedFormRecorder] Failed to acquire lock for recording");
false
}
}
pub fn update_state(&self, state: SequenceState) {
if let Ok(mut recorder) = self.inner.lock() {
recorder.update_state(state);
}
}
pub fn stats(&self) -> Option<FormRecorderStats> {
self.inner.lock().ok().map(|r| r.stats())
}
pub fn finalize(self) -> Option<FormRecorderResults> {
Arc::try_unwrap(self.inner)
.ok()
.and_then(|mutex| mutex.into_inner().ok())
.map(|recorder| recorder.finalize())
}
}
impl Default for SharedFormRecorder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_form_field_token_detection() {
assert!(FormField::detect_token_field("csrf_token", "hidden"));
assert!(FormField::detect_token_field("_csrf", "hidden"));
assert!(FormField::detect_token_field(
"authenticity_token",
"hidden"
));
assert!(FormField::detect_token_field(
"__RequestVerificationToken",
"hidden"
));
assert!(!FormField::detect_token_field("email", "text"));
assert!(!FormField::detect_token_field("username", "text"));
assert!(!FormField::detect_token_field("password", "password"));
}
#[test]
fn test_token_classification() {
assert_eq!(FormField::classify_token("csrf_token"), TokenType::Csrf);
assert_eq!(FormField::classify_token("xsrf-token"), TokenType::Csrf);
assert_eq!(FormField::classify_token("nonce"), TokenType::Nonce);
assert_eq!(FormField::classify_token("timestamp"), TokenType::Timestamp);
assert_eq!(
FormField::classify_token("random_field"),
TokenType::Unknown
);
}
#[test]
fn test_form_submission_builder() {
let submission =
FormSubmissionBuilder::new("https://example.com/login", "POST", "https://example.com/")
.email_field("email", "test@example.com")
.password_field("password", "secret123")
.csrf_field("_csrf", "abc123")
.build();
assert_eq!(submission.fields.len(), 3);
assert_eq!(submission.get_injectable_fields().len(), 2); assert_eq!(submission.get_token_fields().len(), 1); }
#[test]
fn test_form_sequence_builder() {
let step1 = FormSubmissionBuilder::new(
"https://example.com/step1",
"POST",
"https://example.com/wizard",
)
.text_field("name", "John")
.build();
let step2 = FormSubmissionBuilder::new(
"https://example.com/step2",
"POST",
"https://example.com/step1",
)
.email_field("email", "john@example.com")
.build();
let sequence = FormSequenceBuilder::new("Registration", "https://example.com/wizard")
.wizard()
.flow_type(FlowType::Registration)
.add_step(step1)
.add_step(step2)
.build();
assert_eq!(sequence.step_count, 2);
assert!(sequence.is_wizard);
assert_eq!(sequence.flow_type, Some(FlowType::Registration));
}
#[test]
fn test_form_recorder() {
let mut recorder = FormRecorder::new();
let form = DiscoveredForm {
action: "https://example.com/submit".to_string(),
method: "POST".to_string(),
inputs: vec![FormInput {
name: "email".to_string(),
input_type: "email".to_string(),
value: Some("test@example.com".to_string()),
options: None,
required: true,
}],
discovered_at: "https://example.com/".to_string(),
};
assert!(recorder.record_from_discovered_form(&form));
assert!(!recorder.record_from_discovered_form(&form));
let stats = recorder.stats();
assert_eq!(stats.total_submissions, 1);
assert_eq!(stats.unique_forms, 1);
}
#[test]
fn test_is_token_value() {
assert!(FormRecorder::is_token_value(&Some(
"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6".to_string()
)));
assert!(FormRecorder::is_token_value(&Some(
"eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig".to_string()
)));
assert!(FormRecorder::is_token_value(&Some(
"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo=".to_string()
)));
assert!(!FormRecorder::is_token_value(&Some("hello".to_string())));
assert!(!FormRecorder::is_token_value(&None));
}
#[test]
fn test_submission_signature() {
let sub1 = FormSubmissionBuilder::new("/submit", "POST", "/form")
.text_field("name", "value1")
.build();
let sub2 = FormSubmissionBuilder::new("/submit", "POST", "/form")
.text_field("name", "value2") .build();
let sub3 = FormSubmissionBuilder::new("/submit", "POST", "/form")
.text_field("other", "value") .build();
assert_eq!(sub1.signature(), sub2.signature());
assert_ne!(sub1.signature(), sub3.signature());
}
#[test]
fn test_injection_target_extraction() {
let submission = FormSubmissionBuilder::new("/submit", "POST", "/form")
.text_field("username", "test")
.email_field("email", "test@example.com")
.csrf_field("_csrf", "token123")
.build();
let sequence = FormSequenceBuilder::new("Test", "/form")
.add_step(submission)
.build();
let replayer = FormReplayer::new();
let targets = replayer.get_injection_targets(&sequence);
assert_eq!(targets.len(), 2);
assert!(targets.iter().any(|t| t.field_name == "username"));
assert!(targets.iter().any(|t| t.field_name == "email"));
assert!(!targets.iter().any(|t| t.field_name == "_csrf"));
}
#[test]
fn test_flow_type_detection() {
let checkout_step = FormSubmissionBuilder::new(
"https://shop.example.com/checkout/payment",
"POST",
"https://shop.example.com/checkout",
)
.text_field("card_number", "4111111111111111")
.build();
let mut checkout_seq =
FormSequenceBuilder::new("Checkout", "https://shop.example.com/checkout")
.add_step(checkout_step)
.build();
checkout_seq.detect_flow_type();
assert_eq!(checkout_seq.flow_type, Some(FlowType::Checkout));
let register_step = FormSubmissionBuilder::new(
"https://example.com/register",
"POST",
"https://example.com/signup",
)
.email_field("email", "test@example.com")
.password_field("password", "secret")
.password_field("confirm_password", "secret")
.build();
let mut register_seq = FormSequenceBuilder::new("Register", "https://example.com/signup")
.add_step(register_step)
.build();
register_seq.detect_flow_type();
assert_eq!(register_seq.flow_type, Some(FlowType::Registration));
}
#[test]
fn test_shared_recorder() {
let recorder = SharedFormRecorder::new();
let form = DiscoveredForm {
action: "https://example.com/submit".to_string(),
method: "POST".to_string(),
inputs: vec![FormInput {
name: "test".to_string(),
input_type: "text".to_string(),
value: None,
options: None,
required: false,
}],
discovered_at: "https://example.com/".to_string(),
};
let recorder_clone = recorder.clone();
assert!(recorder.record_from_discovered_form(&form));
let stats = recorder_clone.stats().unwrap();
assert_eq!(stats.total_submissions, 1);
}
#[test]
fn test_modification_application() {
let submission = FormSubmissionBuilder::new("/submit", "POST", "/form")
.text_field("name", "original")
.email_field("email", "original@example.com")
.csrf_field("_csrf", "token")
.build();
let replayer = FormReplayer::new();
let modifications = vec![FieldModification {
field_name: "name".to_string(),
original_value: Some("original".to_string()),
injected_value: "<script>alert(1)</script>".to_string(),
}];
let modified = replayer.apply_modifications(&submission, &modifications);
let name_field = modified.fields.iter().find(|f| f.name == "name").unwrap();
assert_eq!(
name_field.value,
Some("<script>alert(1)</script>".to_string())
);
let email_field = modified.fields.iter().find(|f| f.name == "email").unwrap();
assert_eq!(email_field.value, Some("original@example.com".to_string()));
let csrf_field = modified.fields.iter().find(|f| f.name == "_csrf").unwrap();
assert_eq!(csrf_field.value, Some("token".to_string()));
}
}