db_mover/databases/postgres/
mod.rs

1use std::collections::HashMap;
2use std::io::Write;
3
4use anyhow::Context;
5use postgres::fallible_iterator::FallibleIterator;
6use postgres::{Client, NoTls};
7use tracing::debug;
8use value::PostgreColumn;
9
10use crate::databases::table::{Row, Value};
11use crate::databases::traits::{DBInfoProvider, DBReader, DBWriter};
12
13use super::table::{Column, TableInfo};
14use super::traits::{ReaderIterator, WriterError};
15
16mod value;
17
18pub struct PostgresDB {
19    uri: String,
20    client: Client,
21    table_columns_cache: HashMap<String, Vec<PostgreColumn>>,
22}
23
24impl PostgresDB {
25    pub fn new(uri: &str) -> anyhow::Result<Self> {
26        let client = Self::connect(uri)?;
27        debug!("Connected to postgres {uri}");
28        return Ok(Self {
29            client,
30            uri: uri.to_string(),
31            table_columns_cache: HashMap::default(),
32        });
33    }
34
35    fn connect(uri: &str) -> Result<Client, postgres::Error> {
36        return Client::connect(uri, NoTls);
37    }
38
39    fn get_num_rows(&mut self, table: &str) -> anyhow::Result<u64> {
40        let count_query = format!("SELECT count(1) FROM {table}");
41        return self
42            .client
43            .query_one(&count_query, &[])?
44            .get::<_, i64>(0)
45            .try_into()
46            .context("Failed to convert i64 to u64");
47    }
48
49    fn get_columns(&mut self, table: &str) -> anyhow::Result<Vec<PostgreColumn>> {
50        let mut columns = Vec::new();
51        let rows = self
52            .client
53            .query(
54                "SELECT column_name, is_nullable
55            FROM information_schema.columns 
56            WHERE table_name = $1 AND table_schema = current_schema
57            ORDER BY ordinal_position",
58                &[&table],
59            )
60            .context("Failed to query information about table")?;
61        for row in rows {
62            let is_nullable: &str = row.get(1);
63            columns.push(PostgreColumn {
64                name: row.get(0),
65                column_type: postgres::types::Type::UNKNOWN, // Temp default
66                nullable: is_nullable == "YES",
67            })
68        }
69        let column_names: Vec<&str> = columns.iter().map(|c| c.name.as_str()).collect();
70        let query = format!("SELECT {} FROM {}", column_names.join(", "), table);
71        let stmt = self
72            .client
73            .prepare(&query)
74            .context("Failed to prepare select statement")?;
75        assert!(
76            columns.len() == stmt.columns().len(),
77            "Broken invariant. Expected to get {} column infos, got {}",
78            columns.len(),
79            stmt.columns().len()
80        );
81        for (column, column_info) in std::iter::zip(columns.iter_mut(), stmt.columns()) {
82            assert!(
83                column.name == column_info.name(),
84                "Broken invariant. Expected to get {} column, got {}",
85                column.name,
86                column_info.name()
87            );
88            column.column_type = column_info.type_().clone();
89        }
90        return Ok(columns);
91    }
92
93    fn get_columns_cached(&mut self, table: &str) -> anyhow::Result<Vec<PostgreColumn>> {
94        return match self.table_columns_cache.get(table) {
95            Some(columns) => Ok(columns.clone()),
96            None => {
97                let columns = self.get_columns(table)?;
98                self.table_columns_cache
99                    .insert(table.to_string(), columns.clone());
100                Ok(columns)
101            }
102        };
103    }
104}
105
106impl DBInfoProvider for PostgresDB {
107    fn get_table_info(&mut self, table: &str, no_count: bool) -> anyhow::Result<TableInfo> {
108        let mut num_rows = None;
109        if !no_count {
110            num_rows = Some(
111                self.get_num_rows(table)
112                    .context("Failed to get number of rows in the table")?,
113            );
114        }
115        let postgres_columns = self
116            .get_columns_cached(table)
117            .context("Failed to get info about table columns")?;
118        let columns = postgres_columns
119            .into_iter()
120            .map(Column::try_from)
121            .collect::<anyhow::Result<Vec<Column>>>()?;
122        return Ok(TableInfo {
123            name: table.to_string(),
124            num_rows,
125            columns,
126        });
127    }
128
129    fn get_tables(&mut self) -> anyhow::Result<Vec<String>> {
130        let rows = self
131            .client
132            .query(
133                "SELECT table_name FROM information_schema.tables WHERE table_schema=current_schema",
134                &[]
135            )
136            .context("Failed to query tables")?;
137        return rows
138            .iter()
139            .map(|row| {
140                row.try_get::<_, String>(0)
141                    .context("Failed to read table name")
142            })
143            .collect::<anyhow::Result<Vec<String>>>();
144    }
145}
146
147struct PostgresRowsIter<'a> {
148    target_format: TableInfo,
149    rows: postgres::RowIter<'a>,
150}
151
152impl Iterator for PostgresRowsIter<'_> {
153    type Item = anyhow::Result<Row>;
154
155    fn next(&mut self) -> Option<Self::Item> {
156        return match self
157            .rows
158            .next()
159            .context("Error while reading data from postgres")
160        {
161            Ok(Some(row)) => {
162                let mut result: Row = Vec::with_capacity(self.target_format.columns.len());
163                for (idx, column) in self.target_format.columns.iter().enumerate() {
164                    match Value::try_from((column.column_type, &row, idx)) {
165                        Ok(val) => result.push(val),
166                        Err(e) => return Some(Err(e)),
167                    }
168                }
169                Some(Ok(result))
170            }
171            Ok(None) => None,
172            Err(e) => Some(Err(e)),
173        };
174    }
175}
176
177impl DBReader for PostgresDB {
178    fn read_iter(&mut self, target_format: TableInfo) -> anyhow::Result<ReaderIterator<'_>> {
179        let query = format!(
180            "SELECT {} FROM {}",
181            target_format.column_names().join(", "),
182            target_format.name
183        );
184        let stmt = self
185            .client
186            .prepare(&query)
187            .context("Failed to prepare select statement")?;
188        let rows = self
189            .client
190            .query_raw(&stmt, &[] as &[&str; 0])
191            .context("Failed to get data from postgres source")?;
192        return Ok(Box::new(PostgresRowsIter {
193            target_format,
194            rows,
195        }));
196    }
197}
198
199// Binary COPY signature (first 15 bytes)
200const BINARY_SIGNATURE: &[u8] = b"PGCOPY\n\xFF\r\n\0";
201
202impl DBWriter for PostgresDB {
203    fn opt_clone(&self) -> anyhow::Result<Box<dyn DBWriter>> {
204        return PostgresDB::new(&self.uri).map(|writer| Box::new(writer) as _);
205    }
206
207    fn write_batch(&mut self, batch: &[Row], table: &TableInfo) -> Result<(), WriterError> {
208        let columns = self.get_columns_cached(&table.name)?;
209        let query = format!("COPY {} FROM STDIN WITH BINARY", table.name);
210        let mut writer = self
211            .client
212            .copy_in(&query)
213            .context("Failed to start writing data into postgres")?;
214
215        writer.write_all(BINARY_SIGNATURE)?;
216
217        // Flags (4 bytes).
218        writer.write_all(&0_i32.to_be_bytes())?;
219
220        // Header extension length (4 bytes)
221        writer.write_all(&0_i32.to_be_bytes())?;
222
223        for row in batch {
224            // Count of fields
225            writer.write_all(&(row.len() as i16).to_be_bytes())?;
226            assert_eq!(
227                columns.len(),
228                row.len(),
229                "Number of columns should be equal number of value in a row"
230            );
231            for (value, column) in std::iter::zip(row, &columns) {
232                value.write_postgres_bytes(&mut writer, column)?;
233            }
234        }
235        writer.write_all(&(-1_i16).to_be_bytes())?;
236        writer
237            .finish()
238            .context("Failed to finish writing to postgres")?;
239        return Ok(());
240    }
241
242    fn recover(&mut self) -> anyhow::Result<()> {
243        debug!("Trying to reconnect to the postgres");
244        self.client = Self::connect(&self.uri)?;
245        debug!("Successfully reconnected to the postgres");
246        return Ok(());
247    }
248}