use crate::error::{Error, Result};
use crate::storage::{AsyncSledBackend, AsyncStorageBackend, StorageStats};
use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tokio::time::timeout;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_concurrent: usize,
pub acquire_timeout_ms: u64,
pub enable_metrics: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_concurrent: 100, acquire_timeout_ms: 5000, enable_metrics: true,
}
}
}
impl PoolConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.acquire_timeout_ms = timeout_ms;
self
}
pub fn with_metrics(mut self, enable: bool) -> Self {
self.enable_metrics = enable;
self
}
}
#[derive(Debug)]
pub struct PoolMetrics {
total_operations: AtomicU64,
successful_operations: AtomicU64,
failed_operations: AtomicU64,
timeouts: AtomicU64,
active_operations: AtomicUsize,
peak_concurrent: AtomicUsize,
total_wait_time_us: AtomicU64,
}
impl PoolMetrics {
pub fn new() -> Self {
Self {
total_operations: AtomicU64::new(0),
successful_operations: AtomicU64::new(0),
failed_operations: AtomicU64::new(0),
timeouts: AtomicU64::new(0),
active_operations: AtomicUsize::new(0),
peak_concurrent: AtomicUsize::new(0),
total_wait_time_us: AtomicU64::new(0),
}
}
fn operation_started(&self) {
self.total_operations.fetch_add(1, Ordering::Relaxed);
let active = self.active_operations.fetch_add(1, Ordering::Relaxed) + 1;
let mut peak = self.peak_concurrent.load(Ordering::Relaxed);
while active > peak {
match self.peak_concurrent.compare_exchange_weak(
peak,
active,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(p) => peak = p,
}
}
}
fn operation_completed(&self, success: bool) {
if success {
self.successful_operations.fetch_add(1, Ordering::Relaxed);
} else {
self.failed_operations.fetch_add(1, Ordering::Relaxed);
}
self.active_operations.fetch_sub(1, Ordering::Relaxed);
}
fn record_timeout(&self) {
self.timeouts.fetch_add(1, Ordering::Relaxed);
self.failed_operations.fetch_add(1, Ordering::Relaxed);
}
fn record_wait_time(&self, wait_time_us: u64) {
self.total_wait_time_us
.fetch_add(wait_time_us, Ordering::Relaxed);
}
pub fn snapshot(&self) -> PoolMetricsSnapshot {
PoolMetricsSnapshot {
total_operations: self.total_operations.load(Ordering::Relaxed),
successful_operations: self.successful_operations.load(Ordering::Relaxed),
failed_operations: self.failed_operations.load(Ordering::Relaxed),
timeouts: self.timeouts.load(Ordering::Relaxed),
active_operations: self.active_operations.load(Ordering::Relaxed),
peak_concurrent: self.peak_concurrent.load(Ordering::Relaxed),
total_wait_time_us: self.total_wait_time_us.load(Ordering::Relaxed),
}
}
}
impl Default for PoolMetrics {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolMetricsSnapshot {
pub total_operations: u64,
pub successful_operations: u64,
pub failed_operations: u64,
pub timeouts: u64,
pub active_operations: usize,
pub peak_concurrent: usize,
pub total_wait_time_us: u64,
}
impl PoolMetricsSnapshot {
pub fn avg_wait_time_ms(&self) -> f64 {
if self.total_operations == 0 {
0.0
} else {
(self.total_wait_time_us as f64) / (self.total_operations as f64) / 1000.0
}
}
pub fn success_rate(&self) -> f64 {
if self.total_operations == 0 {
1.0
} else {
(self.successful_operations as f64) / (self.total_operations as f64)
}
}
pub fn timeout_rate(&self) -> f64 {
if self.total_operations == 0 {
0.0
} else {
(self.timeouts as f64) / (self.total_operations as f64)
}
}
}
pub struct PooledAsyncBackend {
backend: Arc<AsyncSledBackend>,
semaphore: Arc<Semaphore>,
config: PoolConfig,
metrics: Arc<PoolMetrics>,
}
impl PooledAsyncBackend {
pub async fn open(path: &std::path::Path, config: PoolConfig) -> Result<Self> {
let backend = AsyncSledBackend::open(path).await?;
let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
let metrics = Arc::new(PoolMetrics::new());
Ok(Self {
backend: Arc::new(backend),
semaphore,
config,
metrics,
})
}
pub fn metrics(&self) -> PoolMetricsSnapshot {
self.metrics.snapshot()
}
pub fn config(&self) -> &PoolConfig {
&self.config
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
async fn acquire_permit(&self) -> Result<tokio::sync::SemaphorePermit<'_>> {
let start = std::time::Instant::now();
let permit = timeout(
Duration::from_millis(self.config.acquire_timeout_ms),
self.semaphore.acquire(),
)
.await
.map_err(|_| {
self.metrics.record_timeout();
Error::Storage("Pool acquire timeout".to_string())
})?
.map_err(|_| Error::Storage("Semaphore closed".to_string()))?;
let wait_time = start.elapsed().as_micros() as u64;
self.metrics.record_wait_time(wait_time);
self.metrics.operation_started();
Ok(permit)
}
async fn with_permit<F, T>(&self, f: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>>,
{
let _permit = self.acquire_permit().await?;
let result = f.await;
self.metrics.operation_completed(result.is_ok());
result
}
}
#[async_trait]
impl AsyncStorageBackend for PooledAsyncBackend {
async fn store_node(&self, node: &Node) -> Result<()> {
self.with_permit(self.backend.store_node(node)).await
}
async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
self.with_permit(self.backend.get_node(id)).await
}
async fn delete_node(&self, id: &NodeId) -> Result<()> {
self.with_permit(self.backend.delete_node(id)).await
}
async fn store_edge(&self, edge: &Edge) -> Result<()> {
self.with_permit(self.backend.store_edge(edge)).await
}
async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
self.with_permit(self.backend.get_edge(id)).await
}
async fn delete_edge(&self, id: &EdgeId) -> Result<()> {
self.with_permit(self.backend.delete_edge(id)).await
}
async fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
self.with_permit(self.backend.get_session_nodes(session_id))
.await
}
async fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
self.with_permit(self.backend.get_outgoing_edges(node_id))
.await
}
async fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
self.with_permit(self.backend.get_incoming_edges(node_id))
.await
}
async fn flush(&self) -> Result<()> {
self.with_permit(self.backend.flush()).await
}
async fn stats(&self) -> Result<StorageStats> {
self.with_permit(self.backend.stats()).await
}
async fn store_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
self.with_permit(self.backend.store_nodes_batch(nodes))
.await
}
async fn store_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
self.with_permit(self.backend.store_edges_batch(edges))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ConversationSession;
use tempfile::tempdir;
#[tokio::test]
async fn test_pool_creation() {
let dir = tempdir().unwrap();
let config = PoolConfig::new().with_max_concurrent(10);
let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
assert_eq!(pool.available_permits(), 10);
assert_eq!(pool.config().max_concurrent, 10);
}
#[tokio::test]
async fn test_pool_operations() {
let dir = tempdir().unwrap();
let config = PoolConfig::new();
let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
let session = ConversationSession::new();
let node = Node::Session(session.clone());
pool.store_node(&node).await.unwrap();
let retrieved = pool.get_node(&session.node_id).await.unwrap();
assert!(retrieved.is_some());
let metrics = pool.metrics();
assert!(metrics.total_operations >= 2); assert!(metrics.successful_operations >= 2);
}
#[tokio::test]
async fn test_concurrent_operations() {
let dir = tempdir().unwrap();
let config = PoolConfig::new().with_max_concurrent(20);
let pool = Arc::new(PooledAsyncBackend::open(dir.path(), config).await.unwrap());
let mut handles = vec![];
for _ in 0..50 {
let pool_clone = Arc::clone(&pool);
let handle = tokio::spawn(async move {
let session = ConversationSession::new();
let node = Node::Session(session);
pool_clone.store_node(&node).await
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap().unwrap();
}
let metrics = pool.metrics();
assert_eq!(metrics.total_operations, 50);
assert_eq!(metrics.successful_operations, 50);
assert!(metrics.peak_concurrent <= 20); }
#[tokio::test]
async fn test_pool_backpressure() {
let dir = tempdir().unwrap();
let config = PoolConfig::new()
.with_max_concurrent(2) .with_timeout(1000);
let pool = Arc::new(PooledAsyncBackend::open(dir.path(), config).await.unwrap());
let pool1 = Arc::clone(&pool);
let handle1 = tokio::spawn(async move {
let _permit = pool1.acquire_permit().await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
});
let pool2 = Arc::clone(&pool);
let handle2 = tokio::spawn(async move {
let _permit = pool2.acquire_permit().await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(pool.available_permits(), 0);
handle1.await.unwrap();
handle2.await.unwrap();
assert_eq!(pool.available_permits(), 2);
}
#[tokio::test]
async fn test_metrics_tracking() {
let dir = tempdir().unwrap();
let config = PoolConfig::new().with_metrics(true);
let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
for _ in 0..10 {
let session = ConversationSession::new();
pool.store_node(&Node::Session(session)).await.unwrap();
}
let metrics = pool.metrics();
assert_eq!(metrics.total_operations, 10);
assert_eq!(metrics.successful_operations, 10);
assert_eq!(metrics.failed_operations, 0);
assert!(metrics.avg_wait_time_ms() >= 0.0);
assert_eq!(metrics.success_rate(), 1.0);
}
#[tokio::test]
async fn test_batch_operations_with_pool() {
let dir = tempdir().unwrap();
let config = PoolConfig::new();
let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
let nodes: Vec<Node> = (0..10)
.map(|_| Node::Session(ConversationSession::new()))
.collect();
let ids = pool.store_nodes_batch(&nodes).await.unwrap();
assert_eq!(ids.len(), 10);
let metrics = pool.metrics();
assert_eq!(metrics.total_operations, 1);
}
}