pg-api 0.1.0

A high-performance PostgreSQL REST API driver with rate limiting, connection pooling, and observability
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::env;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::interval;
use uuid::Uuid;

/// Observability configuration
#[derive(Clone, Debug, Deserialize)]
pub struct ObservabilityConfig {
    pub enabled: bool,
    pub opensearch_url: Option<String>,
    pub api_token: Option<String>,
    pub index_prefix: String,
    pub batch_size: usize,
    pub flush_interval_secs: u64,
    pub buffer_size: usize,
}

impl Default for ObservabilityConfig {
    fn default() -> Self {
        Self {
            enabled: false,
            opensearch_url: None,
            api_token: None,
            index_prefix: "pg-api".to_string(),
            batch_size: 100,
            flush_interval_secs: 5,
            buffer_size: 10000,
        }
    }
}

impl ObservabilityConfig {
    pub fn from_env() -> Self {
        let enabled = env::var("OPENSEARCH_ENABLED")
            .unwrap_or_else(|_| "false".to_string())
            .parse::<bool>()
            .unwrap_or(false);
        
        if !enabled {
            return Self::default();
        }
        
        Self {
            enabled,
            opensearch_url: env::var("OPENSEARCH_API_URL").ok(),
            api_token: env::var("OPENSEARCH_API_TOKEN").ok(),
            index_prefix: env::var("OPENSEARCH_INDEX_PREFIX")
                .unwrap_or_else(|_| "pg-api".to_string()),
            batch_size: env::var("OPENSEARCH_BATCH_SIZE")
                .unwrap_or_else(|_| "100".to_string())
                .parse()
                .unwrap_or(100),
            flush_interval_secs: env::var("OPENSEARCH_FLUSH_INTERVAL")
                .unwrap_or_else(|_| "5".to_string())
                .parse()
                .unwrap_or(5),
            buffer_size: 10000,
        }
    }
}

/// Base event structure for all observability events
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ObservabilityEvent {
    #[serde(rename = "@timestamp")]
    pub timestamp: DateTime<Utc>,
    pub event_type: String,
    pub correlation_id: String,
    pub service: String,
    pub environment: String,
    pub host: String,
    #[serde(flatten)]
    pub data: serde_json::Value,
}

/// Metric event for API requests
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MetricEvent {
    #[serde(rename = "@timestamp")]
    pub timestamp: DateTime<Utc>,
    pub event_type: String,
    pub correlation_id: String,
    pub method: String,
    pub path: String,
    pub status_code: u16,
    pub duration_ms: u64,
    pub client_ip: Option<String>,
    pub api_key_id: Option<String>,
    pub database: Option<String>,
    pub query_type: Option<String>,
    pub error: Option<String>,
}

/// Log event for structured logging
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LogEvent {
    #[serde(rename = "@timestamp")]
    pub timestamp: DateTime<Utc>,
    pub event_type: String,
    pub correlation_id: String,
    pub level: String,
    pub message: String,
    pub module: String,
    pub fields: Option<serde_json::Value>,
}

/// Observability client for sending events to OpenSearch
pub struct ObservabilityClient {
    config: ObservabilityConfig,
    buffer: Arc<Mutex<VecDeque<serde_json::Value>>>,
    http_client: reqwest::Client,
}

impl ObservabilityClient {
    pub fn new(config: ObservabilityConfig) -> Self {
        let http_client = reqwest::Client::builder()
            .timeout(Duration::from_secs(10))
            .build()
            .unwrap_or_default();
        
        Self {
            config,
            buffer: Arc::new(Mutex::new(VecDeque::with_capacity(10000))),
            http_client,
        }
    }
    
    /// Start the background flush task
    pub fn start_flush_task(self: Arc<Self>) {
        if !self.config.enabled {
            tracing::info!("Observability disabled, not starting flush task");
            return;
        }
        
        let client = self.clone();
        tokio::spawn(async move {
            let mut flush_interval = interval(Duration::from_secs(client.config.flush_interval_secs));
            
            loop {
                flush_interval.tick().await;
                if let Err(e) = client.flush().await {
                    tracing::error!("Failed to flush observability buffer: {}", e);
                }
            }
        });
        
        tracing::info!("Observability flush task started");
    }
    
    /// Add an event to the buffer
    pub async fn add_event(&self, event: impl Serialize) -> Result<(), String> {
        if !self.config.enabled {
            return Ok(());
        }
        
        let json_value = serde_json::to_value(event)
            .map_err(|e| format!("Failed to serialize event: {}", e))?;
        
        let mut buffer = self.buffer.lock().await;
        
        // Drop oldest events if buffer is full
        if buffer.len() >= self.config.buffer_size {
            buffer.pop_front();
        }
        
        buffer.push_back(json_value);
        
        // Trigger flush if batch size reached
        if buffer.len() >= self.config.batch_size {
            drop(buffer); // Release lock before flush
            self.flush().await?;
        }
        
        Ok(())
    }
    
    /// Flush buffered events to OpenSearch
    pub async fn flush(&self) -> Result<(), String> {
        if !self.config.enabled {
            return Ok(());
        }
        
        let opensearch_url = self.config.opensearch_url.as_ref()
            .ok_or_else(|| "OpenSearch URL not configured".to_string())?;
        
        let api_token = self.config.api_token.as_ref()
            .ok_or_else(|| "OpenSearch API token not configured".to_string())?;
        
        let mut buffer = self.buffer.lock().await;
        
        if buffer.is_empty() {
            return Ok(());
        }
        
        // Take up to batch_size events
        let mut batch = Vec::new();
        for _ in 0..self.config.batch_size.min(buffer.len()) {
            if let Some(event) = buffer.pop_front() {
                batch.push(event);
            }
        }
        
        drop(buffer); // Release lock during network call
        
        // Send batch to OpenSearch
        let index_name = format!("{}-{}", 
            self.config.index_prefix, 
            Utc::now().format("%Y.%m.%d")
        );
        
        // Use OpenSearch bulk API format
        let url = format!("{}/_bulk", opensearch_url);
        
        // Convert events to OpenSearch bulk format (NDJSON)
        let mut bulk_body = String::new();
        for event in &batch {
            // Action line
            bulk_body.push_str(&format!(r#"{{"index":{{"_index":"{}"}}}}"#, index_name));
            bulk_body.push('\n');
            // Document line
            bulk_body.push_str(&serde_json::to_string(event).unwrap());
            bulk_body.push('\n');
        }
        
        let response = self.http_client
            .post(&url)
            .header("Authorization", format!("Bearer {}", api_token))
            .header("Content-Type", "application/x-ndjson")
            .body(bulk_body)
            .send()
            .await
            .map_err(|e| format!("Failed to send events to OpenSearch: {}", e))?;
        
        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
            
            // Re-add events to buffer on failure
            let mut buffer = self.buffer.lock().await;
            for event in batch.into_iter().rev() {
                buffer.push_front(event);
            }
            
            return Err(format!("OpenSearch returned error {}: {}", status, body));
        }
        
        tracing::debug!("Flushed {} events to OpenSearch", batch.len());
        Ok(())
    }
    
    /// Record a metric event
    pub async fn record_metric(&self, metric: MetricEvent) {
        if let Err(e) = self.add_event(metric).await {
            tracing::debug!("Failed to record metric: {}", e);
        }
    }
    
    /// Record a log event
    #[allow(dead_code)]
    pub async fn record_log(&self, log: LogEvent) {
        if let Err(e) = self.add_event(log).await {
            tracing::debug!("Failed to record log: {}", e);
        }
    }
}

/// Middleware for tracking request metrics
use axum::{
    extract::{Request, State},
    middleware::Next,
    response::Response,
};
use std::time::Instant;

pub async fn metrics_middleware(
    State(client): State<Arc<ObservabilityClient>>,
    request: Request,
    next: Next,
) -> Response {
    let start = Instant::now();
    let correlation_id = Uuid::new_v4().to_string();
    let method = request.method().to_string();
    let path = request.uri().path().to_string();
    
    // Add correlation ID to request extensions
    let mut request = request;
    request.extensions_mut().insert(correlation_id.clone());
    
    // Process request
    let response = next.run(request).await;
    
    // Record metric
    let duration_ms = start.elapsed().as_millis() as u64;
    let status_code = response.status().as_u16();
    
    let metric = MetricEvent {
        timestamp: Utc::now(),
        event_type: "api_request".to_string(),
        correlation_id,
        method,
        path,
        status_code,
        duration_ms,
        client_ip: None, // Would be extracted from headers
        api_key_id: None, // Would be extracted from auth
        database: None,
        query_type: None,
        error: if status_code >= 400 {
            Some(format!("HTTP {}", status_code))
        } else {
            None
        },
    };
    
    client.record_metric(metric).await;
    
    response
}

/// Helper to create correlation ID for tracing
#[allow(dead_code)]
pub fn generate_correlation_id() -> String {
    Uuid::new_v4().to_string()
}

/// Helper to get correlation ID from request
#[allow(dead_code)]
pub fn get_correlation_id(request: &Request) -> Option<String> {
    request.extensions()
        .get::<String>()
        .cloned()
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_config_from_env() {
        // Test default config
        let config = ObservabilityConfig::from_env();
        assert!(!config.enabled);
        assert_eq!(config.index_prefix, "pg-api");
        assert_eq!(config.batch_size, 100);
    }
    
    #[tokio::test]
    async fn test_buffer_management() {
        let config = ObservabilityConfig {
            enabled: true,
            opensearch_url: Some("http://localhost:9200".to_string()),
            api_token: Some("test-token".to_string()),
            index_prefix: "test".to_string(),
            batch_size: 10,
            flush_interval_secs: 60,
            buffer_size: 100,
        };
        
        let client = ObservabilityClient::new(config);
        
        // Add event to buffer
        let event = serde_json::json!({
            "test": "event",
            "timestamp": Utc::now()
        });
        
        assert!(client.add_event(event).await.is_ok());
        
        // Check buffer has event
        let buffer = client.buffer.lock().await;
        assert_eq!(buffer.len(), 1);
    }
}