1use postgres_types::FromSql;
4use std::collections::{BTreeSet, HashSet};
5use std::fmt::{Display, Write};
6use std::hash::Hash;
7use std::error::Error as StdError;
8use std::iter;
9use std::ops::Range;
10use thiserror::Error;
11use postgres_types::WasNull;
12use tokio_postgres::{error::Error as SqlError, row::RowIndex, Column};
13
14#[derive(Debug, Error)]
16pub enum Error {
17 #[error("{msg}")]
18 Custom { msg: String },
19
20 #[error("invalid number of columns, found {found} but expected {expected}")]
21 ColumnCount { found: usize, expected: usize },
22
23 #[error("failed to get column: `{index}` (columns were: {columns})")]
24 SliceLookup { index: String, columns: String },
25
26 #[error("failed to split on: `{split}` (columns were: {columns})")]
27 InvalidSplit { split: String, columns: String },
28
29 #[error(
30 "failed to slice row on: `{start}..{end}` (len was: {len})",
31 start = range.start,
32 end = range.end
33 )]
34 SliceIndex { range: Range<usize>, len: usize },
35
36 #[error("internal postgres error")]
38 Sql(#[from] SqlError),
39}
40
41impl Error {
42 pub fn new<D>(msg: D) -> Error
44 where
45 D: Display,
46 {
47 Error::Custom {
48 msg: msg.to_string(),
49 }
50 }
51
52 fn is_soft(&self) -> bool {
54 match self {
55 Error::Sql(sql) => {
56 let mut error: &dyn StdError = sql;
57 loop {
58 if let Some(WasNull) = error.downcast_ref() {
59 break true;
60 }
61
62 match error.source() {
63 Some(source) => error = source,
64 None => break false,
65 }
66 }
67 }
68
69 _ => false,
70 }
71 }
72}
73
74mod private {
75 pub mod row {
76 pub trait Sealed {}
77 }
78}
79
80pub trait Row: private::row::Sealed {
84 fn columns(&self) -> &[Column];
86
87 fn try_get<'a, I, T>(&'a self, index: I) -> Result<T, Error>
89 where
90 I: RowIndex + Display,
91 T: FromSql<'a>;
92
93 fn len(&self) -> usize {
95 self.columns().len()
96 }
97
98 fn is_empty(&self) -> bool {
100 self.len() == 0
101 }
102
103 fn get<'a, I, T>(&'a self, index: I) -> T
109 where
110 I: RowIndex + Display,
111 T: FromSql<'a>,
112 {
113 match self.try_get::<I, T>(index) {
114 Ok(value) => value,
115 Err(err) => panic!("failed to retrieve column: {}", err),
116 }
117 }
118
119 fn slice(&self, range: Range<usize>) -> Result<RowSlice<Self>, Error>
121 where
122 Self: Sized,
123 {
124 if range.end > self.len() {
125 Err(Error::SliceIndex {
126 range,
127 len: self.len(),
128 })
129 } else {
130 let slice = RowSlice { row: self, range };
131 Ok(slice)
132 }
133 }
134}
135
136pub struct RowSlice<'a, R>
138where
139 R: Row,
140{
141 row: &'a R,
142 range: Range<usize>,
143}
144
145pub trait FromSqlRow: Sized {
162 const COLUMN_COUNT: usize;
166
167 fn from_row<R>(row: &R) -> Result<Self, Error>
169 where
170 R: Row;
171
172 fn from_row_multi<R>(rows: &[R]) -> Result<Vec<Self>, Error>
180 where
181 R: Row,
182 {
183 rows.iter().map(Self::from_row).collect()
184 }
185}
186
187pub trait Merge {
191 type Item;
193
194 fn insert(&mut self, item: Self::Item);
196}
197
198impl<T> Merge for Vec<T> {
199 type Item = T;
200 fn insert(&mut self, item: T) {
201 self.push(item)
202 }
203}
204
205impl<T> Merge for HashSet<T>
206where
207 T: Hash + Eq,
208{
209 type Item = T;
210 fn insert(&mut self, item: T) {
211 HashSet::insert(self, item);
212 }
213}
214
215impl<T> Merge for BTreeSet<T>
216where
217 T: Ord,
218{
219 type Item = T;
220 fn insert(&mut self, item: T) {
221 BTreeSet::insert(self, item);
222 }
223}
224
225impl private::row::Sealed for tokio_postgres::Row {}
226
227impl Row for tokio_postgres::Row {
228 fn columns(&self) -> &[Column] {
229 tokio_postgres::Row::columns(self)
230 }
231
232 fn try_get<'a, I, T>(&'a self, index: I) -> Result<T, Error>
233 where
234 I: RowIndex + Display,
235 T: FromSql<'a>,
236 {
237 tokio_postgres::Row::try_get(self, index).map_err(Error::from)
238 }
239
240 fn len(&self) -> usize {
241 tokio_postgres::Row::len(self)
242 }
243
244 fn is_empty(&self) -> bool {
245 tokio_postgres::Row::is_empty(self)
246 }
247
248 fn get<'a, I, T>(&'a self, index: I) -> T
249 where
250 I: RowIndex + Display,
251 T: FromSql<'a>,
252 {
253 tokio_postgres::Row::get(self, index)
254 }
255}
256
257impl<R> private::row::Sealed for RowSlice<'_, R> where R: Row {}
258
259impl<R> Row for RowSlice<'_, R>
260where
261 R: Row,
262{
263 fn columns(&self) -> &[Column] {
264 &self.row.columns()[self.range.clone()]
265 }
266
267 fn try_get<'a, I, T>(&'a self, index: I) -> Result<T, Error>
268 where
269 I: RowIndex + Display,
270 T: FromSql<'a>,
271 {
272 if let Some(index) = index.__idx(self.columns()) {
273 self.row.try_get(self.range.start + index)
274 } else {
275 Err(Error::SliceLookup {
276 index: index.to_string(),
277 columns: format_columns(self.columns()),
278 })
279 }
280 }
281}
282
283impl<R> RowSlice<'_, R>
284where
285 R: Row,
286{
287 pub fn slice(&self, range: Range<usize>) -> Result<RowSlice<R>, Error>
292 where
293 Self: Sized,
294 {
295 if range.end > self.range.end {
296 Err(Error::SliceIndex {
297 range,
298 len: self.range.end,
299 })
300 } else {
301 let slice = RowSlice {
302 row: self.row,
303 range,
304 };
305 Ok(slice)
306 }
307 }
308}
309
310pub fn split_columns_many<'a, S>(
342 columns: &'a [Column],
343 splits: &'a [S],
344) -> impl Iterator<Item = Result<Range<usize>, Error>> + 'a
345where
346 S: AsRef<str>,
347{
348 let column_names = columns.iter().map(|col| col.name());
349 partition_many(column_names, splits.iter()).map(move |split| match split {
350 SplitResult::Range(range) => Ok(range),
351 SplitResult::NotFound { split, start } => Err(Error::InvalidSplit {
352 split,
353 columns: format_columns(&columns[start..]),
354 }),
355 })
356}
357
358#[cfg_attr(test, derive(Debug, PartialEq))]
359enum SplitResult {
360 NotFound { split: String, start: usize },
361 Range(Range<usize>),
362}
363
364fn partition_many<'a>(
365 columns: impl Iterator<Item = impl AsRef<str> + 'a> + 'a,
366 splits: impl Iterator<Item = impl AsRef<str> + 'a> + 'a,
367) -> impl Iterator<Item = SplitResult> + 'a {
368 let mut columns = columns.enumerate();
369 let mut splits = splits;
370
371 let mut previous_end = 0;
372
373 iter::from_fn(move || -> Option<_> {
374 if let Some(split) = splits.next() {
375 let split = split.as_ref();
376 if let Some((end, _)) = columns.find(|(_, name)| name.as_ref() == split) {
377 let range = previous_end..end;
378 previous_end = end;
379 Some(SplitResult::Range(range))
380 } else {
381 Some(SplitResult::NotFound {
382 split: split.to_owned(),
383 start: previous_end,
384 })
385 }
386 } else {
387 let (last, _) = columns.by_ref().last()?;
388 let len = last + 1;
389 Some(SplitResult::Range(previous_end..len))
390 }
391 })
392}
393
394fn format_columns(columns: &[Column]) -> String {
395 let mut total = String::with_capacity(16 * columns.len());
396 for col in columns {
397 if !total.is_empty() {
398 total.push_str(", ");
399 }
400 write!(total, "`{}`", col.name()).unwrap();
401 }
402 total
403}
404
405mod from_row_sql_impls {
406 use super::*;
407
408 use std::rc::Rc;
409 use std::sync::Arc;
410
411 macro_rules! impl_from_row_for_tuple {
412 (($($elem:ident),+)) => {
413 impl<$($elem),+> FromSqlRow for ($($elem,)+)
414 where $($elem: for<'a> FromSql<'a> + std::fmt::Display),+
415 {
416 const COLUMN_COUNT: usize = impl_from_row_for_tuple!(@count ($($elem),*));
417
418 fn from_row<R>(row: &R) -> Result<Self, Error>
419 where R: Row {
420 if row.len() != Self::COLUMN_COUNT {
421 Err(Error::ColumnCount {
422 expected: Self::COLUMN_COUNT,
423 found: row.len(),
424 })
425 } else {
426 let result = (
427 $(
428 row.try_get::<usize, $elem>(
429 impl_from_row_for_tuple!(@index $elem)
430 )?,
431 )+
432 );
433
434 Ok(result)
435 }
436 }
437 }
438 };
439
440 (@index A) => { 0 };
441 (@index B) => { 1 };
442 (@index C) => { 2 };
443 (@index D) => { 3 };
444 (@index E) => { 4 };
445 (@index F) => { 5 };
446 (@index G) => { 6 };
447 (@index H) => { 7 };
448
449 (@count ()) => { 0 };
450 (@count ($head:ident $(, $tail:ident)*)) => {{
451 1 + impl_from_row_for_tuple!(@count ($($tail),*))
452 }};
453 }
454
455 impl_from_row_for_tuple!((A));
456 impl_from_row_for_tuple!((A, B));
457 impl_from_row_for_tuple!((A, B, C));
458 impl_from_row_for_tuple!((A, B, C, D));
459 impl_from_row_for_tuple!((A, B, C, D, E));
460 impl_from_row_for_tuple!((A, B, C, D, E, F));
461 impl_from_row_for_tuple!((A, B, C, D, E, F, G));
462 impl_from_row_for_tuple!((A, B, C, D, E, F, G, H));
463
464 impl<T> FromSqlRow for Option<T>
465 where
466 T: FromSqlRow,
467 {
468 const COLUMN_COUNT: usize = T::COLUMN_COUNT;
469
470 fn from_row<R>(row: &R) -> Result<Self, Error>
471 where
472 R: Row,
473 {
474 match T::from_row(row) {
475 Ok(value) => Ok(Some(value)),
476 Err(error) if error.is_soft() => Ok(None),
477 Err(error) => Err(error),
478 }
479 }
480 }
481
482 impl<T, E> FromSqlRow for Result<T, E>
483 where
484 T: FromSqlRow,
485 E: From<Error>,
486 {
487 const COLUMN_COUNT: usize = T::COLUMN_COUNT;
488
489 fn from_row<R>(row: &R) -> Result<Self, Error>
490 where
491 R: Row,
492 {
493 match T::from_row(row) {
494 Ok(value) => Ok(Ok(value)),
495 Err(error) => Ok(Err(E::from(error))),
496 }
497 }
498 }
499
500 macro_rules! impl_from_row_for_wrapper {
501 ($wrapper:ident, $constructor:expr) => {
502 impl<T> FromSqlRow for $wrapper<T>
503 where
504 T: FromSqlRow,
505 {
506 const COLUMN_COUNT: usize = T::COLUMN_COUNT;
507
508 fn from_row<R>(row: &R) -> Result<Self, Error>
509 where
510 R: Row,
511 {
512 let value = T::from_row(row)?;
513 Ok($constructor(value))
514 }
515 }
516 };
517 }
518
519 impl_from_row_for_wrapper!(Box, Box::new);
520 impl_from_row_for_wrapper!(Rc, Rc::new);
521 impl_from_row_for_wrapper!(Arc, Arc::new);
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 fn split_chars_fallible<'a>(
529 columns: &'a str,
530 splits: &'a str,
531 ) -> impl Iterator<Item = SplitResult> + 'a {
532 let cols = columns.chars().map(|ch| ch.to_string());
533 let splits = splits.chars().map(|ch| ch.to_string());
534 partition_many(cols, splits)
535 }
536
537 fn split_chars<'a>(
538 columns: &'a str,
539 splits: &'a str,
540 ) -> impl Iterator<Item = Range<usize>> + 'a {
541 let cols = columns.chars().map(|ch| ch.to_string());
542 let splits = splits.chars().map(|ch| ch.to_string());
543 partition_many(cols, splits).map(move |split| match split {
544 SplitResult::Range(range) => range,
545 SplitResult::NotFound { split, start } => panic!(
546 "failed to split {:?} on {:?}",
547 columns.chars().skip(start).collect::<String>(),
548 split,
549 ),
550 })
551 }
552
553 #[test]
554 fn split_columns_many_no_excess() {
555 let partitions = split_chars("abcabdab", "aaa").collect::<Vec<_>>();
556 assert_eq!(partitions, vec![0..0, 0..3, 3..6, 6..8,])
557 }
558
559 #[test]
560 fn split_columns_many_leading_columns() {
561 let partitions = split_chars("deabcabdab", "aaa").collect::<Vec<_>>();
562 assert_eq!(partitions, vec![0..2, 2..5, 5..8, 8..10,])
563 }
564
565 #[test]
566 fn split_columns_many_too_many_splits() {
567 let partitions = split_chars_fallible("abcabc", "aaa").collect::<Vec<_>>();
568 assert_eq!(
569 partitions,
570 vec![
571 SplitResult::Range(0..0),
572 SplitResult::Range(0..3),
573 SplitResult::NotFound {
574 split: "a".to_owned(),
575 start: 3,
576 }
577 ]
578 )
579 }
580}