1use std::{
6 ops::{Deref, DerefMut},
7 sync::{
8 atomic::{AtomicU32, Ordering},
9 Arc,
10 },
11 time::Duration,
12};
13
14use snafu::{ResultExt, Snafu};
15use tokio::{sync::Mutex, time::interval};
16use tokio_executor_trait::Tokio as TokioExecutor;
17use tokio_reactor_trait::Tokio as TokioReactor;
18
19#[derive(Debug, Snafu)]
21pub enum Error {
22 #[snafu(display("Failed to connect to RabbitMQ: {source}"))]
23 Connection { source: lapin::Error },
24
25 #[snafu(display("Failed to create channel: {source}"))]
26 Channel { source: lapin::Error },
27
28 #[snafu(display("Failed to close connection: {source}"))]
29 Close { source: lapin::Error },
30}
31
32pub struct RabbitMqChannel {
34 channel: lapin::Channel,
35 ref_counter: Arc<AtomicU32>,
36}
37
38impl Drop for RabbitMqChannel {
39 fn drop(&mut self) {
40 self.ref_counter.fetch_sub(1, Ordering::Relaxed);
41 }
42}
43
44impl Deref for RabbitMqChannel {
45 type Target = lapin::Channel;
46
47 fn deref(&self) -> &Self::Target {
48 &self.channel
49 }
50}
51
52impl DerefMut for RabbitMqChannel {
53 fn deref_mut(&mut self) -> &mut Self::Target {
54 &mut self.channel
55 }
56}
57
58pub struct RabbitMqPool {
62 url: String,
63 min_connections: u32,
64 max_channels_per_connection: u32,
65 connections: Mutex<Vec<ConnectionEntry>>,
66}
67
68struct ConnectionEntry {
69 connection: lapin::Connection,
70 channels: Arc<AtomicU32>,
71}
72
73impl RabbitMqPool {
74 pub fn from_config(
78 url: &str,
79 min_connections: u32,
80 max_channels_per_connection: u32,
81 ) -> Arc<Self> {
82 let this = Arc::new(Self {
83 url: url.into(),
84 min_connections,
85 max_channels_per_connection,
86 connections: Mutex::new(vec![]),
87 });
88
89 tokio::spawn(reap_unused_connections(this.clone()));
90
91 this
92 }
93
94 pub async fn make_connection(&self) -> Result<lapin::Connection, Error> {
99 let connection = lapin::Connection::connect(
100 &self.url,
101 lapin::ConnectionProperties::default()
102 .with_executor(TokioExecutor::current())
103 .with_reactor(TokioReactor),
104 )
105 .await
106 .context(ConnectionSnafu)?;
107
108 Ok(connection)
109 }
110
111 pub async fn create_channel(&self) -> Result<RabbitMqChannel, Error> {
116 let mut connections = self.connections.lock().await;
117
118 let entry = connections.iter().find(|entry| {
119 entry.channels.load(Ordering::Relaxed) < self.max_channels_per_connection
120 && entry.connection.status().connected()
121 });
122
123 let entry = if let Some(entry) = entry {
124 entry
125 } else {
126 let connection = self.make_connection().await?;
127
128 let channels = Arc::new(AtomicU32::new(0));
129
130 connections.push(ConnectionEntry {
131 connection,
132 channels,
133 });
134
135 connections.last().expect("Item was just pushed.")
136 };
137
138 let channel = entry
139 .connection
140 .create_channel()
141 .await
142 .context(ChannelSnafu)?;
143
144 entry.channels.fetch_add(1, Ordering::Relaxed);
145
146 Ok(RabbitMqChannel {
147 channel,
148 ref_counter: entry.channels.clone(),
149 })
150 }
151
152 pub async fn close(&self, reply_code: u16, reply_message: &str) -> Result<(), Error> {
154 let mut connections = self.connections.lock().await;
155
156 for entry in connections.drain(..) {
157 entry
158 .connection
159 .close(reply_code, reply_message)
160 .await
161 .context(CloseSnafu)?;
162 }
163
164 Ok(())
165 }
166}
167
168async fn reap_unused_connections(pool: Arc<RabbitMqPool>) {
169 let mut interval = interval(Duration::from_secs(10));
170
171 loop {
172 interval.tick().await;
173
174 let mut connections = pool.connections.lock().await;
175
176 if connections.len() <= pool.min_connections as usize {
177 continue;
178 }
179
180 let removed_entries = drain_filter(&mut connections, |entry| {
181 entry.channels.load(Ordering::Relaxed) == 0 || !entry.connection.status().connected()
182 });
183
184 drop(connections);
185
186 for entry in removed_entries {
187 if entry.connection.status().connected() {
188 if let Err(e) = entry.connection.close(0, "closing").await {
189 log::error!("Failed to close connection in gc {e}");
190 }
191 }
192 }
193 }
194}
195
196fn drain_filter<T>(vec: &mut Vec<T>, mut predicate: impl FnMut(&T) -> bool) -> Vec<T> {
198 let mut i = 0;
199 let mut ret = Vec::new();
200 while i < vec.len() {
201 if predicate(&mut vec[i]) {
202 ret.push(vec.remove(i));
203 } else {
204 i += 1;
205 }
206 }
207 ret
208}
209
210#[cfg(test)]
211mod test {
212 use pretty_assertions::assert_eq;
213
214 #[test]
215 fn test_drain_filter() {
216 let mut items = vec![0, 1, 2, 3, 4, 5];
217
218 let mut iterations = 0;
219
220 let removed = super::drain_filter(&mut items, |i| {
221 iterations += 1;
222 *i > 2
223 });
224
225 assert_eq!(iterations, 6);
226 assert_eq!(items, vec![0, 1, 2]);
227 assert_eq!(removed, vec![3, 4, 5]);
228 }
229}