1use snafu::{ResultExt, Snafu};
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Mutex;
11use tokio::time::interval;
12use tokio_executor_trait::Tokio as TokioExecutor;
13use tokio_reactor_trait::Tokio as TokioReactor;
14
15#[derive(Debug, Snafu)]
17pub enum Error {
18 #[snafu(display("Failed to connect to RabbitMQ: {source}"))]
19 Connection { source: lapin::Error },
20
21 #[snafu(display("Failed to create channel: {source}"))]
22 Channel { source: lapin::Error },
23
24 #[snafu(display("Failed to close connection: {source}"))]
25 Close { source: lapin::Error },
26}
27
28pub struct RabbitMqChannel {
30 channel: lapin::Channel,
31 ref_counter: Arc<AtomicU32>,
32}
33
34impl Drop for RabbitMqChannel {
35 fn drop(&mut self) {
36 self.ref_counter.fetch_sub(1, Ordering::Relaxed);
37 }
38}
39
40impl Deref for RabbitMqChannel {
41 type Target = lapin::Channel;
42
43 fn deref(&self) -> &Self::Target {
44 &self.channel
45 }
46}
47
48impl DerefMut for RabbitMqChannel {
49 fn deref_mut(&mut self) -> &mut Self::Target {
50 &mut self.channel
51 }
52}
53
54pub struct RabbitMqPool {
58 url: String,
59 min_connections: u32,
60 max_channels_per_connection: u32,
61 connections: Mutex<Vec<ConnectionEntry>>,
62}
63
64struct ConnectionEntry {
65 connection: lapin::Connection,
66 channels: Arc<AtomicU32>,
67}
68
69impl RabbitMqPool {
70 pub fn from_config(
74 url: &str,
75 min_connections: u32,
76 max_channels_per_connection: u32,
77 ) -> Arc<Self> {
78 let this = Arc::new(Self {
79 url: url.into(),
80 min_connections,
81 max_channels_per_connection,
82 connections: Mutex::new(vec![]),
83 });
84
85 tokio::spawn(reap_unused_connections(this.clone()));
86
87 this
88 }
89
90 pub async fn make_connection(&self) -> Result<lapin::Connection, Error> {
95 let connection = lapin::Connection::connect(
96 &self.url,
97 lapin::ConnectionProperties::default()
98 .with_executor(TokioExecutor::current())
99 .with_reactor(TokioReactor),
100 )
101 .await
102 .context(ConnectionSnafu)?;
103
104 Ok(connection)
105 }
106
107 pub async fn create_channel(&self) -> Result<RabbitMqChannel, Error> {
112 let mut connections = self.connections.lock().await;
113
114 let entry = connections.iter().find(|entry| {
115 entry.channels.load(Ordering::Relaxed) < self.max_channels_per_connection
116 && entry.connection.status().connected()
117 });
118
119 let entry = if let Some(entry) = entry {
120 entry
121 } else {
122 let connection = self.make_connection().await?;
123
124 let channels = Arc::new(AtomicU32::new(0));
125
126 connections.push(ConnectionEntry {
127 connection,
128 channels,
129 });
130
131 connections.last().expect("Item was just pushed.")
132 };
133
134 let channel = entry
135 .connection
136 .create_channel()
137 .await
138 .context(ChannelSnafu)?;
139
140 entry.channels.fetch_add(1, Ordering::Relaxed);
141
142 Ok(RabbitMqChannel {
143 channel,
144 ref_counter: entry.channels.clone(),
145 })
146 }
147
148 pub async fn close(&self, reply_code: u16, reply_message: &str) -> Result<(), Error> {
150 let mut connections = self.connections.lock().await;
151
152 for entry in connections.drain(..) {
153 entry
154 .connection
155 .close(reply_code, reply_message)
156 .await
157 .context(CloseSnafu)?;
158 }
159
160 Ok(())
161 }
162}
163
164async fn reap_unused_connections(pool: Arc<RabbitMqPool>) {
165 let mut interval = interval(Duration::from_secs(10));
166
167 loop {
168 interval.tick().await;
169
170 let mut connections = pool.connections.lock().await;
171
172 if connections.len() <= pool.min_connections as usize {
173 continue;
174 }
175
176 let removed_entries = drain_filter(&mut connections, |entry| {
177 entry.channels.load(Ordering::Relaxed) == 0 || !entry.connection.status().connected()
178 });
179
180 drop(connections);
181
182 for entry in removed_entries {
183 if entry.connection.status().connected() {
184 if let Err(e) = entry.connection.close(0, "closing").await {
185 log::error!("Failed to close connection in gc {}", e);
186 }
187 }
188 }
189 }
190}
191
192fn drain_filter<T>(vec: &mut Vec<T>, mut predicate: impl FnMut(&T) -> bool) -> Vec<T> {
194 let mut i = 0;
195 let mut ret = Vec::new();
196 while i < vec.len() {
197 if predicate(&mut vec[i]) {
198 ret.push(vec.remove(i));
199 } else {
200 i += 1;
201 }
202 }
203 ret
204}
205
206#[cfg(test)]
207mod test {
208 use pretty_assertions::assert_eq;
209
210 #[test]
211 fn test_drain_filter() {
212 let mut items = vec![0, 1, 2, 3, 4, 5];
213
214 let mut iterations = 0;
215
216 let removed = super::drain_filter(&mut items, |i| {
217 iterations += 1;
218 *i > 2
219 });
220
221 assert_eq!(iterations, 6);
222 assert_eq!(items, vec![0, 1, 2]);
223 assert_eq!(removed, vec![3, 4, 5]);
224 }
225}