import subprocess
import time
import requests
import json
import signal
import sys
import os
from threading import Thread
import http.server
import socketserver
from contextlib import contextmanager
BACKEND_PORT = 3001
WISEGATE_PORT = 8080
TEST_TIMEOUT = 30
class Colors:
RED = '\033[0;31m'
GREEN = '\033[0;32m'
YELLOW = '\033[1;33m'
BLUE = '\033[0;34m'
PURPLE = '\033[0;35m'
CYAN = '\033[0;36m'
NC = '\033[0m'
def print_colored(message, color):
print(f"{color}{message}{Colors.NC}")
def print_success(message):
print_colored(f"โ
{message}", Colors.GREEN)
def print_error(message):
print_colored(f"โ {message}", Colors.RED)
def print_info(message):
print_colored(f"โน๏ธ {message}", Colors.BLUE)
def print_test(message):
print_colored(f"๐งช {message}", Colors.PURPLE)
class TestBackendHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.send_header('X-Backend-Server', 'test-backend')
self.end_headers()
response = {
'message': 'Hello from test backend!',
'method': 'GET',
'path': self.path,
'headers': dict(self.headers),
'server': 'test-backend'
}
self.wfile.write(json.dumps(response, indent=2).encode())
def do_POST(self):
content_length = int(self.headers.get('Content-Length', 0))
post_data = self.rfile.read(content_length) if content_length > 0 else b''
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.send_header('X-Backend-Server', 'test-backend')
self.end_headers()
response = {
'message': 'POST received by test backend!',
'method': 'POST',
'path': self.path,
'headers': dict(self.headers),
'body_size': len(post_data),
'body_preview': post_data[:100].decode('utf-8', errors='ignore') if post_data else None,
'server': 'test-backend'
}
self.wfile.write(json.dumps(response, indent=2).encode())
def do_PUT(self):
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
response = {'message': 'PUT request received', 'method': 'PUT'}
self.wfile.write(json.dumps(response).encode())
def log_message(self, format, *args):
pass
class TestEnvironment:
def __init__(self):
self.backend_server = None
self.backend_thread = None
self.wisegate_process = None
self.test_results = []
def build_wisegate(self):
print_info("Building WiseGate...")
try:
result = subprocess.run(['cargo', 'build', '--release'],
capture_output=True, text=True, check=True)
print_success("Build completed successfully")
return True
except subprocess.CalledProcessError as e:
print_error(f"Build failed: {e.stderr}")
return False
def start_backend(self):
print_info(f"Starting test backend on port {BACKEND_PORT}...")
def run_server():
with socketserver.TCPServer(('localhost', BACKEND_PORT), TestBackendHandler) as httpd:
self.backend_server = httpd
httpd.serve_forever()
self.backend_thread = Thread(target=run_server, daemon=True)
self.backend_thread.start()
time.sleep(1)
try:
response = requests.get(f'http://localhost:{BACKEND_PORT}/', timeout=5)
if response.status_code == 200:
print_success(f"Backend server running on port {BACKEND_PORT}")
return True
except:
pass
print_error("Failed to start backend server")
return False
def start_wisegate(self):
print_info(f"Starting WiseGate on port {WISEGATE_PORT}...")
env = os.environ.copy()
env.update({
'RATE_LIMIT_REQUESTS': '5',
'RATE_LIMIT_WINDOW_SECS': '10',
'BLOCKED_METHODS': 'TRACE,CONNECT',
'BLOCKED_PATTERNS': '.env,.git,admin',
'PROXY_TIMEOUT_SECS': '5'
})
try:
self.wisegate_process = subprocess.Popen([
'./target/release/wisegate',
'-l', str(WISEGATE_PORT),
'-f', str(BACKEND_PORT),
'--quiet'
], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
time.sleep(2)
response = requests.get(f'http://localhost:{WISEGATE_PORT}/', timeout=5)
if response.status_code == 200:
print_success(f"WiseGate running on port {WISEGATE_PORT}")
return True
except Exception as e:
print_error(f"Failed to start WiseGate: {e}")
return False
def run_test(self, test_name, test_func):
print_test(f"Running: {test_name}")
try:
result = test_func()
if result:
print_success(f"PASSED: {test_name}")
self.test_results.append((test_name, True, None))
else:
print_error(f"FAILED: {test_name}")
self.test_results.append((test_name, False, "Test returned False"))
except Exception as e:
print_error(f"ERROR: {test_name} - {str(e)}")
self.test_results.append((test_name, False, str(e)))
time.sleep(0.5)
def test_basic_proxy(self):
response = requests.get(f'http://localhost:{WISEGATE_PORT}/', timeout=5)
if response.status_code != 200:
return False
data = response.json()
return (data.get('message') == 'Hello from test backend!' and
data.get('server') == 'test-backend')
def test_post_request(self):
test_data = "This is test POST data"
response = requests.post(f'http://localhost:{WISEGATE_PORT}/api/test',
data=test_data, timeout=5)
if response.status_code != 200:
return False
data = response.json()
return (data.get('method') == 'POST' and
data.get('body_size') == len(test_data.encode()))
def test_headers_forwarding(self):
headers = {'X-Custom-Header': 'test-value', 'X-Client-ID': '12345'}
response = requests.get(f'http://localhost:{WISEGATE_PORT}/headers',
headers=headers, timeout=5)
if response.status_code != 200:
return False
data = response.json()
received_headers = data.get('headers', {})
has_custom = any('custom-header' in k.lower() for k in received_headers.keys())
has_client = any('client-id' in k.lower() for k in received_headers.keys())
return has_custom and has_client
def test_real_ip_header(self):
response = requests.get(f'http://localhost:{WISEGATE_PORT}/ip-test', timeout=5)
if response.status_code != 200:
return False
data = response.json()
return True
def test_blocked_method(self):
try:
response = requests.request('TRACE', f'http://localhost:{WISEGATE_PORT}/', timeout=5)
return response.status_code == 405
except:
return False
def test_blocked_pattern(self):
test_paths = ['/.env', '/.git/config', '/admin/panel']
for path in test_paths:
response = requests.get(f'http://localhost:{WISEGATE_PORT}{path}', timeout=5)
if response.status_code != 404:
return False
return True
def test_rate_limiting(self):
responses = []
for i in range(6):
try:
response = requests.get(f'http://localhost:{WISEGATE_PORT}/rate-test', timeout=5)
responses.append(response.status_code)
except:
responses.append(0)
time.sleep(0.05)
return len(set(responses)) == 1 or 429 in responses
def test_large_request(self):
large_data = 'x' * 1024 * 100 response = requests.post(f'http://localhost:{WISEGATE_PORT}/large',
data=large_data, timeout=10)
if response.status_code != 200:
return False
data = response.json()
return data.get('body_size') == len(large_data.encode())
def run_all_tests(self):
print_colored("๐ก๏ธ WiseGate Automated Test Suite", Colors.CYAN)
print_colored("=" * 40, Colors.CYAN)
if not self.build_wisegate():
return False
if not self.start_backend():
return False
if not self.start_wisegate():
return False
print_colored("\n๐งช Running Tests...", Colors.YELLOW)
print_colored("=" * 20, Colors.YELLOW)
self.run_test("Basic Proxy Functionality", self.test_basic_proxy)
self.run_test("POST Request Forwarding", self.test_post_request)
self.run_test("Header Forwarding", self.test_headers_forwarding)
self.run_test("Real IP Header Injection", self.test_real_ip_header)
self.run_test("Blocked HTTP Methods", self.test_blocked_method)
self.run_test("Blocked URL Patterns", self.test_blocked_pattern)
self.run_test("Rate Limiting", self.test_rate_limiting)
self.run_test("Large Request Handling", self.test_large_request)
self.print_results()
return all(result for _, result, _ in self.test_results)
def print_results(self):
print_colored("\n๐ Test Results Summary", Colors.CYAN)
print_colored("=" * 25, Colors.CYAN)
passed = sum(1 for _, result, _ in self.test_results if result)
total = len(self.test_results)
for test_name, result, error in self.test_results:
status = "PASS" if result else "FAIL"
color = Colors.GREEN if result else Colors.RED
print_colored(f" {status:4} | {test_name}", color)
if error and not result:
print_colored(f" | Error: {error}", Colors.RED)
print_colored(f"\n๐ฏ Results: {passed}/{total} tests passed",
Colors.GREEN if passed == total else Colors.YELLOW)
if passed == total:
print_success("All tests passed! ๐ WiseGate is working correctly.")
else:
print_error(f"{total - passed} test(s) failed. Please check the issues above.")
def cleanup(self):
print_info("Cleaning up...")
if self.wisegate_process:
self.wisegate_process.terminate()
try:
self.wisegate_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.wisegate_process.kill()
if self.backend_server:
self.backend_server.shutdown()
def main():
env = TestEnvironment()
def signal_handler(sig, frame):
print_colored("\n๐ Test interrupted by user", Colors.YELLOW)
env.cleanup()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
success = env.run_all_tests()
env.cleanup()
sys.exit(0 if success else 1)
except Exception as e:
print_error(f"Test suite failed: {e}")
env.cleanup()
sys.exit(1)
if __name__ == "__main__":
main()