1use cdbc::describe::Describe;
2use cdbc::error::Error;
3use cdbc::executor::{Execute, Executor};
4use cdbc::pool::PoolOptions;
5use cdbc::pool::{Pool, PoolConnection};
6use crate::message::{MessageFormat, Notification};
7use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
8use either::Either;
9use std::fmt::{self, Debug};
10use std::io;
11use std::str::from_utf8;
12use mco::{chan, co};
13use mco::std::sync::channel::{Receiver, Sender};
14use mco::std::sync::channel;
15use cdbc::io::chan_stream::ChanStream;
16
17pub struct PgListener {
24 pool: Pool<Postgres>,
25 connection: Option<PoolConnection<Postgres>>,
26 buffer_rx: Receiver<Notification>,
27 buffer_tx: Option<Sender<Notification>>,
28 channels: Vec<String>,
29}
30
31pub struct PgNotification(Notification);
33
34impl PgListener {
35 pub fn connect(uri: &str) -> Result<Self, Error> {
36 let pool = PoolOptions::<Postgres>::new()
39 .max_connections(1)
40 .max_lifetime(None)
41 .idle_timeout(None)
42 .connect(uri)
43 ?;
44
45 Self::connect_with(&pool)
46 }
47
48 pub fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
49 let mut connection = pool.acquire()?;
51
52 let (sender, receiver) = chan!();
54 connection.stream.notifications = Some(sender);
55
56 Ok(Self {
57 pool: pool.clone(),
58 connection: Some(connection),
59 buffer_rx: receiver,
60 buffer_tx: None,
61 channels: Vec::new(),
62 })
63 }
64
65 pub fn listen(&mut self, channel: &str) -> Result<(), Error> {
68 self.connection()
69 .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
70 ?;
71
72 self.channels.push(channel.to_owned());
73
74 Ok(())
75 }
76
77 pub fn listen_all<'a>(
79 &mut self,
80 channels: impl IntoIterator<Item = &'a str>,
81 ) -> Result<(), Error> {
82 let beg = self.channels.len();
83 self.channels.extend(channels.into_iter().map(|s| s.into()));
84
85 self.connection
86 .as_mut()
87 .unwrap()
88 .execute(&*build_listen_all_query(&self.channels[beg..]))
89 ?;
90
91 Ok(())
92 }
93
94 pub fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
97 self.connection()
98 .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
99 ?;
100
101 if let Some(pos) = self.channels.iter().position(|s| s == channel) {
102 self.channels.remove(pos);
103 }
104
105 Ok(())
106 }
107
108 pub fn unlisten_all(&mut self) -> Result<(), Error> {
110 self.connection().execute("UNLISTEN *")?;
111
112 self.channels.clear();
113
114 Ok(())
115 }
116
117 #[inline]
118 fn connect_if_needed(&mut self) -> Result<(), Error> {
119 if self.connection.is_none() {
120 let mut connection = self.pool.acquire()?;
121 connection.stream.notifications = self.buffer_tx.take();
122
123 connection
124 .execute(&*build_listen_all_query(&self.channels))
125 ?;
126
127 self.connection = Some(connection);
128 }
129
130 Ok(())
131 }
132
133 #[inline]
134 fn connection(&mut self) -> &mut PgConnection {
135 self.connection.as_mut().unwrap()
136 }
137
138 pub fn recv(&mut self) -> Result<PgNotification, Error> {
167 loop {
168 if let Some(notification) = self.try_recv()? {
169 return Ok(notification);
170 }
171 }
172 }
173
174 pub fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
200 if let Ok(notification) = self.buffer_rx.try_recv() {
203 return Ok(Some(PgNotification(notification)));
204 }
205
206 loop {
207 self.connect_if_needed()?;
209
210 let message = match self.connection().stream.recv_unchecked() {
211 Ok(message) => message,
212
213 Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
216 self.buffer_tx = self.connection().stream.notifications.take();
217 self.connection = None;
218
219 return Ok(None);
221 }
222
223 Err(error) => {
225 return Err(error);
226 }
227 };
228
229 match message.format {
230 MessageFormat::NotificationResponse => {
232 return Ok(Some(PgNotification(message.decode()?)));
233 }
234
235 MessageFormat::ReadyForQuery => {
237 self.connection().pending_ready_for_query_count -= 1;
238 }
239
240 _ => {}
242 }
243 }
244 }
245
246 pub fn into_stream(mut self) -> ChanStream<PgNotification> {
253 chan_stream!( {
254 loop {
255 r#yield!(self.recv()?);
256 }
257 })
258 }
259}
260
261impl Drop for PgListener {
262 fn drop(&mut self) {
263 if let Some(mut conn) = self.connection.take() {
264 let fut = move || {
265 let _ = conn.execute("UNLISTEN *");
266
267 conn.return_to_pool();
271 };
272
273 co!(fut);
275 }
276 }
277}
278
279impl<'c> Executor for &'c mut PgListener {
280 type Database = Postgres;
281
282 fn fetch_many<'q, E: 'q>(
283 &mut self,
284 query: E,
285 ) -> ChanStream<Either<PgQueryResult, PgRow>>
286 where
287 E: Execute<'q, Self::Database>,
288 {
289 self.connection().fetch_many(query)
290 }
291
292 fn fetch_optional<'q, E: 'q>(
293 &mut self,
294 query: E,
295 ) -> Result<Option<PgRow>, Error>
296 where E: Execute<'q, Self::Database>,
297 {
298 self.connection().fetch_optional(query)
299 }
300
301 fn prepare_with<'q>(
302 &mut self,
303 query: &'q str,
304 parameters: &'q [PgTypeInfo],
305 ) -> Result<PgStatement, Error>
306 where
307 {
308 self.connection().prepare_with(query, parameters)
309 }
310
311 #[doc(hidden)]
312 fn describe< 'q>(
313 &mut self,
314 query: &'q str,
315 ) -> Result<Describe<Self::Database>, Error>
316 where
317 {
318 self.connection().describe(query)
319 }
320}
321
322impl PgNotification {
323 #[inline]
325 pub fn process_id(&self) -> u32 {
326 self.0.process_id
327 }
328
329 #[inline]
332 pub fn channel(&self) -> &str {
333 from_utf8(&self.0.channel).unwrap()
334 }
335
336 #[inline]
339 pub fn payload(&self) -> &str {
340 from_utf8(&self.0.payload).unwrap()
341 }
342}
343
344impl Debug for PgListener {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 f.debug_struct("PgListener").finish()
347 }
348}
349
350impl Debug for PgNotification {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 f.debug_struct("PgNotification")
353 .field("process_id", &self.process_id())
354 .field("channel", &self.channel())
355 .field("payload", &self.payload())
356 .finish()
357 }
358}
359
360fn ident(mut name: &str) -> String {
361 if let Some(index) = name.find('\0') {
364 name = &name[..index];
365 }
366
367 name.replace('"', "\"\"")
369}
370
371fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
372 channels.into_iter().fold(String::new(), |mut acc, chan| {
373 acc.push_str(r#"LISTEN ""#);
374 acc.push_str(&ident(chan.as_ref()));
375 acc.push_str(r#"";"#);
376 acc
377 })
378}
379
380#[test]
381fn test_build_listen_all_query_with_single_channel() {
382 let output = build_listen_all_query(&["test"]);
383 assert_eq!(output.as_str(), r#"LISTEN "test";"#);
384}
385
386#[test]
387fn test_build_listen_all_query_with_multiple_channels() {
388 let output = build_listen_all_query(&["channel.0", "channel.1"]);
389 assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
390}