db_mover/databases/postgres/
mod.rs1use std::collections::HashMap;
2use std::io::Write;
3
4use anyhow::Context;
5use postgres::fallible_iterator::FallibleIterator;
6use postgres::{Client, NoTls};
7use tracing::debug;
8use value::PostgreColumn;
9
10use crate::databases::table::{Row, Value};
11use crate::databases::traits::{DBInfoProvider, DBReader, DBWriter};
12
13use super::table::{Column, TableInfo};
14use super::traits::{ReaderIterator, WriterError};
15
16mod value;
17
18pub struct PostgresDB {
19 uri: String,
20 client: Client,
21 table_columns_cache: HashMap<String, Vec<PostgreColumn>>,
22}
23
24impl PostgresDB {
25 pub fn new(uri: &str) -> anyhow::Result<Self> {
26 let client = Self::connect(uri)?;
27 debug!("Connected to postgres {uri}");
28 return Ok(Self {
29 client,
30 uri: uri.to_string(),
31 table_columns_cache: HashMap::default(),
32 });
33 }
34
35 fn connect(uri: &str) -> Result<Client, postgres::Error> {
36 return Client::connect(uri, NoTls);
37 }
38
39 fn get_num_rows(&mut self, table: &str) -> anyhow::Result<u64> {
40 let count_query = format!("SELECT count(1) FROM {table}");
41 return self
42 .client
43 .query_one(&count_query, &[])?
44 .get::<_, i64>(0)
45 .try_into()
46 .context("Failed to convert i64 to u64");
47 }
48
49 fn get_columns(&mut self, table: &str) -> anyhow::Result<Vec<PostgreColumn>> {
50 let mut columns = Vec::new();
51 let rows = self
52 .client
53 .query(
54 "SELECT column_name, is_nullable
55 FROM information_schema.columns
56 WHERE table_name = $1 AND table_schema = current_schema
57 ORDER BY ordinal_position",
58 &[&table],
59 )
60 .context("Failed to query information about table")?;
61 for row in rows {
62 let is_nullable: &str = row.get(1);
63 columns.push(PostgreColumn {
64 name: row.get(0),
65 column_type: postgres::types::Type::UNKNOWN, nullable: is_nullable == "YES",
67 })
68 }
69 let column_names: Vec<&str> = columns.iter().map(|c| c.name.as_str()).collect();
70 let query = format!("SELECT {} FROM {}", column_names.join(", "), table);
71 let stmt = self
72 .client
73 .prepare(&query)
74 .context("Failed to prepare select statement")?;
75 assert!(
76 columns.len() == stmt.columns().len(),
77 "Broken invariant. Expected to get {} column infos, got {}",
78 columns.len(),
79 stmt.columns().len()
80 );
81 for (column, column_info) in std::iter::zip(columns.iter_mut(), stmt.columns()) {
82 assert!(
83 column.name == column_info.name(),
84 "Broken invariant. Expected to get {} column, got {}",
85 column.name,
86 column_info.name()
87 );
88 column.column_type = column_info.type_().clone();
89 }
90 return Ok(columns);
91 }
92
93 fn get_columns_cached(&mut self, table: &str) -> anyhow::Result<Vec<PostgreColumn>> {
94 return match self.table_columns_cache.get(table) {
95 Some(columns) => Ok(columns.clone()),
96 None => {
97 let columns = self.get_columns(table)?;
98 self.table_columns_cache
99 .insert(table.to_string(), columns.clone());
100 Ok(columns)
101 }
102 };
103 }
104}
105
106impl DBInfoProvider for PostgresDB {
107 fn get_table_info(&mut self, table: &str, no_count: bool) -> anyhow::Result<TableInfo> {
108 let mut num_rows = None;
109 if !no_count {
110 num_rows = Some(
111 self.get_num_rows(table)
112 .context("Failed to get number of rows in the table")?,
113 );
114 }
115 let postgres_columns = self
116 .get_columns_cached(table)
117 .context("Failed to get info about table columns")?;
118 let columns = postgres_columns
119 .into_iter()
120 .map(Column::try_from)
121 .collect::<anyhow::Result<Vec<Column>>>()?;
122 return Ok(TableInfo {
123 name: table.to_string(),
124 num_rows,
125 columns,
126 });
127 }
128
129 fn get_tables(&mut self) -> anyhow::Result<Vec<String>> {
130 let rows = self
131 .client
132 .query(
133 "SELECT table_name FROM information_schema.tables WHERE table_schema=current_schema",
134 &[]
135 )
136 .context("Failed to query tables")?;
137 return rows
138 .iter()
139 .map(|row| {
140 row.try_get::<_, String>(0)
141 .context("Failed to read table name")
142 })
143 .collect::<anyhow::Result<Vec<String>>>();
144 }
145}
146
147struct PostgresRowsIter<'a> {
148 target_format: TableInfo,
149 rows: postgres::RowIter<'a>,
150}
151
152impl Iterator for PostgresRowsIter<'_> {
153 type Item = anyhow::Result<Row>;
154
155 fn next(&mut self) -> Option<Self::Item> {
156 return match self
157 .rows
158 .next()
159 .context("Error while reading data from postgres")
160 {
161 Ok(Some(row)) => {
162 let mut result: Row = Vec::with_capacity(self.target_format.columns.len());
163 for (idx, column) in self.target_format.columns.iter().enumerate() {
164 match Value::try_from((column.column_type, &row, idx)) {
165 Ok(val) => result.push(val),
166 Err(e) => return Some(Err(e)),
167 }
168 }
169 Some(Ok(result))
170 }
171 Ok(None) => None,
172 Err(e) => Some(Err(e)),
173 };
174 }
175}
176
177impl DBReader for PostgresDB {
178 fn read_iter(&mut self, target_format: TableInfo) -> anyhow::Result<ReaderIterator<'_>> {
179 let query = format!(
180 "SELECT {} FROM {}",
181 target_format.column_names().join(", "),
182 target_format.name
183 );
184 let stmt = self
185 .client
186 .prepare(&query)
187 .context("Failed to prepare select statement")?;
188 let rows = self
189 .client
190 .query_raw(&stmt, &[] as &[&str; 0])
191 .context("Failed to get data from postgres source")?;
192 return Ok(Box::new(PostgresRowsIter {
193 target_format,
194 rows,
195 }));
196 }
197}
198
199const BINARY_SIGNATURE: &[u8] = b"PGCOPY\n\xFF\r\n\0";
201
202impl DBWriter for PostgresDB {
203 fn opt_clone(&self) -> anyhow::Result<Box<dyn DBWriter>> {
204 return PostgresDB::new(&self.uri).map(|writer| Box::new(writer) as _);
205 }
206
207 fn write_batch(&mut self, batch: &[Row], table: &TableInfo) -> Result<(), WriterError> {
208 let columns = self.get_columns_cached(&table.name)?;
209 let query = format!("COPY {} FROM STDIN WITH BINARY", table.name);
210 let mut writer = self
211 .client
212 .copy_in(&query)
213 .context("Failed to start writing data into postgres")?;
214
215 writer.write_all(BINARY_SIGNATURE)?;
216
217 writer.write_all(&0_i32.to_be_bytes())?;
219
220 writer.write_all(&0_i32.to_be_bytes())?;
222
223 for row in batch {
224 writer.write_all(&(row.len() as i16).to_be_bytes())?;
226 assert_eq!(
227 columns.len(),
228 row.len(),
229 "Number of columns should be equal number of value in a row"
230 );
231 for (value, column) in std::iter::zip(row, &columns) {
232 value.write_postgres_bytes(&mut writer, column)?;
233 }
234 }
235 writer.write_all(&(-1_i16).to_be_bytes())?;
236 writer
237 .finish()
238 .context("Failed to finish writing to postgres")?;
239 return Ok(());
240 }
241
242 fn recover(&mut self) -> anyhow::Result<()> {
243 debug!("Trying to reconnect to the postgres");
244 self.client = Self::connect(&self.uri)?;
245 debug!("Successfully reconnected to the postgres");
246 return Ok(());
247 }
248}