Skip to main content

lapin_pool/
lib.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::{
6    ops::{Deref, DerefMut},
7    sync::{
8        Arc,
9        atomic::{AtomicU32, Ordering},
10    },
11    time::Duration,
12};
13
14use snafu::{ResultExt, Snafu};
15use tokio::{sync::Mutex, time::interval};
16
17/// Errors that occur while using the RabbitMQ connection pool.
18#[derive(Debug, Snafu)]
19pub enum Error {
20    #[snafu(display("Failed to connect to RabbitMQ: {source}"))]
21    Connection { source: lapin::Error },
22
23    #[snafu(display("Failed to create channel: {source}"))]
24    Channel { source: lapin::Error },
25
26    #[snafu(display("Failed to close connection: {source}"))]
27    Close { source: lapin::Error },
28}
29
30/// [`lapin::Channel`] wrapper which maintains a ref counter to the channels underlying connection
31pub struct RabbitMqChannel {
32    channel: lapin::Channel,
33    ref_counter: Arc<AtomicU32>,
34}
35
36impl Drop for RabbitMqChannel {
37    fn drop(&mut self) {
38        self.ref_counter.fetch_sub(1, Ordering::Relaxed);
39    }
40}
41
42impl Deref for RabbitMqChannel {
43    type Target = lapin::Channel;
44
45    fn deref(&self) -> &Self::Target {
46        &self.channel
47    }
48}
49
50impl DerefMut for RabbitMqChannel {
51    fn deref_mut(&mut self) -> &mut Self::Target {
52        &mut self.channel
53    }
54}
55
56/// RabbitMQ connection pool which manages connection based on the amount of channels used per connection
57///
58/// Keeps a configured minimum of connections alive but creates them lazily when enough channels are requested
59pub struct RabbitMqPool {
60    url: String,
61    min_connections: u32,
62    max_channels_per_connection: u32,
63    connections: Mutex<Vec<ConnectionEntry>>,
64}
65
66struct ConnectionEntry {
67    connection: lapin::Connection,
68    channels: Arc<AtomicU32>,
69}
70
71impl RabbitMqPool {
72    /// Creates a new [`RabbitMqPool`] from given parameters
73    ///
74    /// Spawns a connection reaper task on the tokio runtime
75    pub fn from_config(
76        url: &str,
77        min_connections: u32,
78        max_channels_per_connection: u32,
79    ) -> Arc<Self> {
80        let this = Arc::new(Self {
81            url: url.into(),
82            min_connections,
83            max_channels_per_connection,
84            connections: Mutex::new(vec![]),
85        });
86
87        tokio::spawn(reap_unused_connections(this.clone()));
88
89        this
90    }
91
92    /// Create a connection with the pools given params
93    ///
94    /// This just creates a connection and does not add it to its pool. Connections will automatically be created when
95    /// creating channels.
96    pub async fn make_connection(&self) -> Result<lapin::Connection, Error> {
97        let connection = lapin::Connection::connect(
98            &self.url,
99            lapin::ConnectionProperties::default().enable_auto_recover(),
100        )
101        .await
102        .context(ConnectionSnafu)?;
103
104        Ok(connection)
105    }
106
107    /// Create a rabbitmq channel using one of the connections of the pool
108    ///
109    /// If there are no connections available or all connections are at the channel cap
110    /// a new connection will be created
111    pub async fn create_channel(&self) -> Result<RabbitMqChannel, Error> {
112        let mut connections = self.connections.lock().await;
113
114        let entry = connections.iter().find(|entry| {
115            entry.channels.load(Ordering::Relaxed) < self.max_channels_per_connection
116                && entry.connection.status().connected()
117        });
118
119        let entry = if let Some(entry) = entry {
120            entry
121        } else {
122            let connection = self.make_connection().await?;
123
124            let channels = Arc::new(AtomicU32::new(0));
125
126            connections.push(ConnectionEntry {
127                connection,
128                channels,
129            });
130
131            connections.last().expect("Item was just pushed.")
132        };
133
134        let channel = entry
135            .connection
136            .create_channel()
137            .await
138            .context(ChannelSnafu)?;
139
140        entry.channels.fetch_add(1, Ordering::Relaxed);
141
142        Ok(RabbitMqChannel {
143            channel,
144            ref_counter: entry.channels.clone(),
145        })
146    }
147
148    /// Close all connections managed by the pool with the given code and message
149    pub async fn close(&self, reply_code: u16, reply_message: &str) -> Result<(), Error> {
150        let mut connections = self.connections.lock().await;
151
152        for entry in connections.drain(..) {
153            entry
154                .connection
155                .close(reply_code, reply_message.into())
156                .await
157                .context(CloseSnafu)?;
158        }
159
160        Ok(())
161    }
162}
163
164async fn reap_unused_connections(pool: Arc<RabbitMqPool>) {
165    let mut interval = interval(Duration::from_secs(10));
166
167    loop {
168        interval.tick().await;
169
170        let mut connections = pool.connections.lock().await;
171
172        if connections.len() <= pool.min_connections as usize {
173            continue;
174        }
175
176        let removed_entries = drain_filter(&mut connections, |entry| {
177            entry.channels.load(Ordering::Relaxed) == 0 || !entry.connection.status().connected()
178        });
179
180        drop(connections);
181
182        for entry in removed_entries {
183            if entry.connection.status().connected()
184                && let Err(e) = entry.connection.close(0, "closing".into()).await
185            {
186                log::error!("Failed to close connection in gc {e}");
187            }
188        }
189    }
190}
191
192/// Placeholder for Vec::drain_filter https://doc.rust-lang.org/std/vec/struct.Vec.html#method.drain_filter
193fn drain_filter<T>(vec: &mut Vec<T>, mut predicate: impl FnMut(&T) -> bool) -> Vec<T> {
194    let mut i = 0;
195    let mut ret = Vec::new();
196    while i < vec.len() {
197        if predicate(&mut vec[i]) {
198            ret.push(vec.remove(i));
199        } else {
200            i += 1;
201        }
202    }
203    ret
204}
205
206#[cfg(test)]
207mod test {
208    use pretty_assertions::assert_eq;
209
210    #[test]
211    fn test_drain_filter() {
212        let mut items = vec![0, 1, 2, 3, 4, 5];
213
214        let mut iterations = 0;
215
216        let removed = super::drain_filter(&mut items, |i| {
217            iterations += 1;
218            *i > 2
219        });
220
221        assert_eq!(iterations, 6);
222        assert_eq!(items, vec![0, 1, 2]);
223        assert_eq!(removed, vec![3, 4, 5]);
224    }
225}