1use std::collections::VecDeque;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use tokio::sync::{Mutex, Semaphore};
10
11use crate::driver::RedisDriver;
12use crate::error::{RedisError, RedisResult};
13
14#[derive(Debug, Clone)]
16pub struct PoolConfig {
17 pub max_connections: usize,
19 pub host: String,
21 pub port: u16,
23}
24
25impl Default for PoolConfig {
26 fn default() -> Self {
27 Self {
28 max_connections: 10,
29 host: "127.0.0.1".to_string(),
30 port: 6379,
31 }
32 }
33}
34
35impl PoolConfig {
36 pub fn new(host: impl Into<String>, port: u16) -> Self {
38 Self {
39 max_connections: 10,
40 host: host.into(),
41 port,
42 }
43 }
44
45 pub fn max_connections(mut self, n: usize) -> Self {
47 self.max_connections = n;
48 self
49 }
50}
51
52pub struct RedisPool {
54 config: PoolConfig,
55 connections: Arc<Mutex<VecDeque<RedisDriver>>>,
56 semaphore: Arc<Semaphore>,
57}
58
59impl RedisPool {
60 pub fn new(config: PoolConfig) -> Self {
62 let semaphore = Arc::new(Semaphore::new(config.max_connections));
63 Self {
64 config,
65 connections: Arc::new(Mutex::new(VecDeque::new())),
66 semaphore,
67 }
68 }
69
70 pub async fn get(&self) -> RedisResult<PooledConnection> {
72 let permit = self
74 .semaphore
75 .clone()
76 .acquire_owned()
77 .await
78 .map_err(|_| RedisError::Pool("Failed to acquire pool permit".into()))?;
79
80 let driver = {
82 let mut conns = self.connections.lock().await;
83 conns.pop_front()
84 };
85
86 let driver = match driver {
87 Some(d) => d,
88 None => {
89 RedisDriver::connect(&self.config.host, self.config.port).await?
91 }
92 };
93
94 Ok(PooledConnection {
95 driver: Some(driver),
96 pool: self.connections.clone(),
97 _permit: permit,
98 })
99 }
100}
101
102pub struct PooledConnection {
104 driver: Option<RedisDriver>,
105 pool: Arc<Mutex<VecDeque<RedisDriver>>>,
106 _permit: tokio::sync::OwnedSemaphorePermit,
107}
108
109impl Deref for PooledConnection {
110 type Target = RedisDriver;
111
112 fn deref(&self) -> &Self::Target {
113 self.driver.as_ref().unwrap()
114 }
115}
116
117impl DerefMut for PooledConnection {
118 fn deref_mut(&mut self) -> &mut Self::Target {
119 self.driver.as_mut().unwrap()
120 }
121}
122
123impl Drop for PooledConnection {
124 fn drop(&mut self) {
125 if let Some(driver) = self.driver.take() {
126 let pool = self.pool.clone();
127 tokio::spawn(async move {
128 let mut conns = pool.lock().await;
129 conns.push_back(driver);
130 });
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_pool_config_default() {
141 let config = PoolConfig::default();
142 assert_eq!(config.max_connections, 10);
143 assert_eq!(config.host, "127.0.0.1");
144 assert_eq!(config.port, 6379);
145 }
146}