lapin_pool/
lib.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5use snafu::{ResultExt, Snafu};
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Mutex;
11use tokio::time::interval;
12use tokio_executor_trait::Tokio as TokioExecutor;
13use tokio_reactor_trait::Tokio as TokioReactor;
14
15/// Errors that occur while using the RabbitMQ connection pool.
16#[derive(Debug, Snafu)]
17pub enum Error {
18    #[snafu(display("Failed to connect to RabbitMQ: {source}"))]
19    Connection { source: lapin::Error },
20
21    #[snafu(display("Failed to create channel: {source}"))]
22    Channel { source: lapin::Error },
23
24    #[snafu(display("Failed to close connection: {source}"))]
25    Close { source: lapin::Error },
26}
27
28/// [`lapin::Channel`] wrapper which maintains a ref counter to the channels underlying connection
29pub struct RabbitMqChannel {
30    channel: lapin::Channel,
31    ref_counter: Arc<AtomicU32>,
32}
33
34impl Drop for RabbitMqChannel {
35    fn drop(&mut self) {
36        self.ref_counter.fetch_sub(1, Ordering::Relaxed);
37    }
38}
39
40impl Deref for RabbitMqChannel {
41    type Target = lapin::Channel;
42
43    fn deref(&self) -> &Self::Target {
44        &self.channel
45    }
46}
47
48impl DerefMut for RabbitMqChannel {
49    fn deref_mut(&mut self) -> &mut Self::Target {
50        &mut self.channel
51    }
52}
53
54/// RabbitMQ connection pool which manages connection based on the amount of channels used per connection
55///
56/// Keeps a configured minimum of connections alive but creates them lazily when enough channels are requested
57pub struct RabbitMqPool {
58    url: String,
59    min_connections: u32,
60    max_channels_per_connection: u32,
61    connections: Mutex<Vec<ConnectionEntry>>,
62}
63
64struct ConnectionEntry {
65    connection: lapin::Connection,
66    channels: Arc<AtomicU32>,
67}
68
69impl RabbitMqPool {
70    /// Creates a new [`RabbitMqPool`] from given parameters
71    ///
72    /// Spawns a connection reaper task on the tokio runtime
73    pub fn from_config(
74        url: &str,
75        min_connections: u32,
76        max_channels_per_connection: u32,
77    ) -> Arc<Self> {
78        let this = Arc::new(Self {
79            url: url.into(),
80            min_connections,
81            max_channels_per_connection,
82            connections: Mutex::new(vec![]),
83        });
84
85        tokio::spawn(reap_unused_connections(this.clone()));
86
87        this
88    }
89
90    /// Create a connection with the pools given params
91    ///
92    /// This just creates a connection and does not add it to its pool. Connections will automatically be created when
93    /// creating channels.
94    pub async fn make_connection(&self) -> Result<lapin::Connection, Error> {
95        let connection = lapin::Connection::connect(
96            &self.url,
97            lapin::ConnectionProperties::default()
98                .with_executor(TokioExecutor::current())
99                .with_reactor(TokioReactor),
100        )
101        .await
102        .context(ConnectionSnafu)?;
103
104        Ok(connection)
105    }
106
107    /// Create a rabbitmq channel using one of the connections of the pool
108    ///
109    /// If there are no connections available or all connections are at the channel cap
110    /// a new connection will be created
111    pub async fn create_channel(&self) -> Result<RabbitMqChannel, Error> {
112        let mut connections = self.connections.lock().await;
113
114        let entry = connections.iter().find(|entry| {
115            entry.channels.load(Ordering::Relaxed) < self.max_channels_per_connection
116                && entry.connection.status().connected()
117        });
118
119        let entry = if let Some(entry) = entry {
120            entry
121        } else {
122            let connection = self.make_connection().await?;
123
124            let channels = Arc::new(AtomicU32::new(0));
125
126            connections.push(ConnectionEntry {
127                connection,
128                channels,
129            });
130
131            connections.last().expect("Item was just pushed.")
132        };
133
134        let channel = entry
135            .connection
136            .create_channel()
137            .await
138            .context(ChannelSnafu)?;
139
140        entry.channels.fetch_add(1, Ordering::Relaxed);
141
142        Ok(RabbitMqChannel {
143            channel,
144            ref_counter: entry.channels.clone(),
145        })
146    }
147
148    /// Close all connections managed by the pool with the given code and message
149    pub async fn close(&self, reply_code: u16, reply_message: &str) -> Result<(), Error> {
150        let mut connections = self.connections.lock().await;
151
152        for entry in connections.drain(..) {
153            entry
154                .connection
155                .close(reply_code, reply_message)
156                .await
157                .context(CloseSnafu)?;
158        }
159
160        Ok(())
161    }
162}
163
164async fn reap_unused_connections(pool: Arc<RabbitMqPool>) {
165    let mut interval = interval(Duration::from_secs(10));
166
167    loop {
168        interval.tick().await;
169
170        let mut connections = pool.connections.lock().await;
171
172        if connections.len() <= pool.min_connections as usize {
173            continue;
174        }
175
176        let removed_entries = drain_filter(&mut connections, |entry| {
177            entry.channels.load(Ordering::Relaxed) == 0 || !entry.connection.status().connected()
178        });
179
180        drop(connections);
181
182        for entry in removed_entries {
183            if entry.connection.status().connected() {
184                if let Err(e) = entry.connection.close(0, "closing").await {
185                    log::error!("Failed to close connection in gc {}", e);
186                }
187            }
188        }
189    }
190}
191
192/// Placeholder for Vec::drain_filter https://doc.rust-lang.org/std/vec/struct.Vec.html#method.drain_filter
193fn drain_filter<T>(vec: &mut Vec<T>, mut predicate: impl FnMut(&T) -> bool) -> Vec<T> {
194    let mut i = 0;
195    let mut ret = Vec::new();
196    while i < vec.len() {
197        if predicate(&mut vec[i]) {
198            ret.push(vec.remove(i));
199        } else {
200            i += 1;
201        }
202    }
203    ret
204}
205
206#[cfg(test)]
207mod test {
208    use pretty_assertions::assert_eq;
209
210    #[test]
211    fn test_drain_filter() {
212        let mut items = vec![0, 1, 2, 3, 4, 5];
213
214        let mut iterations = 0;
215
216        let removed = super::drain_filter(&mut items, |i| {
217            iterations += 1;
218            *i > 2
219        });
220
221        assert_eq!(iterations, 6);
222        assert_eq!(items, vec![0, 1, 2]);
223        assert_eq!(removed, vec![3, 4, 5]);
224    }
225}