use irondrop::cli::Cli;
use irondrop::server::run_server;
use reqwest::StatusCode;
use reqwest::blocking::Client;
use std::fs::File;
use std::io::{BufRead, Read, Write};
use std::net::SocketAddr;
use std::sync::mpsc;
use std::thread::{self, JoinHandle};
use tempfile::{TempDir, tempdir};
struct TestServer {
addr: SocketAddr,
shutdown_tx: mpsc::Sender<()>,
handle: Option<JoinHandle<()>>,
_temp_dir: TempDir,
}
fn setup_test_server(username: Option<String>, password: Option<String>) -> TestServer {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path).unwrap();
writeln!(file, "hello from test file").unwrap();
let forbidden_file_path = dir.path().join("test.zip");
File::create(&forbidden_file_path).unwrap();
let cli = Cli {
directory: dir.path().to_path_buf(),
listen: Some("127.0.0.1".to_string()),
port: Some(0), allowed_extensions: Some("*.txt".to_string()),
threads: Some(4),
chunk_size: Some(1024),
verbose: Some(false),
detailed_logging: Some(false),
username,
password,
enable_upload: Some(false),
max_upload_size: Some(10240),
enable_webdav: Some(false),
config_file: None,
log_dir: None,
ssl_cert: None,
ssl_key: None,
};
let (shutdown_tx, shutdown_rx) = mpsc::channel();
let (addr_tx, addr_rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
if let Err(e) = run_server(cli, Some(shutdown_rx), Some(addr_tx)) {
eprintln!("Server thread failed: {e}");
}
});
let server_addr = addr_rx.recv().unwrap();
TestServer {
addr: server_addr,
shutdown_tx,
handle: Some(server_handle),
_temp_dir: dir,
}
}
impl Drop for TestServer {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
self.shutdown_tx.send(()).ok();
handle.join().unwrap();
}
}
}
#[test]
fn test_unauthenticated_access() {
let server = setup_test_server(None, None);
let client = Client::new();
let res = client
.get(format!("http://{}/", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().unwrap();
assert!(body.contains("test.txt"));
let res = client
.get(format!("http://{}/test.txt", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().unwrap(), "hello from test file\n");
}
#[test]
fn test_authentication_required() {
let server = setup_test_server(Some("user".to_string()), Some("pass".to_string()));
let client = Client::new();
let res = client
.get(format!("http://{}/", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
assert!(res.headers().contains_key("www-authenticate"));
let res = client
.get(format!("http://{}/", server.addr))
.basic_auth("wrong", Some("user"))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_successful_authentication() {
let server = setup_test_server(Some("user".to_string()), Some("pass".to_string()));
let client = Client::new();
let res = client
.get(format!("http://{}/", server.addr))
.basic_auth("user", Some("pass"))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().unwrap();
assert!(body.contains("test.txt"));
let res = client
.get(format!("http://{}/test.txt", server.addr))
.basic_auth("user", Some("pass"))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().unwrap(), "hello from test file\n");
}
#[test]
fn test_error_responses() {
let server = setup_test_server(None, None);
let client = Client::new();
let res = client
.get(format!("http://{}/nonexistent.txt", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client
.get(format!("http://{}/test.zip", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::FORBIDDEN);
let res = client
.post(format!("http://{}/", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[test]
fn test_http_range_request_partial_content() {
use std::io::{Read, Write};
use std::net::TcpStream;
let server = setup_test_server(None, None);
let mut stream = TcpStream::connect(server.addr).expect("connect ok");
write!(
stream,
"GET /test.txt HTTP/1.1\r\nHost: localhost\r\nRange: bytes=0-9\r\nConnection: close\r\n\r\n"
)
.unwrap();
stream.flush().unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).unwrap();
let text = String::from_utf8_lossy(&buf);
let ok_206 = text.contains("HTTP/1.1 206") && text.contains("Content-Range: bytes 0-9/");
let ok_200 = text.contains("HTTP/1.1 200") && text.contains("Accept-Ranges: bytes");
assert!(
ok_206 || ok_200,
"expected 206 with Content-Range or 200 with Accept-Ranges, got: {}",
&*text
);
}
#[test]
fn test_static_asset_headers_and_lengths() {
use std::io::{Read, Write};
use std::net::TcpStream;
let server = setup_test_server(None, None);
let mut stream = TcpStream::connect(server.addr).expect("connect ok");
write!(
stream,
"GET /_irondrop/static/common/base.css HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
)
.unwrap();
stream.flush().unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).unwrap();
let text = String::from_utf8_lossy(&buf);
assert!(text.contains("HTTP/1.1 200"));
assert!(text.contains("Content-Type: text/css"));
assert!(text.contains("Content-Length:"));
}
#[test]
fn test_path_traversal_prevention() {
let server = setup_test_server(None, None);
let client = Client::new();
let res = client
.get(format!("http://{}/subdir/../test.txt", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().unwrap(), "hello from test file\n");
let mut stream = std::net::TcpStream::connect(server.addr).unwrap();
let request = "GET /../../../../../../etc/passwd HTTP/1.1\r\n\r\n";
stream.write_all(request.as_bytes()).unwrap();
let mut reader = std::io::BufReader::new(stream);
let mut status_line = String::new();
reader.read_line(&mut status_line).unwrap();
assert!(status_line.starts_with("HTTP/1.1 403 Forbidden"));
}
#[test]
fn test_malformed_request() {
let server = setup_test_server(None, None);
let request = "GET /not-a-valid-http-version\r\n\r\n";
let mut stream = std::net::TcpStream::connect(server.addr).unwrap();
stream.write_all(request.as_bytes()).unwrap();
let mut reader = std::io::BufReader::new(stream);
let mut status_line = String::new();
reader.read_line(&mut status_line).unwrap();
assert!(status_line.starts_with("HTTP/1.1 400 Bad Request"));
}
#[test]
fn test_concurrent_requests() {
let server = setup_test_server(None, None);
let client = Client::new();
let handles: Vec<_> = (0..10)
.map(|_| {
let addr = server.addr;
let client = client.clone();
thread::spawn(move || {
client
.get(format!("http://{}/test.txt", addr))
.send()
.unwrap()
})
})
.collect();
for handle in handles {
let response = handle.join().unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}
#[test]
fn test_large_header_handling() {
let server = setup_test_server(None, None);
let large_value = "x".repeat(8192); let mut stream = std::net::TcpStream::connect(server.addr).unwrap();
let request = format!(
"GET /test.txt HTTP/1.1\r\nHost: localhost\r\nX-Large-Header: {}\r\n\r\n",
large_value
);
stream.write_all(request.as_bytes()).unwrap();
let mut reader = std::io::BufReader::new(stream);
let mut status_line = String::new();
reader.read_line(&mut status_line).unwrap();
assert!(
status_line.starts_with("HTTP/1.1 200")
|| status_line.starts_with("HTTP/1.1 413")
|| status_line.starts_with("HTTP/1.1 400")
);
}
#[test]
fn test_empty_request_handling() {
let server = setup_test_server(None, None);
let mut stream = std::net::TcpStream::connect(server.addr).unwrap();
stream.write_all(b"").unwrap();
stream.shutdown(std::net::Shutdown::Write).unwrap();
let mut reader = std::io::BufReader::new(stream);
let mut response = String::new();
let _ = reader.read_to_string(&mut response);
}
#[test]
fn test_invalid_range_requests() {
let server = setup_test_server(None, None);
let test_cases = vec![
"Range: bytes=invalid",
"Range: bytes=100-50", "Range: bytes=999_999-", "Range: units=0-10", ];
for range_header in test_cases {
let mut stream = std::net::TcpStream::connect(server.addr).unwrap();
let request = format!(
"GET /test.txt HTTP/1.1\r\nHost: localhost\r\n{}\r\n\r\n",
range_header
);
stream.write_all(request.as_bytes()).unwrap();
let mut reader = std::io::BufReader::new(stream);
let mut status_line = String::new();
reader.read_line(&mut status_line).unwrap();
assert!(status_line.starts_with("HTTP/1.1 200") || status_line.starts_with("HTTP/1.1 416"));
}
}
#[test]
fn test_connection_timeout_handling() {
let server = setup_test_server(None, None);
let mut stream = std::net::TcpStream::connect(server.addr).unwrap();
stream.write_all(b"GET /test.txt HTTP/1.1\r\n").unwrap();
thread::sleep(std::time::Duration::from_millis(100));
let mut reader = std::io::BufReader::new(stream);
let mut response = String::new();
let result = reader.read_to_string(&mut response);
}
#[test]
fn test_special_characters_in_paths() {
let server = setup_test_server(None, None);
let client = Client::new();
let test_paths = vec![
"/test%20file.txt", "/test%2Efile.txt", "/test%3Ffile.txt", "/test%23file.txt", ];
for path in test_paths {
let res = client
.get(format!("http://{}{}", server.addr, path))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
}
fn setup_test_server_with_tree<F>(
username: Option<String>,
password: Option<String>,
populate: F,
) -> TestServer
where
F: FnOnce(&std::path::Path),
{
let dir = tempdir().unwrap();
populate(dir.path());
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path).unwrap();
writeln!(file, "hello from test file").unwrap();
let cli = Cli {
directory: dir.path().to_path_buf(),
listen: Some("127.0.0.1".to_string()),
port: Some(0),
allowed_extensions: Some("*".to_string()),
threads: Some(4),
chunk_size: Some(1024),
verbose: Some(false),
detailed_logging: Some(false),
username,
password,
enable_upload: Some(false),
max_upload_size: Some(10240),
enable_webdav: Some(false),
config_file: None,
log_dir: None,
ssl_cert: None,
ssl_key: None,
};
let (shutdown_tx, shutdown_rx) = mpsc::channel();
let (addr_tx, addr_rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
if let Err(e) = run_server(cli, Some(shutdown_rx), Some(addr_tx)) {
eprintln!("Server thread failed: {e}");
}
});
let server_addr = addr_rx.recv().unwrap();
TestServer {
addr: server_addr,
shutdown_tx,
handle: Some(server_handle),
_temp_dir: dir,
}
}
#[test]
fn test_directory_trailing_slash_redirect_and_links() {
use std::fs::{File, create_dir_all};
use std::io::Write;
let server = setup_test_server_with_tree(None, None, |root| {
let d1 = root.join("call_rec_data").join("Ruby").join("23");
let d2 = root.join("call_rec_data").join("Ruby").join("45");
create_dir_all(&d1).unwrap();
create_dir_all(&d2).unwrap();
let mut f = File::create(d1.join("dummy.txt")).unwrap();
writeln!(f, "x").unwrap();
});
let no_redirect_client = reqwest::blocking::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
let res = no_redirect_client
.get(format!("http://{}/call_rec_data/Ruby", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::MOVED_PERMANENTLY);
let loc = res.headers().get("location").unwrap().to_str().unwrap();
assert_eq!(loc, "/call_rec_data/Ruby/");
let client = Client::new();
let res = client
.get(format!("http://{}/call_rec_data/Ruby/", server.addr))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().unwrap();
assert!(body.contains("/call_rec_data/Ruby/23/"));
assert!(body.contains("/call_rec_data/Ruby/45/"));
assert!(!body.contains("/call_rec_data/23/"));
}
#[test]
fn test_search_api_directory_paths_have_trailing_slash() {
use serde_json::Value;
use std::fs::create_dir_all;
let server = setup_test_server_with_tree(None, None, |root| {
create_dir_all(root.join("call_rec_data").join("Ruby").join("23")).unwrap();
});
let client = Client::new();
let res = client
.get(format!(
"http://{}/_irondrop/search?q=Ruby&path=/",
server.addr
))
.send()
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().unwrap();
let json: Value = serde_json::from_str(&text).unwrap();
if let Some(arr) = json.as_array() {
for item in arr {
if item["type"] == "directory" {
let p = item["path"].as_str().unwrap_or("");
assert!(p.ends_with('/'));
}
}
}
}