use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::env;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::interval;
use uuid::Uuid;
#[derive(Clone, Debug, Deserialize)]
pub struct ObservabilityConfig {
pub enabled: bool,
pub opensearch_url: Option<String>,
pub api_token: Option<String>,
pub index_prefix: String,
pub batch_size: usize,
pub flush_interval_secs: u64,
pub buffer_size: usize,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
enabled: false,
opensearch_url: None,
api_token: None,
index_prefix: "pg-api".to_string(),
batch_size: 100,
flush_interval_secs: 5,
buffer_size: 10000,
}
}
}
impl ObservabilityConfig {
pub fn from_env() -> Self {
let enabled = env::var("OPENSEARCH_ENABLED")
.unwrap_or_else(|_| "false".to_string())
.parse::<bool>()
.unwrap_or(false);
if !enabled {
return Self::default();
}
Self {
enabled,
opensearch_url: env::var("OPENSEARCH_API_URL").ok(),
api_token: env::var("OPENSEARCH_API_TOKEN").ok(),
index_prefix: env::var("OPENSEARCH_INDEX_PREFIX")
.unwrap_or_else(|_| "pg-api".to_string()),
batch_size: env::var("OPENSEARCH_BATCH_SIZE")
.unwrap_or_else(|_| "100".to_string())
.parse()
.unwrap_or(100),
flush_interval_secs: env::var("OPENSEARCH_FLUSH_INTERVAL")
.unwrap_or_else(|_| "5".to_string())
.parse()
.unwrap_or(5),
buffer_size: 10000,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ObservabilityEvent {
#[serde(rename = "@timestamp")]
pub timestamp: DateTime<Utc>,
pub event_type: String,
pub correlation_id: String,
pub service: String,
pub environment: String,
pub host: String,
#[serde(flatten)]
pub data: serde_json::Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MetricEvent {
#[serde(rename = "@timestamp")]
pub timestamp: DateTime<Utc>,
pub event_type: String,
pub correlation_id: String,
pub method: String,
pub path: String,
pub status_code: u16,
pub duration_ms: u64,
pub client_ip: Option<String>,
pub api_key_id: Option<String>,
pub database: Option<String>,
pub query_type: Option<String>,
pub error: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LogEvent {
#[serde(rename = "@timestamp")]
pub timestamp: DateTime<Utc>,
pub event_type: String,
pub correlation_id: String,
pub level: String,
pub message: String,
pub module: String,
pub fields: Option<serde_json::Value>,
}
pub struct ObservabilityClient {
config: ObservabilityConfig,
buffer: Arc<Mutex<VecDeque<serde_json::Value>>>,
http_client: reqwest::Client,
}
impl ObservabilityClient {
pub fn new(config: ObservabilityConfig) -> Self {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap_or_default();
Self {
config,
buffer: Arc::new(Mutex::new(VecDeque::with_capacity(10000))),
http_client,
}
}
pub fn start_flush_task(self: Arc<Self>) {
if !self.config.enabled {
tracing::info!("Observability disabled, not starting flush task");
return;
}
let client = self.clone();
tokio::spawn(async move {
let mut flush_interval = interval(Duration::from_secs(client.config.flush_interval_secs));
loop {
flush_interval.tick().await;
if let Err(e) = client.flush().await {
tracing::error!("Failed to flush observability buffer: {}", e);
}
}
});
tracing::info!("Observability flush task started");
}
pub async fn add_event(&self, event: impl Serialize) -> Result<(), String> {
if !self.config.enabled {
return Ok(());
}
let json_value = serde_json::to_value(event)
.map_err(|e| format!("Failed to serialize event: {}", e))?;
let mut buffer = self.buffer.lock().await;
if buffer.len() >= self.config.buffer_size {
buffer.pop_front();
}
buffer.push_back(json_value);
if buffer.len() >= self.config.batch_size {
drop(buffer); self.flush().await?;
}
Ok(())
}
pub async fn flush(&self) -> Result<(), String> {
if !self.config.enabled {
return Ok(());
}
let opensearch_url = self.config.opensearch_url.as_ref()
.ok_or_else(|| "OpenSearch URL not configured".to_string())?;
let api_token = self.config.api_token.as_ref()
.ok_or_else(|| "OpenSearch API token not configured".to_string())?;
let mut buffer = self.buffer.lock().await;
if buffer.is_empty() {
return Ok(());
}
let mut batch = Vec::new();
for _ in 0..self.config.batch_size.min(buffer.len()) {
if let Some(event) = buffer.pop_front() {
batch.push(event);
}
}
drop(buffer);
let index_name = format!("{}-{}",
self.config.index_prefix,
Utc::now().format("%Y.%m.%d")
);
let url = format!("{}/_bulk", opensearch_url);
let mut bulk_body = String::new();
for event in &batch {
bulk_body.push_str(&format!(r#"{{"index":{{"_index":"{}"}}}}"#, index_name));
bulk_body.push('\n');
bulk_body.push_str(&serde_json::to_string(event).unwrap());
bulk_body.push('\n');
}
let response = self.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_token))
.header("Content-Type", "application/x-ndjson")
.body(bulk_body)
.send()
.await
.map_err(|e| format!("Failed to send events to OpenSearch: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
let mut buffer = self.buffer.lock().await;
for event in batch.into_iter().rev() {
buffer.push_front(event);
}
return Err(format!("OpenSearch returned error {}: {}", status, body));
}
tracing::debug!("Flushed {} events to OpenSearch", batch.len());
Ok(())
}
pub async fn record_metric(&self, metric: MetricEvent) {
if let Err(e) = self.add_event(metric).await {
tracing::debug!("Failed to record metric: {}", e);
}
}
#[allow(dead_code)]
pub async fn record_log(&self, log: LogEvent) {
if let Err(e) = self.add_event(log).await {
tracing::debug!("Failed to record log: {}", e);
}
}
}
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use std::time::Instant;
pub async fn metrics_middleware(
State(client): State<Arc<ObservabilityClient>>,
request: Request,
next: Next,
) -> Response {
let start = Instant::now();
let correlation_id = Uuid::new_v4().to_string();
let method = request.method().to_string();
let path = request.uri().path().to_string();
let mut request = request;
request.extensions_mut().insert(correlation_id.clone());
let response = next.run(request).await;
let duration_ms = start.elapsed().as_millis() as u64;
let status_code = response.status().as_u16();
let metric = MetricEvent {
timestamp: Utc::now(),
event_type: "api_request".to_string(),
correlation_id,
method,
path,
status_code,
duration_ms,
client_ip: None, api_key_id: None, database: None,
query_type: None,
error: if status_code >= 400 {
Some(format!("HTTP {}", status_code))
} else {
None
},
};
client.record_metric(metric).await;
response
}
#[allow(dead_code)]
pub fn generate_correlation_id() -> String {
Uuid::new_v4().to_string()
}
#[allow(dead_code)]
pub fn get_correlation_id(request: &Request) -> Option<String> {
request.extensions()
.get::<String>()
.cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_from_env() {
let config = ObservabilityConfig::from_env();
assert!(!config.enabled);
assert_eq!(config.index_prefix, "pg-api");
assert_eq!(config.batch_size, 100);
}
#[tokio::test]
async fn test_buffer_management() {
let config = ObservabilityConfig {
enabled: true,
opensearch_url: Some("http://localhost:9200".to_string()),
api_token: Some("test-token".to_string()),
index_prefix: "test".to_string(),
batch_size: 10,
flush_interval_secs: 60,
buffer_size: 100,
};
let client = ObservabilityClient::new(config);
let event = serde_json::json!({
"test": "event",
"timestamp": Utc::now()
});
assert!(client.add_event(event).await.is_ok());
let buffer = client.buffer.lock().await;
assert_eq!(buffer.len(), 1);
}
}