pg-api 0.1.0

A high-performance PostgreSQL REST API driver with rate limiting, connection pooling, and observability
#!/usr/bin/env python3
"""
Load testing script for pg-api using locust
"""

import json
import time
import random
import argparse
from locust import HttpUser, task, between
from datetime import datetime

class PgApiUser(HttpUser):
    """Simulated user for pg-api load testing"""
    
    wait_time = between(0.1, 1.0)  # Wait between requests
    
    def on_start(self):
        """Setup before tests start"""
        self.api_token = "sk_test_load_testing_token"
        self.headers = {
            "Authorization": f"Bearer {self.api_token}",
            "Content-Type": "application/json"
        }
        self.user_ids = list(range(1, 1001))  # Simulate 1000 users
    
    @task(10)
    def single_query(self):
        """Test single query endpoint"""
        user_id = random.choice(self.user_ids)
        payload = {
            "query": "SELECT * FROM users WHERE id = $1",
            "params": [user_id],
            "database": "main"
        }
        
        with self.client.post(
            "/v1/query",
            json=payload,
            headers=self.headers,
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")
    
    @task(5)
    def batch_query(self):
        """Test batch query endpoint"""
        batch_size = random.randint(2, 10)
        queries = []
        
        for _ in range(batch_size):
            user_id = random.choice(self.user_ids)
            queries.append({
                "query": "SELECT * FROM users WHERE id = $1",
                "params": [user_id]
            })
        
        payload = {
            "queries": queries,
            "database": "main"
        }
        
        with self.client.post(
            "/v1/batch",
            json=payload,
            headers=self.headers,
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")
    
    @task(2)
    def transaction_query(self):
        """Test transaction endpoint"""
        user_id = random.choice(self.user_ids)
        amount = random.uniform(10, 1000)
        
        queries = [
            {
                "query": "UPDATE accounts SET balance = balance - $1 WHERE user_id = $2",
                "params": [amount, user_id]
            },
            {
                "query": "INSERT INTO transactions (user_id, amount, type) VALUES ($1, $2, $3)",
                "params": [user_id, amount, "debit"]
            }
        ]
        
        payload = {
            "queries": queries,
            "database": "main"
        }
        
        with self.client.post(
            "/v1/transaction",
            json=payload,
            headers=self.headers,
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")
    
    @task(3)
    def list_tables(self):
        """Test schema introspection"""
        with self.client.get(
            "/v1/databases/main/tables",
            headers=self.headers,
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")
    
    @task(1)
    def health_check(self):
        """Test health endpoint"""
        with self.client.get(
            "/health",
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")


class AdminUser(HttpUser):
    """Simulated admin user for testing admin operations"""
    
    wait_time = between(2, 5)  # Admins make fewer requests
    
    def on_start(self):
        """Setup before tests start"""
        self.api_token = "sk_test_admin_token"
        self.headers = {
            "Authorization": f"Bearer {self.api_token}",
            "Content-Type": "application/json"
        }
    
    @task(1)
    def get_account_info(self):
        """Get account information"""
        with self.client.get(
            "/v1/account",
            headers=self.headers,
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")
    
    @task(1)
    def get_database_schema(self):
        """Get full database schema"""
        with self.client.get(
            "/v1/databases/main/schema",
            headers=self.headers,
            catch_response=True
        ) as response:
            if response.status_code == 200:
                response.success()
            else:
                response.failure(f"Got status code {response.status_code}")


def run_standalone_test():
    """Run a simple load test without locust"""
    import requests
    import concurrent.futures
    import statistics
    
    base_url = "http://localhost:8580"
    api_token = "sk_test_load_testing_token"
    headers = {
        "Authorization": f"Bearer {api_token}",
        "Content-Type": "application/json"
    }
    
    def make_request():
        """Make a single request and measure time"""
        start = time.time()
        try:
            response = requests.post(
                f"{base_url}/v1/query",
                json={
                    "query": "SELECT 1 as test",
                    "params": [],
                    "database": "main"
                },
                headers=headers,
                timeout=5
            )
            elapsed = time.time() - start
            return {
                "success": response.status_code == 200,
                "time": elapsed,
                "status": response.status_code
            }
        except Exception as e:
            return {
                "success": False,
                "time": time.time() - start,
                "error": str(e)
            }
    
    print("🚀 Running standalone load test...")
    print(f"Target: {base_url}")
    print("-" * 50)
    
    # Test different concurrency levels
    for concurrency in [1, 10, 50, 100]:
        print(f"\n📊 Testing with {concurrency} concurrent requests...")
        
        results = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
            futures = [executor.submit(make_request) for _ in range(concurrency * 10)]
            
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())
        
        # Calculate statistics
        successful = [r for r in results if r["success"]]
        failed = [r for r in results if not r["success"]]
        times = [r["time"] for r in successful]
        
        if times:
            print(f"  ✅ Success: {len(successful)}/{len(results)}")
            print(f"  ⏱️  Min: {min(times):.3f}s")
            print(f"  ⏱️  Max: {max(times):.3f}s")
            print(f"  ⏱️  Avg: {statistics.mean(times):.3f}s")
            print(f"  ⏱️  Median: {statistics.median(times):.3f}s")
            if len(times) > 1:
                print(f"  ⏱️  StdDev: {statistics.stdev(times):.3f}s")
            print(f"  📈 Throughput: {len(successful)/sum(times):.1f} req/s")
        
        if failed:
            print(f"  ❌ Failed: {len(failed)}")
            errors = {}
            for r in failed:
                error = r.get("error", f"Status {r.get('status', 'unknown')}")
                errors[error] = errors.get(error, 0) + 1
            for error, count in errors.items():
                print(f"    - {error}: {count}")
    
    print("\n✅ Load test complete!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Load test pg-api")
    parser.add_argument(
        "--standalone",
        action="store_true",
        help="Run standalone test without locust"
    )
    
    args = parser.parse_args()
    
    if args.standalone:
        run_standalone_test()
    else:
        print("To run with locust:")
        print("  1. Install: pip install locust")
        print("  2. Run: locust -f load_test.py --host=http://localhost:8580")
        print("  3. Open: http://localhost:8089")
        print("\nOr run standalone test:")
        print("  python load_test.py --standalone")