postgres_cursor/
lib.rs

1//! Provides a Cursor abstraction for use with the `postgres` crate.
2//!
3//! # Examples
4//!
5//! ```no_run
6//! extern crate postgres;
7//! extern crate postgres_cursor;
8//!
9//! use postgres::{Client, NoTls};
10//! use postgres_cursor::Cursor;
11//!
12//! # fn main() {
13//!
14//! // First, establish a connection with postgres
15//! let mut client = Client::connect("postgres://jwilm@127.0.0.1/foo", NoTls)
16//!     .expect("connect");
17//!
18//! // Build the cursor
19//! let mut cursor = Cursor::build(&mut client)
20//!     // Batch size determines rows returned in each FETCH call
21//!     .batch_size(10)
22//!     // Query is the statement to build a cursor for
23//!     .query("SELECT id FROM products")
24//!     // Finalize turns this builder into a cursor
25//!     .finalize()
26//!     .expect("cursor creation succeeded");
27//!
28//! // Iterate over batches of rows
29//! for result in &mut cursor {
30//!     // Each item returned from the iterator is a Result<Vec<Row>, postgres::Error>.
31//!     // This is because each call to `next()` makes a query
32//!     // to the database.
33//!     let rows = result.unwrap();
34//!
35//!     // After handling errors, rows returned in this iteration
36//!     // can be iterated over.
37//!     for row in &rows {
38//!         println!("{:?}", row);
39//!     }
40//! }
41//!
42//! # }
43//! ```
44extern crate postgres;
45extern crate rand;
46
47#[macro_use]
48#[cfg(test)]
49extern crate lazy_static;
50
51use std::iter::IntoIterator;
52use std::{fmt, mem};
53
54use postgres::row::Row;
55use postgres::types::ToSql;
56use postgres::Client;
57use rand::thread_rng;
58use rand::RngCore;
59
60struct Hex<'a>(&'a [u8]);
61impl<'a> fmt::Display for Hex<'a> {
62    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63        for byte in self.0 {
64            write!(f, "{:02x}", byte)?;
65        }
66
67        Ok(())
68    }
69}
70
71/// Represents a PostgreSQL cursor.
72///
73/// The actual cursor in the database is only created and active _while_
74/// `Iter` is in scope and calls to `next()` return `Some`.
75pub struct Cursor<'client> {
76    client: &'client mut Client,
77    closed: bool,
78    cursor_name: String,
79    fetch_query: String,
80    batch_size: u32,
81}
82
83impl<'client> Cursor<'client> {
84    fn new<'c, 'a, D>(builder: Builder<'c, 'a, D>) -> Result<Cursor<'c>, postgres::Error>
85    where
86        D: fmt::Display + ?Sized,
87    {
88        let mut bytes: [u8; 8] = unsafe { *mem::MaybeUninit::uninit().assume_init_ref() };
89        let mut rng = thread_rng();
90        rng.fill_bytes(&mut bytes[..]);
91
92        let cursor_name = format!("cursor:{}:{}", builder.tag, Hex(&bytes));
93        let query = format!("DECLARE \"{}\" CURSOR FOR {}", cursor_name, builder.query);
94        let fetch_query = format!("FETCH {} FROM \"{}\"", builder.batch_size, cursor_name);
95
96        builder.client.execute("BEGIN", &[])?;
97        builder.client.execute(&query[..], builder.params)?;
98
99        Ok(Cursor {
100            closed: false,
101            client: builder.client,
102            cursor_name,
103            fetch_query,
104            batch_size: builder.batch_size,
105        })
106    }
107
108    pub fn build<'b>(client: &'b mut Client) -> Builder<'b, 'static, str> {
109        Builder::<str>::new(client)
110    }
111}
112
113/// Iterator returning `Vec<Row>` for every call to `next()`.
114pub struct Iter<'b, 'a: 'b> {
115    cursor: &'b mut Cursor<'a>,
116}
117
118impl<'b, 'a: 'b> Iterator for Iter<'b, 'a> {
119    type Item = Result<Vec<Row>, postgres::Error>;
120
121    fn next(&mut self) -> Option<Result<Vec<Row>, postgres::Error>> {
122        if self.cursor.closed {
123            None
124        } else {
125            Some(self.cursor.next_batch())
126        }
127    }
128}
129
130impl<'a, 'client> IntoIterator for &'a mut Cursor<'client> {
131    type Item = Result<Vec<Row>, postgres::Error>;
132    type IntoIter = Iter<'a, 'client>;
133
134    fn into_iter(self) -> Iter<'a, 'client> {
135        self.iter()
136    }
137}
138
139impl<'a> Cursor<'a> {
140    pub fn iter<'b>(&'b mut self) -> Iter<'b, 'a> {
141        Iter { cursor: self }
142    }
143
144    fn next_batch(&mut self) -> Result<Vec<Row>, postgres::Error> {
145        let rows = self.client.query(&self.fetch_query[..], &[])?;
146        if rows.len() < (self.batch_size as usize) {
147            self.close()?;
148        }
149        Ok(rows)
150    }
151
152    fn close(&mut self) -> Result<(), postgres::Error> {
153        if !self.closed {
154            let close_query = format!("CLOSE \"{}\"", self.cursor_name);
155            self.client.execute(&close_query[..], &[])?;
156            self.client.execute("COMMIT", &[])?;
157            self.closed = true;
158        }
159
160        Ok(())
161    }
162}
163
164impl<'a> Drop for Cursor<'a> {
165    fn drop(&mut self) {
166        let _ = self.close();
167    }
168}
169
170/// Builds a Cursor
171///
172/// This type is constructed by calling `Cursor::build`.
173pub struct Builder<'client, 'builder, D: ?Sized + 'builder> {
174    batch_size: u32,
175    query: &'builder str,
176    client: &'client mut Client,
177    tag: &'builder D,
178    params: &'builder [&'builder (dyn ToSql + Sync)],
179}
180
181impl<'client, 'builder, D: fmt::Display + ?Sized + 'builder> Builder<'client, 'builder, D> {
182    fn new<'c>(client: &'c mut Client) -> Builder<'c, 'static, str> {
183        Builder {
184            client,
185            batch_size: 5_000,
186            query: "SELECT 1 as one",
187            tag: "default",
188            params: &[],
189        }
190    }
191
192    /// Set query params for cursor creation
193    pub fn query_params(mut self, params: &'builder [&'builder (dyn ToSql + Sync)]) -> Self {
194        self.params = params;
195        self
196    }
197
198    /// Set the batch size passed to `FETCH` on each iteration.
199    ///
200    /// Default is 5,000.
201    pub fn batch_size(mut self, batch_size: u32) -> Self {
202        self.batch_size = batch_size;
203        self
204    }
205
206    /// Set the tag for cursor name.
207    ///
208    /// Adding a tag to the cursor name can be helpful for identifying where
209    /// cursors originate when viewing `pg_stat_activity`.
210    ///
211    /// Default is `default`.
212    ///
213    /// # Examples
214    ///
215    /// Any type that implements `fmt::Display` may be provided as a tag. For example, a simple
216    /// string literal is one option.
217    ///
218    /// ```no_run
219    /// # extern crate postgres;
220    /// # extern crate postgres_cursor;
221    /// # use postgres::Client;
222    /// # use postgres::NoTls;
223    /// # use postgres_cursor::Cursor;
224    /// # fn main() {
225    /// # let mut client = Client::connect("postgres://jwilm@127.0.0.1/foo", NoTls)
226    /// #     .expect("connect");
227    /// let mut cursor = Cursor::build(&mut client)
228    ///     .tag("custom-cursor-tag")
229    ///     .finalize();
230    /// # }
231    /// ```
232    ///
233    /// Or maybe you want to build a tag at run-time without incurring an extra allocation:
234    ///
235    /// ```no_run
236    /// # extern crate postgres;
237    /// # extern crate postgres_cursor;
238    /// # use postgres::Client;
239    /// # use postgres::NoTls;
240    /// # use postgres_cursor::Cursor;
241    /// # fn main() {
242    /// # let mut client = Client::connect("postgres://jwilm@127.0.0.1/foo", NoTls)
243    /// #     .expect("connect");
244    /// use std::fmt;
245    ///
246    /// struct Pid(i32);
247    /// impl fmt::Display for Pid {
248    ///     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
249    ///         write!(f, "pid-{}", self.0)
250    ///     }
251    /// }
252    ///
253    /// let tag = Pid(8123);
254    /// let mut cursor = Cursor::build(&mut client)
255    ///     .tag(&tag)
256    ///     .finalize();
257    /// # }
258    /// ```
259    pub fn tag<D2: fmt::Display + ?Sized>(
260        self,
261        tag: &'builder D2,
262    ) -> Builder<'client, 'builder, D2> {
263        Builder {
264            batch_size: self.batch_size,
265            query: self.query,
266            client: self.client,
267            tag,
268            params: self.params,
269        }
270    }
271
272    /// Set the query to create a cursor for.
273    ///
274    /// Default is `SELECT 1`.
275    pub fn query(mut self, query: &'builder str) -> Self {
276        self.query = query;
277        self
278    }
279
280    /// Turn the builder into a `Cursor`.
281    pub fn finalize(self) -> Result<Cursor<'client>, postgres::Error> {
282        Cursor::new(self)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use std::sync::Mutex;
289
290    use super::Cursor;
291    use postgres::Client;
292    use postgres::NoTls;
293
294    lazy_static! {
295        static ref LOCK: Mutex<u8> = Mutex::new(0);
296    }
297
298    fn synchronized<F: FnOnce() -> T, T>(func: F) -> T {
299        let _guard = LOCK.lock().unwrap_or_else(|e| e.into_inner());
300        func()
301    }
302
303    fn with_items<F: FnOnce(&mut Client) -> T, T>(items: i32, func: F) -> T {
304        synchronized(|| {
305            let mut client = get_client();
306            client
307                .execute("TRUNCATE TABLE products", &[])
308                .expect("truncate");
309            // Highly inefficient; should optimize.
310            for i in 0..items {
311                client
312                    .execute("INSERT INTO products (id) VALUES ($1)", &[&i])
313                    .expect("insert");
314            }
315            func(&mut client)
316        })
317    }
318
319    fn get_client() -> Client {
320        Client::connect(
321            "postgres://postgres@127.0.0.1/postgresql_cursor_test",
322            NoTls,
323        )
324        .expect("connect")
325    }
326
327    #[test]
328    fn test_framework_works() {
329        let count = 183;
330        with_items(count, |client| {
331            for row in &client.query("SELECT COUNT(*) FROM products", &[]).unwrap() {
332                let got: i64 = row.get(0);
333                assert_eq!(got, count as i64);
334            }
335        });
336    }
337
338    #[test]
339    fn cursor_iter_works_when_batch_size_divisible() {
340        with_items(200, |client| {
341            let mut cursor = Cursor::build(client)
342                .batch_size(10)
343                .query("SELECT id FROM products")
344                .finalize()
345                .unwrap();
346
347            let mut got = 0;
348            for batch in &mut cursor {
349                let batch = batch.unwrap();
350                got += batch.len();
351            }
352
353            assert_eq!(got, 200);
354        });
355    }
356
357    #[test]
358    fn cursor_iter_works_when_batch_size_remainder() {
359        with_items(197, |client| {
360            let mut cursor = Cursor::build(client)
361                .batch_size(10)
362                .query("SELECT id FROM products")
363                .finalize()
364                .unwrap();
365
366            let mut got = 0;
367            for batch in &mut cursor {
368                let batch = batch.unwrap();
369                got += batch.len();
370            }
371
372            assert_eq!(got, 197);
373        });
374    }
375
376    #[test]
377    fn build_cursor_with_tag() {
378        with_items(1, |client| {
379            {
380                let cursor = Cursor::build(client).tag("foobar").finalize().unwrap();
381
382                assert!(cursor.cursor_name.starts_with("cursor:foobar"));
383            }
384
385            struct Foo;
386            use std::fmt;
387            impl fmt::Display for Foo {
388                fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
389                    write!(f, "foo-{}", 1)
390                }
391            }
392
393            {
394                let foo = Foo;
395                let cursor = Cursor::build(client).tag(&foo).finalize().unwrap();
396
397                println!("{}", cursor.cursor_name);
398                assert!(cursor.cursor_name.starts_with("cursor:foo-1"));
399            }
400        });
401    }
402
403    #[test]
404    fn cursor_with_long_tag() {
405        with_items(100, |client| {
406            let mut cursor = Cursor::build(client)
407                .tag("really-long-tag-damn-that-was-only-three-words-foo-bar-baz")
408                .query("SELECT id FROM products")
409                .finalize()
410                .unwrap();
411
412            let mut got = 0;
413            for batch in &mut cursor {
414                let batch = batch.unwrap();
415                got += batch.len();
416            }
417
418            assert_eq!(got, 100);
419        });
420    }
421
422    #[test]
423    fn cursor_with_params() {
424        with_items(100, |client| {
425            let mut cursor = Cursor::build(client)
426                .query("SELECT id FROM products WHERE id > $1 AND id < $2")
427                .query_params(&[&1, &10])
428                .finalize()
429                .unwrap();
430
431            let mut got = 0;
432            for batch in &mut cursor {
433                let batch = batch.unwrap();
434                got += batch.len();
435            }
436
437            assert_eq!(got, 8);
438        });
439    }
440}