use std::sync::{Arc, Mutex, RwLock};
use std::collections::HashMap;
use crate::handle::{DiscoveredDomain, VerificationResult};
use crate::local_struct::LocalStruct;
use crate::stack::Stack;
#[derive(Debug, Clone)]
pub struct BruteForceState {
pub discovered_domains: Arc<Mutex<Vec<DiscoveredDomain>>>,
pub verification_results: Arc<Mutex<Vec<VerificationResult>>>,
pub local_status: Arc<RwLock<LocalStruct>>,
pub local_stack: Arc<RwLock<Stack<usize>>>,
}
impl BruteForceState {
pub fn new() -> Self {
BruteForceState {
discovered_domains: Arc::new(Mutex::new(Vec::new())),
verification_results: Arc::new(Mutex::new(Vec::new())),
local_status: Arc::new(RwLock::new(LocalStruct::new())),
local_stack: Arc::new(RwLock::new(Stack::new())),
}
}
pub fn add_discovered_domain(&self, domain: DiscoveredDomain) {
if let Ok(mut domains) = self.discovered_domains.lock() {
domains.push(domain);
}
}
pub fn get_discovered_domains(&self) -> Vec<DiscoveredDomain> {
if let Ok(domains) = self.discovered_domains.lock() {
domains.clone()
} else {
Vec::new()
}
}
pub fn clear_discovered_domains(&self) {
if let Ok(mut domains) = self.discovered_domains.lock() {
domains.clear();
}
}
pub fn add_verification_result(&self, result: VerificationResult) {
if let Ok(mut results) = self.verification_results.lock() {
results.push(result);
}
}
pub fn get_verification_results(&self) -> Vec<VerificationResult> {
if let Ok(results) = self.verification_results.lock() {
results.clone()
} else {
Vec::new()
}
}
pub fn clear_verification_results(&self) {
if let Ok(mut results) = self.verification_results.lock() {
results.clear();
}
}
pub fn is_local_status_empty(&self) -> bool {
match self.local_status.read() {
Ok(local_status) => local_status.empty(),
Err(_) => true,
}
}
pub fn get_timeout_data(&self, max_length: usize) -> Vec<crate::local_struct::LocalRetryStruct> {
match self.local_status.write() {
Ok(mut local_status) => local_status.get_timeout_data(max_length),
Err(_) => Vec::new(),
}
}
pub fn search_from_index_and_delete(&self, index: u32) -> Result<crate::local_struct::LocalRetryStruct, Box<dyn std::error::Error>> {
match self.local_status.write() {
Ok(mut local_status) => {
local_status.search_from_index_and_delete(index)
.map_err(|e| e)
}
Err(e) => Err(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to acquire write lock: {}", e)
)))
}
}
pub fn append_status(&self, value: crate::model::StatusTable, index: u32) {
if let Ok(mut local_status) = self.local_status.write() {
local_status.append(value, index);
}
}
pub fn push_to_stack(&self, index: usize) {
if let Ok(mut stack) = self.local_stack.try_write() {
if stack.length <= 50000 {
stack.push(index);
}
}
}
}
impl Default for BruteForceState {
fn default() -> Self {
Self::new()
}
}
unsafe impl Send for BruteForceState {}
unsafe impl Sync for BruteForceState {}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::sync::Arc;
#[test]
fn test_thread_safety() {
let state = Arc::new(BruteForceState::new());
let mut handles = vec![];
for i in 0..10 {
let state_clone = state.clone();
let handle = thread::spawn(move || {
let domain = DiscoveredDomain {
domain: format!("test{}.example.com", i),
ip: format!("192.168.1.{}", i),
record_type: "A".to_string(),
timestamp: chrono::Utc::now().timestamp() as u64,
};
state_clone.add_discovered_domain(domain);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let domains = state.get_discovered_domains();
assert_eq!(domains.len(), 10);
}
#[test]
fn test_state_isolation() {
let state1 = BruteForceState::new();
let state2 = BruteForceState::new();
let domain1 = DiscoveredDomain {
domain: "test1.example.com".to_string(),
ip: "192.168.1.1".to_string(),
record_type: "A".to_string(),
timestamp: chrono::Utc::now().timestamp() as u64,
};
let domain2 = DiscoveredDomain {
domain: "test2.example.com".to_string(),
ip: "192.168.1.2".to_string(),
record_type: "A".to_string(),
timestamp: chrono::Utc::now().timestamp() as u64,
};
state1.add_discovered_domain(domain1);
state2.add_discovered_domain(domain2);
assert_eq!(state1.get_discovered_domains().len(), 1);
assert_eq!(state2.get_discovered_domains().len(), 1);
assert_ne!(state1.get_discovered_domains()[0].domain, state2.get_discovered_domains()[0].domain);
}
}