Skip to main content

attackstr/
lib.rs

1//! # attackstr
2//!
3//! Grammar-based security payload generation for the Santh ecosystem.
4//!
5//! Every security tool needs attack payloads — `SQLi`, XSS, command injection,
6//! SSTI, SSRF, XXE, and more. This crate provides a single, configurable
7//! engine that all Santh tools share. Upgrade payloads once, every tool
8//! benefits.
9//!
10//! # Architecture
11//!
12//! Payloads are defined in TOML grammar files. Each grammar specifies:
13//!
14//! - **Contexts**: injection points (string break, numeric, attribute, etc.)
15//! - **Techniques**: attack patterns with template variables
16//! - **Variables**: substitution values (tautologies, commands, etc.)
17//! - **Encodings**: transforms applied to final payloads (URL, hex, unicode, etc.)
18//!
19//! The engine computes the Cartesian product:
20//! `contexts × techniques × variable_combos × encodings`
21//!
22//! # Usage
23//!
24//! ```rust
25//! use attackstr::{PayloadDb, PayloadConfig};
26//!
27//! let mut db = PayloadDb::with_config(PayloadConfig::default());
28//! db.load_toml(r#"
29//! [grammar]
30//! name = "example"
31//! sink_category = "sql-injection"
32//!
33//! [[techniques]]
34//! name = "basic"
35//! template = "' OR 1=1 --"
36//! "#).unwrap();
37//!
38//! // Get payloads for a category
39//! let sqli = db.payloads("sql-injection");
40//! for payload in sqli {
41//!     println!("{}", payload.text);
42//! }
43//!
44//! // Get payloads with marker injection for taint tracking
45//! let marked = db.payloads_with_marker("xss", "SLN_MARKER_42");
46//! ```
47//!
48//! # Custom Encodings
49//!
50//! Register custom encoding transforms:
51//!
52//! ```rust
53//! use attackstr::PayloadDb;
54//!
55//! let mut db = PayloadDb::new();
56//! db.register_encoding("rot13", |s| {
57//!     s.chars().map(|c| match c {
58//!         'a'..='m' | 'A'..='M' => (c as u8 + 13) as char,
59//!         'n'..='z' | 'N'..='Z' => (c as u8 - 13) as char,
60//!         _ => c,
61//!     }).collect()
62//! });
63//! ```
64
65#![forbid(unsafe_code)]
66#![warn(missing_docs)]
67
68/// TOML-configurable settings.
69pub mod config;
70mod encoding;
71mod grammar;
72mod loader;
73mod mutate;
74/// Legacy payloads and custom validators imported from older suites.
75pub mod ports;
76/// Grammar validation.
77pub mod validate;
78
79pub use config::PayloadConfigFile;
80pub use encoding::{apply_encoding, BuiltinEncoding, CustomEncoder, Encoder};
81pub use grammar::{
82    Context, Encoding, Grammar, GrammarMeta, Technique, TemplateExpansionError, Variable,
83};
84pub use loader::PayloadDb;
85pub use mutate::{
86    mutate_all, mutate_case, mutate_encoding_mix, mutate_html, mutate_null_bytes,
87    mutate_sql_comments, mutate_unicode, mutate_whitespace,
88};
89use serde::{Deserialize, Serialize};
90use std::collections::BTreeMap;
91use std::hash::{Hash, Hasher};
92pub use validate::{validate, GrammarIssue, IssueLevel};
93
94/// A trait for sources that can provide payloads.
95///
96/// This trait abstracts over different payload storage and generation
97/// strategies, allowing users to swap implementations.
98///
99/// # Example
100///
101/// ```rust
102/// use attackstr::{PayloadSource, PayloadDb};
103///
104/// fn count_payloads(source: &mut dyn PayloadSource) -> usize {
105///     source.payload_count()
106/// }
107/// ```
108pub trait PayloadSource {
109    /// Get all payloads for a given category.
110    ///
111    /// The returned slice is cached on subsequent calls for the same category.
112    fn payloads(&mut self, category: &str) -> &[Payload];
113
114    /// Get all available category names.
115    fn categories(&self) -> Vec<&str>;
116
117    /// Get the total number of payloads across all categories.
118    fn payload_count(&self) -> usize;
119}
120
121/// A static payload source that holds payloads directly in memory.
122///
123/// This is useful for users who generate payloads externally and want
124/// to use them with the attackstr ecosystem.
125///
126/// # Example
127///
128/// ```rust
129/// use attackstr::{StaticPayloads, Payload, PayloadSource};
130///
131/// let payloads = vec![
132///     Payload {
133///         text: "test".into(),
134///         category: "custom".into(),
135///         technique: "manual".into(),
136///         context: "default".into(),
137///         encoding: "raw".into(),
138///         cwe: None,
139///         severity: None,
140///         confidence: 1.0,
141///         expected_pattern: None,
142///     },
143/// ];
144///
145/// let mut source = StaticPayloads::new(payloads);
146/// assert_eq!(source.payloads("custom").len(), 1);
147/// ```
148#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
149pub struct StaticPayloads {
150    payloads: Vec<Payload>,
151    #[serde(skip)]
152    category_ranges: BTreeMap<String, std::ops::Range<usize>>,
153}
154
155impl StaticPayloads {
156    /// Create a new `StaticPayloads` from a vector of payloads.
157    ///
158    /// Example:
159    /// ```rust
160    /// use attackstr::{Payload, StaticPayloads};
161    ///
162    /// let payloads = vec![Payload {
163    ///     text: "alert(1)".into(),
164    ///     category: "xss".into(),
165    ///     technique: "basic".into(),
166    ///     context: "default".into(),
167    ///     encoding: "raw".into(),
168    ///     cwe: None,
169    ///     severity: None,
170    ///     confidence: 1.0,
171    ///     expected_pattern: None,
172    /// }];
173    ///
174    /// let source = StaticPayloads::new(payloads);
175    /// assert_eq!(source.all_payloads().len(), 1);
176    /// ```
177    #[must_use]
178    pub fn new(mut payloads: Vec<Payload>) -> Self {
179        sort_payloads_by_category(&mut payloads);
180        Self {
181            category_ranges: build_category_ranges(&payloads),
182            payloads,
183        }
184    }
185
186    /// Add a single payload to this source.
187    ///
188    /// Example:
189    /// ```rust
190    /// use attackstr::{Payload, StaticPayloads};
191    ///
192    /// let mut source = StaticPayloads::default();
193    /// source.add(Payload {
194    ///     text: "test".into(),
195    ///     category: "custom".into(),
196    ///     technique: "manual".into(),
197    ///     context: "default".into(),
198    ///     encoding: "raw".into(),
199    ///     cwe: None,
200    ///     severity: None,
201    ///     confidence: 1.0,
202    ///     expected_pattern: None,
203    /// });
204    /// assert_eq!(source.all_payloads().len(), 1);
205    /// ```
206    pub fn add(&mut self, payload: Payload) {
207        self.payloads.push(payload);
208        sort_payloads_by_category(&mut self.payloads);
209        self.category_ranges = build_category_ranges(&self.payloads);
210    }
211
212    /// Get all payloads regardless of category.
213    ///
214    /// Example:
215    /// ```rust
216    /// use attackstr::StaticPayloads;
217    ///
218    /// let source = StaticPayloads::default();
219    /// assert!(source.all_payloads().is_empty());
220    /// ```
221    #[must_use]
222    pub fn all_payloads(&self) -> &[Payload] {
223        &self.payloads
224    }
225
226    /// Iterate over all payloads in this source.
227    ///
228    /// Example:
229    /// ```rust
230    /// use attackstr::StaticPayloads;
231    ///
232    /// let source = StaticPayloads::default();
233    /// assert_eq!(source.iter().count(), 0);
234    /// ```
235    pub fn iter(&self) -> impl Iterator<Item = &Payload> {
236        self.payloads.iter()
237    }
238
239    /// Iterate over payloads for a single category.
240    ///
241    /// Example:
242    /// ```rust
243    /// use attackstr::{Payload, StaticPayloads};
244    ///
245    /// let source = StaticPayloads::new(vec![Payload {
246    ///     text: "alert(1)".into(),
247    ///     category: "xss".into(),
248    ///     technique: "basic".into(),
249    ///     context: "default".into(),
250    ///     encoding: "raw".into(),
251    ///     cwe: None,
252    ///     severity: None,
253    ///     confidence: 1.0,
254    ///     expected_pattern: None,
255    /// }]);
256    /// assert_eq!(source.iter_category("xss").count(), 1);
257    /// ```
258    pub fn iter_category<'a>(
259        &'a self,
260        category: &'a str,
261    ) -> impl Iterator<Item = &'a Payload> + 'a {
262        self.payloads
263            .iter()
264            .filter(move |payload| payload.category == category)
265    }
266}
267
268impl From<Vec<Payload>> for StaticPayloads {
269    fn from(payloads: Vec<Payload>) -> Self {
270        Self::new(payloads)
271    }
272}
273
274impl PayloadSource for StaticPayloads {
275    fn payloads(&mut self, category: &str) -> &[Payload] {
276        if self.category_ranges.is_empty() && !self.payloads.is_empty() {
277            self.category_ranges = build_category_ranges(&self.payloads);
278        }
279        self.category_ranges
280            .get(category)
281            .map_or(&[], |range| &self.payloads[range.clone()])
282    }
283
284    fn categories(&self) -> Vec<&str> {
285        use std::collections::HashSet;
286        let mut seen = HashSet::new();
287        self.payloads
288            .iter()
289            .filter_map(|p| {
290                if seen.insert(p.category.clone()) {
291                    Some(p.category.as_str())
292                } else {
293                    None
294                }
295            })
296            .collect()
297    }
298
299    fn payload_count(&self) -> usize {
300        self.payloads.len()
301    }
302}
303
304fn build_category_ranges(payloads: &[Payload]) -> BTreeMap<String, std::ops::Range<usize>> {
305    let mut ranges = BTreeMap::new();
306    let mut start = 0;
307    while start < payloads.len() {
308        let category = payloads[start].category.clone();
309        let mut end = start + 1;
310        while end < payloads.len() && payloads[end].category == category {
311            end += 1;
312        }
313        ranges.insert(category, start..end);
314        start = end;
315    }
316    ranges
317}
318
319/// Configuration for payload generation behavior.
320#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
321pub struct PayloadConfig {
322    /// Maximum payloads per category before truncation (0 = unlimited).
323    pub max_per_category: usize,
324    /// Whether to deduplicate identical payloads within a category.
325    pub deduplicate: bool,
326    /// Default marker prefix for taint tracking (e.g. "SLN").
327    pub marker_prefix: String,
328    /// Categories to exclude from generation (e.g. for compliance).
329    pub exclude_categories: Vec<String>,
330    /// Categories to include exclusively (empty = all).
331    pub include_categories: Vec<String>,
332    /// Restrict loaded grammars to one or more runtimes (empty = all).
333    pub target_runtime: Option<Vec<String>>,
334    /// Where to place the taint marker in generated marker payloads.
335    pub marker_position: MarkerPosition,
336}
337
338impl PayloadConfig {
339    /// Create a builder for [`PayloadConfig`].
340    ///
341    /// Example:
342    /// ```rust
343    /// use attackstr::PayloadConfig;
344    ///
345    /// let config = PayloadConfig::builder().marker_prefix("TRACE").build();
346    /// assert_eq!(config.marker_prefix, "TRACE");
347    /// ```
348    #[must_use]
349    pub fn builder() -> PayloadConfigBuilder {
350        PayloadConfigBuilder::default()
351    }
352
353    /// Load a [`PayloadConfig`] from a TOML file.
354    ///
355    /// Example:
356    /// ```rust
357    /// use attackstr::PayloadConfig;
358    ///
359    /// let dir = tempfile::tempdir().unwrap();
360    /// let path = dir.path().join("payloads.toml");
361    /// std::fs::write(&path, "marker_prefix = \"TRACE\"\n").unwrap();
362    ///
363    /// let config = PayloadConfig::load(&path).unwrap();
364    /// assert_eq!(config.marker_prefix, "TRACE");
365    /// ```
366    ///
367    /// # Errors
368    /// Returns a [`PayloadError`] if reading or parsing the file fails.
369    pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self, PayloadError> {
370        Ok(PayloadConfigFile::load(path)?.into_config())
371    }
372
373    /// Parse a [`PayloadConfig`] directly from TOML text.
374    ///
375    /// Example:
376    /// ```rust
377    /// use attackstr::{MarkerPosition, PayloadConfig};
378    ///
379    /// let config = PayloadConfig::from_toml("marker_position = \"suffix\"", "<inline>").unwrap();
380    /// assert_eq!(config.marker_position, MarkerPosition::Suffix);
381    /// ```
382    ///
383    /// # Errors
384    /// Returns a [`PayloadError`] if parsing the TOML fails.
385    pub fn from_toml(toml_str: &str, source: impl Into<String>) -> Result<Self, PayloadError> {
386        Ok(PayloadConfigFile::from_toml(toml_str, source.into())?.into_config())
387    }
388}
389
390impl Default for PayloadConfig {
391    fn default() -> Self {
392        Self {
393            max_per_category: 0,
394            deduplicate: true,
395            marker_prefix: "SLN".into(),
396            exclude_categories: Vec::new(),
397            include_categories: Vec::new(),
398            target_runtime: None,
399            marker_position: MarkerPosition::Prefix,
400        }
401    }
402}
403
404/// Placement strategy for marker-injected payloads.
405#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
406pub enum MarkerPosition {
407    /// Prepend the marker to the payload text.
408    Prefix,
409    /// Append the marker to the payload text.
410    Suffix,
411    /// Wrap the marker in braces and prepend it inline.
412    Inline,
413    /// Replace `{MARKER}` placeholders in the payload text.
414    Replace(String),
415}
416
417impl std::fmt::Display for MarkerPosition {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        match self {
420            Self::Prefix => f.write_str("prefix"),
421            Self::Suffix => f.write_str("suffix"),
422            Self::Inline => f.write_str("inline"),
423            Self::Replace(value) => write!(f, "replace:{value}"),
424        }
425    }
426}
427
428/// Builder for [`PayloadConfig`].
429#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
430pub struct PayloadConfigBuilder {
431    config: PayloadConfig,
432}
433
434impl PayloadConfigBuilder {
435    /// Set the maximum payload count per category.
436    #[must_use]
437    pub fn max_per_category(mut self, max_per_category: usize) -> Self {
438        self.config.max_per_category = max_per_category;
439        self
440    }
441
442    /// Set whether identical payloads should be deduplicated.
443    #[must_use]
444    pub fn deduplicate(mut self, deduplicate: bool) -> Self {
445        self.config.deduplicate = deduplicate;
446        self
447    }
448
449    /// Set the marker prefix.
450    #[must_use]
451    pub fn marker_prefix(mut self, marker_prefix: impl Into<String>) -> Self {
452        self.config.marker_prefix = marker_prefix.into();
453        self
454    }
455
456    /// Set the categories to exclude.
457    #[must_use]
458    pub fn exclude_categories(mut self, exclude_categories: Vec<String>) -> Self {
459        self.config.exclude_categories = exclude_categories;
460        self
461    }
462
463    /// Set the categories to include.
464    #[must_use]
465    pub fn include_categories(mut self, include_categories: Vec<String>) -> Self {
466        self.config.include_categories = include_categories;
467        self
468    }
469
470    /// Set the allowed target runtimes.
471    #[must_use]
472    pub fn target_runtime(mut self, target_runtime: Option<Vec<String>>) -> Self {
473        self.config.target_runtime = target_runtime;
474        self
475    }
476
477    /// Set the marker placement strategy.
478    #[must_use]
479    pub fn marker_position(mut self, marker_position: MarkerPosition) -> Self {
480        self.config.marker_position = marker_position;
481        self
482    }
483
484    /// Build the final [`PayloadConfig`].
485    ///
486    /// Example:
487    /// ```rust
488    /// use attackstr::PayloadConfig;
489    ///
490    /// let config = PayloadConfig::builder().max_per_category(10).build();
491    /// assert_eq!(config.max_per_category, 10);
492    /// ```
493    #[must_use]
494    pub fn build(self) -> PayloadConfig {
495        self.config
496    }
497}
498
499fn sort_payloads_by_category(payloads: &mut [Payload]) {
500    payloads.sort_by(|left, right| {
501        left.category
502            .cmp(&right.category)
503            .then_with(|| left.technique.cmp(&right.technique))
504            .then_with(|| left.context.cmp(&right.context))
505            .then_with(|| left.encoding.cmp(&right.encoding))
506            .then_with(|| left.text.cmp(&right.text))
507    });
508}
509
510/// A generated payload with metadata about its origin.
511#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
512pub struct Payload {
513    /// The payload string.
514    pub text: String,
515    /// Which category this payload targets (e.g. "sql-injection").
516    pub category: String,
517    /// Which technique generated it (e.g. "union-based").
518    pub technique: String,
519    /// Which context it was generated in (e.g. "string-break").
520    pub context: String,
521    /// Which encoding was applied (e.g. "`url_encode`").
522    pub encoding: String,
523    /// Optional CWE identifier inherited from the grammar.
524    pub cwe: Option<String>,
525    /// Optional severity hint inherited from the grammar.
526    pub severity: Option<String>,
527    /// Confidence score for this payload variant.
528    pub confidence: f64,
529    /// Optional regex pattern expected in the observed response.
530    pub expected_pattern: Option<String>,
531}
532
533impl Default for Payload {
534    fn default() -> Self {
535        Self {
536            text: String::new(),
537            category: String::new(),
538            technique: String::new(),
539            context: String::new(),
540            encoding: "raw".to_string(),
541            cwe: None,
542            severity: None,
543            confidence: 1.0,
544            expected_pattern: None,
545        }
546    }
547}
548
549impl Eq for Payload {}
550
551impl Hash for Payload {
552    fn hash<H: Hasher>(&self, state: &mut H) {
553        self.text.hash(state);
554        self.category.hash(state);
555        self.technique.hash(state);
556        self.context.hash(state);
557        self.encoding.hash(state);
558        self.cwe.hash(state);
559        self.severity.hash(state);
560        self.confidence.to_bits().hash(state);
561        self.expected_pattern.hash(state);
562    }
563}
564
565/// Errors from payload operations.
566#[derive(Debug, thiserror::Error)]
567pub enum PayloadError {
568    /// Failed to read a file.
569    #[error("{0}. Fix: verify the file or directory exists and that the current process has permission to read it.")]
570    Io(#[from] std::io::Error),
571    /// Failed to parse TOML configuration.
572    #[error("{message}", message = Self::config_parse_message(file, source))]
573    ConfigParse {
574        /// Which config file failed.
575        file: String,
576        /// The parse error.
577        source: Box<toml::de::Error>,
578    },
579    /// Failed to parse TOML grammar.
580    #[error("{message}", message = Self::grammar_parse_message(file, source))]
581    GrammarParse {
582        /// Which file failed.
583        file: String,
584        /// The parse error.
585        source: Box<toml::de::Error>,
586    },
587    /// Grammar parsed but failed semantic validation.
588    #[error("{message}", message = Self::grammar_validation_message(file, issues))]
589    GrammarValidation {
590        /// Which file failed.
591        file: String,
592        /// Structured validation issues collected for the grammar.
593        issues: Vec<GrammarIssue>,
594    },
595    /// Failed to expand template placeholders in a grammar.
596    #[error("{message}", message = Self::template_expansion_message(file, source))]
597    TemplateExpansion {
598        /// Which file failed.
599        file: String,
600        /// The expansion error.
601        source: TemplateExpansionError,
602    },
603    /// Path is not a directory.
604    #[error("path '{0}' is not a directory. Fix: pass a directory that contains `.toml` grammar files or update `grammar_dirs` in your config.")]
605    NotADirectory(String),
606    /// Another directory load is already in progress for this database instance.
607    #[error("payload database load is already in progress. Fix: wait for the current `load_dir` call to finish before starting another one on the same `PayloadDb`.")]
608    ConcurrentLoad,
609}
610
611impl Serialize for PayloadError {
612    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
613    where
614        S: serde::Serializer,
615    {
616        use serde::ser::SerializeMap;
617
618        let mut map = serializer.serialize_map(Some(2))?;
619        map.serialize_entry("kind", self.kind())?;
620        map.serialize_entry("message", &self.to_string())?;
621        map.end()
622    }
623}
624
625impl<'de> Deserialize<'de> for PayloadError {
626    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
627    where
628        D: serde::Deserializer<'de>,
629    {
630        #[derive(Deserialize)]
631        struct PayloadErrorWire {
632            message: String,
633        }
634
635        let wire = PayloadErrorWire::deserialize(deserializer)?;
636        Ok(Self::Io(std::io::Error::other(wire.message)))
637    }
638}
639
640impl PayloadError {
641    fn kind(&self) -> &'static str {
642        match self {
643            Self::Io(_) => "io",
644            Self::ConfigParse { .. } => "config_parse",
645            Self::GrammarParse { .. } => "grammar_parse",
646            Self::GrammarValidation { .. } => "grammar_validation",
647            Self::TemplateExpansion { .. } => "template_expansion",
648            Self::NotADirectory(_) => "not_a_directory",
649            Self::ConcurrentLoad => "concurrent_load",
650        }
651    }
652
653    fn config_parse_message(file: &str, source: &toml::de::Error) -> String {
654        format!(
655            "config parse error in {file}: {source}. Fix: make the file valid TOML and keep payload settings at the top level, for example `max_per_category = 100` and `grammar_dirs = [\"./grammars\"]`."
656        )
657    }
658
659    fn grammar_parse_message(file: &str, source: &toml::de::Error) -> String {
660        let detail = source.to_string();
661        let fix = if detail.contains("missing field `grammar`") {
662            "Fix: add a `[grammar]` table with at least `name` and `sink_category`."
663        } else if detail.contains("missing field `name`")
664            || detail.contains("missing field `sink_category`")
665        {
666            "Fix: every grammar needs a `[grammar]` section with both `name` and `sink_category` fields."
667        } else if detail.contains("missing field `template`") {
668            "Fix: every `[[techniques]]` entry needs a `name` and `template`."
669        } else {
670            "Fix: make the file valid TOML and include a `[grammar]` section plus at least one `[[techniques]]` entry."
671        };
672
673        format!("grammar parse error in {file}: {detail}. {fix}")
674    }
675
676    fn template_expansion_message(file: &str, source: &TemplateExpansionError) -> String {
677        let fix = match source {
678            TemplateExpansionError::UnclosedBrace { .. } => {
679                "Fix: close every `{placeholder}` with a matching `}` and escape literal braces by leaving them outside placeholder syntax."
680            }
681            TemplateExpansionError::RecursionLimitExceeded { max_depth } => {
682                return format!(
683                    "template expansion error in {file}: {source}. Fix: remove circular or self-referential variables so expansion stays below the recursion limit of {max_depth}."
684                );
685            }
686            TemplateExpansionError::PayloadLimitExceeded { limit } => {
687                return format!(
688                    "template expansion error in {file}: {source}. Fix: reduce Cartesian product size (contexts x techniques x variables) to stay below the {limit} limit."
689                );
690            }
691        };
692
693        format!("template expansion error in {file}: {source}. {fix}")
694    }
695
696    fn grammar_validation_message(file: &str, issues: &[GrammarIssue]) -> String {
697        let issue_count = issues.len();
698        let summary = issues
699            .first()
700            .map(|issue| format!("{}: {}", issue.level, issue.message))
701            .unwrap_or_else(|| "unknown validation failure".to_string());
702        format!(
703            "grammar validation error in {file}: {summary}. Fix: resolve the reported validation issue{plural} before loading the grammar.",
704            plural = if issue_count == 1 { "" } else { "s" }
705        )
706    }
707}
708
709#[cfg(test)]
710mod adversarial_tests;
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn payload_round_trips_with_serde() {
718        let payload = Payload {
719            text: "alert(1)".into(),
720            category: "xss".into(),
721            technique: "basic".into(),
722            context: "default".into(),
723            encoding: "raw".into(),
724            cwe: Some("CWE-79".into()),
725            severity: Some("high".into()),
726            confidence: 0.9,
727            expected_pattern: Some("alert".into()),
728        };
729
730        let encoded = toml::to_string(&payload).unwrap();
731        let decoded: Payload = toml::from_str(&encoded).unwrap();
732        assert_eq!(decoded, payload);
733    }
734
735    #[test]
736    fn payload_config_builder_overrides_defaults() {
737        let config = PayloadConfig::builder()
738            .max_per_category(100)
739            .deduplicate(false)
740            .marker_prefix("TAINT")
741            .exclude_categories(vec!["xxe".into()])
742            .include_categories(vec!["xss".into()])
743            .target_runtime(Some(vec!["php".into()]))
744            .marker_position(MarkerPosition::Suffix)
745            .build();
746
747        assert_eq!(config.max_per_category, 100);
748        assert!(!config.deduplicate);
749        assert_eq!(config.marker_prefix, "TAINT");
750        assert_eq!(config.exclude_categories, vec!["xxe"]);
751        assert_eq!(config.include_categories, vec!["xss"]);
752        assert_eq!(config.target_runtime, Some(vec!["php".into()]));
753        assert_eq!(config.marker_position, MarkerPosition::Suffix);
754    }
755
756    #[test]
757    fn payload_config_loads_from_toml() {
758        let config = PayloadConfig::from_toml(
759            r#"
760max_per_category = 25
761deduplicate = false
762marker_position = "suffix"
763"#,
764            "<test>",
765        )
766        .unwrap();
767
768        assert_eq!(config.max_per_category, 25);
769        assert!(!config.deduplicate);
770        assert_eq!(config.marker_position, MarkerPosition::Suffix);
771    }
772
773    #[test]
774    fn payload_config_loads_from_file() {
775        let dir = tempfile::tempdir().unwrap();
776        let path = dir.path().join("payloads.toml");
777        std::fs::write(&path, "marker_prefix = \"TRACE\"\n").unwrap();
778
779        let config = PayloadConfig::load(&path).unwrap();
780
781        assert_eq!(config.marker_prefix, "TRACE");
782    }
783}
784
785#[cfg(test)]
786mod payload_source_tests {
787    use super::*;
788
789    fn create_test_payload(text: &str, category: &str) -> Payload {
790        Payload {
791            text: text.into(),
792            category: category.into(),
793            technique: "test".into(),
794            context: "default".into(),
795            encoding: "raw".into(),
796            cwe: None,
797            severity: None,
798            confidence: 1.0,
799            expected_pattern: None,
800        }
801    }
802
803    #[test]
804    fn static_payloads_empty() {
805        let source = StaticPayloads::new(vec![]);
806        assert_eq!(source.payload_count(), 0);
807        assert!(source.categories().is_empty());
808    }
809
810    #[test]
811    fn static_payloads_single_category() {
812        let payloads = vec![
813            create_test_payload("payload1", "sqli"),
814            create_test_payload("payload2", "sqli"),
815        ];
816        let source = StaticPayloads::new(payloads);
817
818        assert_eq!(source.payload_count(), 2);
819        let cats = source.categories();
820        assert_eq!(cats.len(), 1);
821        assert_eq!(cats[0], "sqli");
822    }
823
824    #[test]
825    fn static_payloads_multiple_categories() {
826        let payloads = vec![
827            create_test_payload("p1", "sqli"),
828            create_test_payload("p2", "xss"),
829            create_test_payload("p3", "rce"),
830        ];
831        let source = StaticPayloads::new(payloads);
832
833        assert_eq!(source.payload_count(), 3);
834        let mut cats = source.categories();
835        cats.sort_unstable();
836        assert_eq!(cats, vec!["rce", "sqli", "xss"]);
837    }
838
839    #[test]
840    fn static_payloads_add() {
841        let mut source = StaticPayloads::new(vec![]);
842        source.add(create_test_payload("test", "cat"));
843
844        assert_eq!(source.payload_count(), 1);
845    }
846
847    #[test]
848    fn static_payloads_from_vec() {
849        let payloads = vec![create_test_payload("test", "cat")];
850        let source: StaticPayloads = payloads.into();
851
852        assert_eq!(source.payload_count(), 1);
853    }
854
855    #[test]
856    fn static_payloads_default() {
857        let source = StaticPayloads::default();
858        assert_eq!(source.payload_count(), 0);
859    }
860
861    #[test]
862    fn static_payloads_all_payloads() {
863        let payloads = vec![
864            create_test_payload("p1", "sqli"),
865            create_test_payload("p2", "xss"),
866        ];
867        let source = StaticPayloads::new(payloads);
868
869        assert_eq!(source.all_payloads().len(), 2);
870    }
871
872    #[test]
873    fn static_payloads_group_interleaved_categories() {
874        let payloads = vec![
875            create_test_payload("p1", "xss"),
876            create_test_payload("p2", "sqli"),
877            create_test_payload("p3", "xss"),
878        ];
879        let mut source = StaticPayloads::new(payloads);
880
881        let xss = source.payloads("xss");
882        assert_eq!(xss.len(), 2);
883        assert!(xss.iter().all(|payload| payload.category == "xss"));
884    }
885
886    #[test]
887    fn static_payloads_iter_category_filters() {
888        let payloads = vec![
889            create_test_payload("p1", "xss"),
890            create_test_payload("p2", "sqli"),
891            create_test_payload("p3", "xss"),
892        ];
893        let source = StaticPayloads::new(payloads);
894
895        let texts: Vec<_> = source
896            .iter_category("xss")
897            .map(|payload| payload.text.as_str())
898            .collect();
899
900        assert_eq!(texts, vec!["p1", "p3"]);
901    }
902
903    #[test]
904    fn static_payloads_iter_returns_all_items() {
905        let payloads = vec![
906            create_test_payload("p1", "xss"),
907            create_test_payload("p2", "sqli"),
908        ];
909        let source = StaticPayloads::new(payloads);
910
911        let texts: Vec<_> = source.iter().map(|payload| payload.text.as_str()).collect();
912        assert_eq!(texts, vec!["p2", "p1"]);
913    }
914
915    #[test]
916    fn payload_db_implements_payload_source() {
917        fn use_trait(source: &mut dyn PayloadSource) -> usize {
918            source.payload_count()
919        }
920
921        let mut db = PayloadDb::new();
922        db.load_toml(
923            r#"
924[grammar]
925name = "test"
926sink_category = "test-cat"
927
928[[contexts]]
929name = "default"
930prefix = ""
931suffix = ""
932
933[[techniques]]
934name = "t1"
935template = "hello"
936"#,
937        )
938        .unwrap();
939
940        // Test through the trait interface
941        assert_eq!(use_trait(&mut db), 1);
942
943        let cats = db.categories();
944        assert_eq!(cats.len(), 1);
945        assert_eq!(cats[0], "test-cat");
946
947        let payloads = db.payloads("test-cat");
948        assert_eq!(payloads.len(), 1);
949        assert_eq!(payloads[0].text, "hello");
950    }
951
952    #[test]
953    fn static_payloads_implements_payload_source() {
954        fn use_trait(s: &mut dyn PayloadSource) -> usize {
955            s.payload_count()
956        }
957
958        let payloads = vec![
959            create_test_payload("p1", "cat1"),
960            create_test_payload("p2", "cat2"),
961        ];
962        let mut source = StaticPayloads::new(payloads);
963
964        // Test through the trait interface
965        assert_eq!(use_trait(&mut source), 2);
966    }
967
968    #[test]
969    fn payload_source_trait_object_works() {
970        let payloads = vec![create_test_payload("test", "cat")];
971        let source: Box<dyn PayloadSource> = Box::new(StaticPayloads::new(payloads));
972
973        assert_eq!(source.payload_count(), 1);
974        assert_eq!(source.categories(), vec!["cat"]);
975    }
976}
977
978#[cfg(test)]
979mod encoder_tests {
980    use super::encoding::{CustomEncoder, Encoder};
981
982    #[test]
983    fn custom_encoder_new() {
984        let encoder = CustomEncoder::new(|s: &str| s.to_uppercase());
985        assert_eq!(encoder.encode("hello"), "HELLO");
986    }
987
988    #[test]
989    fn custom_encoder_default() {
990        let encoder = CustomEncoder::default();
991        assert_eq!(encoder.encode("hello"), "hello");
992    }
993
994    #[test]
995    fn encoder_trait_for_fn() {
996        fn upper(s: &str) -> String {
997            s.to_uppercase()
998        }
999        let encoder: &dyn Encoder = &upper;
1000        assert_eq!(encoder.encode("hello"), "HELLO");
1001    }
1002
1003    #[test]
1004    fn encoder_trait_for_closure() {
1005        let reverse = |s: &str| s.chars().rev().collect::<String>();
1006        assert_eq!(reverse.encode("hello"), "olleh");
1007    }
1008
1009    #[test]
1010    fn encoder_trait_for_rot13() {
1011        let rot13 = |s: &str| {
1012            s.chars()
1013                .map(|c| match c {
1014                    'a'..='m' | 'A'..='M' => (c as u8 + 13) as char,
1015                    'n'..='z' | 'N'..='Z' => (c as u8 - 13) as char,
1016                    _ => c,
1017                })
1018                .collect::<String>()
1019        };
1020        assert_eq!(rot13.encode("hello"), "uryyb");
1021    }
1022}
1023
1024/// Convenience re-exports for common usage.
1025///
1026/// ```rust
1027/// use attackstr::prelude::*;
1028/// ```
1029pub mod prelude {
1030    pub use crate::config::PayloadConfigFile;
1031    pub use crate::validate::{validate, GrammarIssue};
1032    pub use crate::{apply_encoding, BuiltinEncoding};
1033    pub use crate::{mutate_all, mutate_case, mutate_whitespace};
1034    pub use crate::{Payload, PayloadConfig, PayloadDb, PayloadError};
1035}