1extern crate postgres;
45extern crate rand;
46
47#[macro_use]
48#[cfg(test)]
49extern crate lazy_static;
50
51use std::iter::IntoIterator;
52use std::{fmt, mem};
53
54use postgres::row::Row;
55use postgres::types::ToSql;
56use postgres::Client;
57use rand::thread_rng;
58use rand::RngCore;
59
60struct Hex<'a>(&'a [u8]);
61impl<'a> fmt::Display for Hex<'a> {
62 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63 for byte in self.0 {
64 write!(f, "{:02x}", byte)?;
65 }
66
67 Ok(())
68 }
69}
70
71pub struct Cursor<'client> {
76 client: &'client mut Client,
77 closed: bool,
78 cursor_name: String,
79 fetch_query: String,
80 batch_size: u32,
81}
82
83impl<'client> Cursor<'client> {
84 fn new<'c, 'a, D>(builder: Builder<'c, 'a, D>) -> Result<Cursor<'c>, postgres::Error>
85 where
86 D: fmt::Display + ?Sized,
87 {
88 let mut bytes: [u8; 8] = unsafe { *mem::MaybeUninit::uninit().assume_init_ref() };
89 let mut rng = thread_rng();
90 rng.fill_bytes(&mut bytes[..]);
91
92 let cursor_name = format!("cursor:{}:{}", builder.tag, Hex(&bytes));
93 let query = format!("DECLARE \"{}\" CURSOR FOR {}", cursor_name, builder.query);
94 let fetch_query = format!("FETCH {} FROM \"{}\"", builder.batch_size, cursor_name);
95
96 builder.client.execute("BEGIN", &[])?;
97 builder.client.execute(&query[..], builder.params)?;
98
99 Ok(Cursor {
100 closed: false,
101 client: builder.client,
102 cursor_name,
103 fetch_query,
104 batch_size: builder.batch_size,
105 })
106 }
107
108 pub fn build<'b>(client: &'b mut Client) -> Builder<'b, 'static, str> {
109 Builder::<str>::new(client)
110 }
111}
112
113pub struct Iter<'b, 'a: 'b> {
115 cursor: &'b mut Cursor<'a>,
116}
117
118impl<'b, 'a: 'b> Iterator for Iter<'b, 'a> {
119 type Item = Result<Vec<Row>, postgres::Error>;
120
121 fn next(&mut self) -> Option<Result<Vec<Row>, postgres::Error>> {
122 if self.cursor.closed {
123 None
124 } else {
125 Some(self.cursor.next_batch())
126 }
127 }
128}
129
130impl<'a, 'client> IntoIterator for &'a mut Cursor<'client> {
131 type Item = Result<Vec<Row>, postgres::Error>;
132 type IntoIter = Iter<'a, 'client>;
133
134 fn into_iter(self) -> Iter<'a, 'client> {
135 self.iter()
136 }
137}
138
139impl<'a> Cursor<'a> {
140 pub fn iter<'b>(&'b mut self) -> Iter<'b, 'a> {
141 Iter { cursor: self }
142 }
143
144 fn next_batch(&mut self) -> Result<Vec<Row>, postgres::Error> {
145 let rows = self.client.query(&self.fetch_query[..], &[])?;
146 if rows.len() < (self.batch_size as usize) {
147 self.close()?;
148 }
149 Ok(rows)
150 }
151
152 fn close(&mut self) -> Result<(), postgres::Error> {
153 if !self.closed {
154 let close_query = format!("CLOSE \"{}\"", self.cursor_name);
155 self.client.execute(&close_query[..], &[])?;
156 self.client.execute("COMMIT", &[])?;
157 self.closed = true;
158 }
159
160 Ok(())
161 }
162}
163
164impl<'a> Drop for Cursor<'a> {
165 fn drop(&mut self) {
166 let _ = self.close();
167 }
168}
169
170pub struct Builder<'client, 'builder, D: ?Sized + 'builder> {
174 batch_size: u32,
175 query: &'builder str,
176 client: &'client mut Client,
177 tag: &'builder D,
178 params: &'builder [&'builder (dyn ToSql + Sync)],
179}
180
181impl<'client, 'builder, D: fmt::Display + ?Sized + 'builder> Builder<'client, 'builder, D> {
182 fn new<'c>(client: &'c mut Client) -> Builder<'c, 'static, str> {
183 Builder {
184 client,
185 batch_size: 5_000,
186 query: "SELECT 1 as one",
187 tag: "default",
188 params: &[],
189 }
190 }
191
192 pub fn query_params(mut self, params: &'builder [&'builder (dyn ToSql + Sync)]) -> Self {
194 self.params = params;
195 self
196 }
197
198 pub fn batch_size(mut self, batch_size: u32) -> Self {
202 self.batch_size = batch_size;
203 self
204 }
205
206 pub fn tag<D2: fmt::Display + ?Sized>(
260 self,
261 tag: &'builder D2,
262 ) -> Builder<'client, 'builder, D2> {
263 Builder {
264 batch_size: self.batch_size,
265 query: self.query,
266 client: self.client,
267 tag,
268 params: self.params,
269 }
270 }
271
272 pub fn query(mut self, query: &'builder str) -> Self {
276 self.query = query;
277 self
278 }
279
280 pub fn finalize(self) -> Result<Cursor<'client>, postgres::Error> {
282 Cursor::new(self)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use std::sync::Mutex;
289
290 use super::Cursor;
291 use postgres::Client;
292 use postgres::NoTls;
293
294 lazy_static! {
295 static ref LOCK: Mutex<u8> = Mutex::new(0);
296 }
297
298 fn synchronized<F: FnOnce() -> T, T>(func: F) -> T {
299 let _guard = LOCK.lock().unwrap_or_else(|e| e.into_inner());
300 func()
301 }
302
303 fn with_items<F: FnOnce(&mut Client) -> T, T>(items: i32, func: F) -> T {
304 synchronized(|| {
305 let mut client = get_client();
306 client
307 .execute("TRUNCATE TABLE products", &[])
308 .expect("truncate");
309 for i in 0..items {
311 client
312 .execute("INSERT INTO products (id) VALUES ($1)", &[&i])
313 .expect("insert");
314 }
315 func(&mut client)
316 })
317 }
318
319 fn get_client() -> Client {
320 Client::connect(
321 "postgres://postgres@127.0.0.1/postgresql_cursor_test",
322 NoTls,
323 )
324 .expect("connect")
325 }
326
327 #[test]
328 fn test_framework_works() {
329 let count = 183;
330 with_items(count, |client| {
331 for row in &client.query("SELECT COUNT(*) FROM products", &[]).unwrap() {
332 let got: i64 = row.get(0);
333 assert_eq!(got, count as i64);
334 }
335 });
336 }
337
338 #[test]
339 fn cursor_iter_works_when_batch_size_divisible() {
340 with_items(200, |client| {
341 let mut cursor = Cursor::build(client)
342 .batch_size(10)
343 .query("SELECT id FROM products")
344 .finalize()
345 .unwrap();
346
347 let mut got = 0;
348 for batch in &mut cursor {
349 let batch = batch.unwrap();
350 got += batch.len();
351 }
352
353 assert_eq!(got, 200);
354 });
355 }
356
357 #[test]
358 fn cursor_iter_works_when_batch_size_remainder() {
359 with_items(197, |client| {
360 let mut cursor = Cursor::build(client)
361 .batch_size(10)
362 .query("SELECT id FROM products")
363 .finalize()
364 .unwrap();
365
366 let mut got = 0;
367 for batch in &mut cursor {
368 let batch = batch.unwrap();
369 got += batch.len();
370 }
371
372 assert_eq!(got, 197);
373 });
374 }
375
376 #[test]
377 fn build_cursor_with_tag() {
378 with_items(1, |client| {
379 {
380 let cursor = Cursor::build(client).tag("foobar").finalize().unwrap();
381
382 assert!(cursor.cursor_name.starts_with("cursor:foobar"));
383 }
384
385 struct Foo;
386 use std::fmt;
387 impl fmt::Display for Foo {
388 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
389 write!(f, "foo-{}", 1)
390 }
391 }
392
393 {
394 let foo = Foo;
395 let cursor = Cursor::build(client).tag(&foo).finalize().unwrap();
396
397 println!("{}", cursor.cursor_name);
398 assert!(cursor.cursor_name.starts_with("cursor:foo-1"));
399 }
400 });
401 }
402
403 #[test]
404 fn cursor_with_long_tag() {
405 with_items(100, |client| {
406 let mut cursor = Cursor::build(client)
407 .tag("really-long-tag-damn-that-was-only-three-words-foo-bar-baz")
408 .query("SELECT id FROM products")
409 .finalize()
410 .unwrap();
411
412 let mut got = 0;
413 for batch in &mut cursor {
414 let batch = batch.unwrap();
415 got += batch.len();
416 }
417
418 assert_eq!(got, 100);
419 });
420 }
421
422 #[test]
423 fn cursor_with_params() {
424 with_items(100, |client| {
425 let mut cursor = Cursor::build(client)
426 .query("SELECT id FROM products WHERE id > $1 AND id < $2")
427 .query_params(&[&1, &10])
428 .finalize()
429 .unwrap();
430
431 let mut got = 0;
432 for batch in &mut cursor {
433 let batch = batch.unwrap();
434 got += batch.len();
435 }
436
437 assert_eq!(got, 8);
438 });
439 }
440}