use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::sync::watch;
use tracing::{debug, error, warn};
use crate::stats::StreamStats;
use crate::tcp_info::get_tcp_info;
const DEFAULT_BUFFER_SIZE: usize = 128 * 1024; const HIGH_SPEED_BUFFER: usize = 4 * 1024 * 1024;
#[derive(Clone)]
pub struct TcpConfig {
pub buffer_size: usize,
pub nodelay: bool,
pub window_size: Option<usize>,
pub congestion: Option<String>,
}
impl Default for TcpConfig {
fn default() -> Self {
Self {
buffer_size: DEFAULT_BUFFER_SIZE,
nodelay: false,
window_size: None,
congestion: None,
}
}
}
const HIGH_SPEED_WINDOW_THRESHOLD: usize = 1_000_000;
impl TcpConfig {
pub fn high_speed() -> Self {
Self {
buffer_size: HIGH_SPEED_BUFFER,
nodelay: true,
window_size: Some(HIGH_SPEED_BUFFER),
congestion: None,
}
}
pub fn with_auto_detect(
nodelay: bool,
window_size: Option<usize>,
bitrate_limit: Option<u64>,
) -> Self {
let use_high_speed = window_size
.map(|w| w > HIGH_SPEED_WINDOW_THRESHOLD)
.unwrap_or(false)
|| !matches!(bitrate_limit, Some(bps) if bps > 0);
if use_high_speed && window_size.is_none() {
let mut config = Self::high_speed();
config.nodelay = nodelay;
config
} else {
Self {
buffer_size: window_size.unwrap_or(DEFAULT_BUFFER_SIZE),
nodelay,
window_size,
congestion: None,
}
}
}
}
#[cfg(unix)]
fn configure_socket_buffers(stream: &TcpStream, buffer_size: usize) -> std::io::Result<()> {
use std::os::unix::io::AsRawFd;
use tracing::debug;
let fd = stream.as_raw_fd();
let size = buffer_size as libc::c_int;
unsafe {
let ret = libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_SNDBUF,
&size as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
if ret != 0 {
debug!(
"Failed to set SO_SNDBUF to {}: {}",
buffer_size,
std::io::Error::last_os_error()
);
}
let ret = libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_RCVBUF,
&size as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
if ret != 0 {
debug!(
"Failed to set SO_RCVBUF to {}: {}",
buffer_size,
std::io::Error::last_os_error()
);
}
}
Ok(())
}
#[cfg(not(unix))]
fn configure_socket_buffers(_stream: &TcpStream, _buffer_size: usize) -> std::io::Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
fn set_tcp_congestion(stream: &TcpStream, algo: &str) -> std::io::Result<()> {
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let ret = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_CONGESTION,
algo.as_ptr() as *const libc::c_void,
algo.len() as libc::socklen_t,
)
};
if ret != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
fn set_tcp_congestion(_stream: &TcpStream, _algo: &str) -> std::io::Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn validate_congestion(algo: &str) -> Result<(), String> {
let fd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) };
if fd < 0 {
return Err(format!(
"failed to create test socket: {}",
std::io::Error::last_os_error()
));
}
let ret = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_CONGESTION,
algo.as_ptr() as *const libc::c_void,
algo.len() as libc::socklen_t,
)
};
let result = if ret != 0 {
let mut msg = "not available on this kernel".to_string();
if let Ok(available) =
std::fs::read_to_string("/proc/sys/net/ipv4/tcp_available_congestion_control")
{
msg = format!("not available (available: {})", available.trim());
}
Err(msg)
} else {
Ok(())
};
unsafe { libc::close(fd) };
result
}
#[cfg(not(target_os = "linux"))]
pub fn validate_congestion(_algo: &str) -> Result<(), String> {
Ok(())
}
pub fn configure_stream(stream: &TcpStream, config: &TcpConfig) -> std::io::Result<()> {
stream.set_nodelay(config.nodelay)?;
if let Some(window) = config.window_size {
configure_socket_buffers(stream, window)?;
}
if let Some(ref algo) = config.congestion {
set_tcp_congestion(stream, algo)?;
}
Ok(())
}
pub async fn send_data(
mut stream: TcpStream,
stats: Arc<StreamStats>,
duration: Duration,
config: TcpConfig,
mut cancel: watch::Receiver<bool>,
bitrate: Option<u64>,
mut pause: watch::Receiver<bool>,
) -> anyhow::Result<Option<crate::protocol::TcpInfoSnapshot>> {
configure_stream(&stream, &config)?;
let buf_size = match bitrate {
Some(bps) if bps > 0 => {
let bytes_per_sec = bps / 8;
config.buffer_size.min((bytes_per_sec / 10).max(1) as usize)
}
_ => config.buffer_size,
};
let buffer = vec![0u8; buf_size];
let start = tokio::time::Instant::now();
let deadline = start + duration;
let is_infinite = duration == Duration::ZERO;
let mut pace_start = start;
let mut pace_bytes_offset: u64 = 0;
loop {
if *cancel.borrow() {
debug!("Send cancelled for stream {}", stats.stream_id);
break;
}
if *pause.borrow() {
if crate::pause::wait_while_paused(&mut pause, &mut cancel).await {
break;
}
pace_start = tokio::time::Instant::now();
pace_bytes_offset = stats.bytes_sent.load(std::sync::atomic::Ordering::Relaxed);
continue;
}
if !is_infinite && tokio::time::Instant::now() >= deadline {
break;
}
match stream.write(&buffer).await {
Ok(n) => {
stats.add_bytes_sent(n as u64);
if let Some(bps) = bitrate
&& bps > 0
{
let bytes_per_sec = bps as f64 / 8.0;
let elapsed = pace_start.elapsed().as_secs_f64();
let expected = elapsed * bytes_per_sec;
let total = (stats.bytes_sent.load(std::sync::atomic::Ordering::Relaxed)
- pace_bytes_offset) as f64;
if total > expected {
let overshoot = Duration::from_secs_f64((total - expected) / bytes_per_sec);
tokio::select! {
biased;
_ = cancel.changed() => {
debug!("Send cancelled during pacing sleep for stream {}", stats.stream_id);
break;
}
_ = pause.changed() => {} _ = tokio::time::sleep(overshoot) => {}
}
}
}
}
Err(e) => {
error!("Send error on stream {}: {}", stats.stream_id, e);
return Err(e.into());
}
}
}
let tcp_info = get_stream_tcp_info(&stream);
if let Some(ref info) = tcp_info {
stats.add_retransmits(info.retransmits);
}
stream.shutdown().await?;
debug!(
"Stream {} send complete: {} bytes",
stats.stream_id,
stats.bytes_sent.load(std::sync::atomic::Ordering::Relaxed)
);
Ok(tcp_info)
}
pub async fn receive_data(
mut stream: TcpStream,
stats: Arc<StreamStats>,
mut cancel: watch::Receiver<bool>,
config: TcpConfig,
) -> anyhow::Result<Option<crate::protocol::TcpInfoSnapshot>> {
configure_stream(&stream, &config)?;
let mut buffer = vec![0u8; config.buffer_size];
loop {
tokio::select! {
result = stream.read(&mut buffer) => {
match result {
Ok(0) => {
debug!("Stream {} EOF", stats.stream_id);
break;
}
Ok(n) => {
stats.add_bytes_received(n as u64);
}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
continue;
}
warn!("Receive error on stream {}: {}", stats.stream_id, e);
return Err(e.into());
}
}
}
_ = cancel.changed() => {
if *cancel.borrow() {
debug!("Receive cancelled for stream {}", stats.stream_id);
break;
}
}
}
}
let tcp_info = get_stream_tcp_info(&stream);
if let Some(ref info) = tcp_info {
stats.add_retransmits(info.retransmits);
}
debug!(
"Stream {} receive complete: {} bytes",
stats.stream_id,
stats
.bytes_received
.load(std::sync::atomic::Ordering::Relaxed)
);
Ok(tcp_info)
}
pub fn get_stream_tcp_info(stream: &TcpStream) -> Option<crate::protocol::TcpInfoSnapshot> {
get_tcp_info(stream).ok()
}
pub async fn send_data_half(
mut write_half: OwnedWriteHalf,
stats: Arc<StreamStats>,
duration: Duration,
config: TcpConfig,
mut cancel: watch::Receiver<bool>,
bitrate: Option<u64>,
mut pause: watch::Receiver<bool>,
) -> anyhow::Result<OwnedWriteHalf> {
let buf_size = match bitrate {
Some(bps) if bps > 0 => {
let bytes_per_sec = bps / 8;
config.buffer_size.min((bytes_per_sec / 10).max(1) as usize)
}
_ => config.buffer_size,
};
let buffer = vec![0u8; buf_size];
let start = tokio::time::Instant::now();
let deadline = start + duration;
let is_infinite = duration == Duration::ZERO;
let mut pace_start = start;
let mut pace_bytes_offset: u64 = 0;
loop {
if *cancel.borrow() {
debug!("Send cancelled for stream {}", stats.stream_id);
break;
}
if *pause.borrow() {
if crate::pause::wait_while_paused(&mut pause, &mut cancel).await {
break;
}
pace_start = tokio::time::Instant::now();
pace_bytes_offset = stats.bytes_sent.load(std::sync::atomic::Ordering::Relaxed);
continue;
}
if !is_infinite && tokio::time::Instant::now() >= deadline {
break;
}
match write_half.write(&buffer).await {
Ok(n) => {
stats.add_bytes_sent(n as u64);
if let Some(bps) = bitrate
&& bps > 0
{
let bytes_per_sec = bps as f64 / 8.0;
let elapsed = pace_start.elapsed().as_secs_f64();
let expected = elapsed * bytes_per_sec;
let total = (stats.bytes_sent.load(std::sync::atomic::Ordering::Relaxed)
- pace_bytes_offset) as f64;
if total > expected {
let overshoot = Duration::from_secs_f64((total - expected) / bytes_per_sec);
tokio::select! {
biased;
_ = cancel.changed() => {
debug!("Send cancelled during pacing sleep for stream {}", stats.stream_id);
break;
}
_ = pause.changed() => {}
_ = tokio::time::sleep(overshoot) => {}
}
}
}
}
Err(e) => {
error!("Send error on stream {}: {}", stats.stream_id, e);
return Err(e.into());
}
}
}
let _ = write_half.shutdown().await;
debug!(
"Stream {} send complete: {} bytes",
stats.stream_id,
stats.bytes_sent.load(std::sync::atomic::Ordering::Relaxed)
);
Ok(write_half)
}
pub async fn receive_data_half(
mut read_half: OwnedReadHalf,
stats: Arc<StreamStats>,
mut cancel: watch::Receiver<bool>,
config: TcpConfig,
) -> anyhow::Result<OwnedReadHalf> {
let mut buffer = vec![0u8; config.buffer_size];
loop {
tokio::select! {
result = read_half.read(&mut buffer) => {
match result {
Ok(0) => {
debug!("Stream {} EOF", stats.stream_id);
break;
}
Ok(n) => {
stats.add_bytes_received(n as u64);
}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
continue;
}
warn!("Receive error on stream {}: {}", stats.stream_id, e);
return Err(e.into());
}
}
}
_ = cancel.changed() => {
if *cancel.borrow() {
debug!("Receive cancelled for stream {}", stats.stream_id);
break;
}
}
}
}
debug!(
"Stream {} receive complete: {} bytes",
stats.stream_id,
stats
.bytes_received
.load(std::sync::atomic::Ordering::Relaxed)
);
Ok(read_half)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = TcpConfig::default();
assert_eq!(config.buffer_size, DEFAULT_BUFFER_SIZE);
assert!(!config.nodelay);
}
#[test]
fn test_high_speed_config() {
let config = TcpConfig::high_speed();
assert_eq!(config.buffer_size, HIGH_SPEED_BUFFER);
assert!(config.nodelay);
}
#[test]
#[cfg(target_os = "linux")]
fn test_validate_congestion_cubic() {
assert!(validate_congestion("cubic").is_ok());
}
#[test]
#[cfg(target_os = "linux")]
fn test_validate_congestion_invalid() {
let result = validate_congestion("nonexistent_algo_xyz");
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(msg.contains("not available"), "unexpected error: {}", msg);
}
}