lapin_pool/
lib.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::{
6    ops::{Deref, DerefMut},
7    sync::{
8        atomic::{AtomicU32, Ordering},
9        Arc,
10    },
11    time::Duration,
12};
13
14use snafu::{ResultExt, Snafu};
15use tokio::{sync::Mutex, time::interval};
16use tokio_executor_trait::Tokio as TokioExecutor;
17use tokio_reactor_trait::Tokio as TokioReactor;
18
19/// Errors that occur while using the RabbitMQ connection pool.
20#[derive(Debug, Snafu)]
21pub enum Error {
22    #[snafu(display("Failed to connect to RabbitMQ: {source}"))]
23    Connection { source: lapin::Error },
24
25    #[snafu(display("Failed to create channel: {source}"))]
26    Channel { source: lapin::Error },
27
28    #[snafu(display("Failed to close connection: {source}"))]
29    Close { source: lapin::Error },
30}
31
32/// [`lapin::Channel`] wrapper which maintains a ref counter to the channels underlying connection
33pub struct RabbitMqChannel {
34    channel: lapin::Channel,
35    ref_counter: Arc<AtomicU32>,
36}
37
38impl Drop for RabbitMqChannel {
39    fn drop(&mut self) {
40        self.ref_counter.fetch_sub(1, Ordering::Relaxed);
41    }
42}
43
44impl Deref for RabbitMqChannel {
45    type Target = lapin::Channel;
46
47    fn deref(&self) -> &Self::Target {
48        &self.channel
49    }
50}
51
52impl DerefMut for RabbitMqChannel {
53    fn deref_mut(&mut self) -> &mut Self::Target {
54        &mut self.channel
55    }
56}
57
58/// RabbitMQ connection pool which manages connection based on the amount of channels used per connection
59///
60/// Keeps a configured minimum of connections alive but creates them lazily when enough channels are requested
61pub struct RabbitMqPool {
62    url: String,
63    min_connections: u32,
64    max_channels_per_connection: u32,
65    connections: Mutex<Vec<ConnectionEntry>>,
66}
67
68struct ConnectionEntry {
69    connection: lapin::Connection,
70    channels: Arc<AtomicU32>,
71}
72
73impl RabbitMqPool {
74    /// Creates a new [`RabbitMqPool`] from given parameters
75    ///
76    /// Spawns a connection reaper task on the tokio runtime
77    pub fn from_config(
78        url: &str,
79        min_connections: u32,
80        max_channels_per_connection: u32,
81    ) -> Arc<Self> {
82        let this = Arc::new(Self {
83            url: url.into(),
84            min_connections,
85            max_channels_per_connection,
86            connections: Mutex::new(vec![]),
87        });
88
89        tokio::spawn(reap_unused_connections(this.clone()));
90
91        this
92    }
93
94    /// Create a connection with the pools given params
95    ///
96    /// This just creates a connection and does not add it to its pool. Connections will automatically be created when
97    /// creating channels.
98    pub async fn make_connection(&self) -> Result<lapin::Connection, Error> {
99        let connection = lapin::Connection::connect(
100            &self.url,
101            lapin::ConnectionProperties::default()
102                .with_executor(TokioExecutor::current())
103                .with_reactor(TokioReactor),
104        )
105        .await
106        .context(ConnectionSnafu)?;
107
108        Ok(connection)
109    }
110
111    /// Create a rabbitmq channel using one of the connections of the pool
112    ///
113    /// If there are no connections available or all connections are at the channel cap
114    /// a new connection will be created
115    pub async fn create_channel(&self) -> Result<RabbitMqChannel, Error> {
116        let mut connections = self.connections.lock().await;
117
118        let entry = connections.iter().find(|entry| {
119            entry.channels.load(Ordering::Relaxed) < self.max_channels_per_connection
120                && entry.connection.status().connected()
121        });
122
123        let entry = if let Some(entry) = entry {
124            entry
125        } else {
126            let connection = self.make_connection().await?;
127
128            let channels = Arc::new(AtomicU32::new(0));
129
130            connections.push(ConnectionEntry {
131                connection,
132                channels,
133            });
134
135            connections.last().expect("Item was just pushed.")
136        };
137
138        let channel = entry
139            .connection
140            .create_channel()
141            .await
142            .context(ChannelSnafu)?;
143
144        entry.channels.fetch_add(1, Ordering::Relaxed);
145
146        Ok(RabbitMqChannel {
147            channel,
148            ref_counter: entry.channels.clone(),
149        })
150    }
151
152    /// Close all connections managed by the pool with the given code and message
153    pub async fn close(&self, reply_code: u16, reply_message: &str) -> Result<(), Error> {
154        let mut connections = self.connections.lock().await;
155
156        for entry in connections.drain(..) {
157            entry
158                .connection
159                .close(reply_code, reply_message)
160                .await
161                .context(CloseSnafu)?;
162        }
163
164        Ok(())
165    }
166}
167
168async fn reap_unused_connections(pool: Arc<RabbitMqPool>) {
169    let mut interval = interval(Duration::from_secs(10));
170
171    loop {
172        interval.tick().await;
173
174        let mut connections = pool.connections.lock().await;
175
176        if connections.len() <= pool.min_connections as usize {
177            continue;
178        }
179
180        let removed_entries = drain_filter(&mut connections, |entry| {
181            entry.channels.load(Ordering::Relaxed) == 0 || !entry.connection.status().connected()
182        });
183
184        drop(connections);
185
186        for entry in removed_entries {
187            if entry.connection.status().connected() {
188                if let Err(e) = entry.connection.close(0, "closing").await {
189                    log::error!("Failed to close connection in gc {e}");
190                }
191            }
192        }
193    }
194}
195
196/// Placeholder for Vec::drain_filter https://doc.rust-lang.org/std/vec/struct.Vec.html#method.drain_filter
197fn drain_filter<T>(vec: &mut Vec<T>, mut predicate: impl FnMut(&T) -> bool) -> Vec<T> {
198    let mut i = 0;
199    let mut ret = Vec::new();
200    while i < vec.len() {
201        if predicate(&mut vec[i]) {
202            ret.push(vec.remove(i));
203        } else {
204            i += 1;
205        }
206    }
207    ret
208}
209
210#[cfg(test)]
211mod test {
212    use pretty_assertions::assert_eq;
213
214    #[test]
215    fn test_drain_filter() {
216        let mut items = vec![0, 1, 2, 3, 4, 5];
217
218        let mut iterations = 0;
219
220        let removed = super::drain_filter(&mut items, |i| {
221            iterations += 1;
222            *i > 2
223        });
224
225        assert_eq!(iterations, 6);
226        assert_eq!(items, vec![0, 1, 2]);
227        assert_eq!(removed, vec![3, 4, 5]);
228    }
229}