1use crate::models::PostgresDatabase;
2use crate::*;
3use bytes::Bytes;
4use futures::Stream;
5use std::sync::Arc;
6
7mod data_format;
8mod elefant_file;
9mod postgres;
10mod sql_file;
11mod table_data;
12
13use crate::models::PostgresSchema;
15use crate::models::PostgresTable;
16use crate::quoting::IdentifierQuoter;
17pub use data_format::*;
18pub use postgres::PostgresInstanceStorage;
19pub use sql_file::{apply_sql_file, apply_sql_string, SqlDataMode, SqlFile, SqlFileOptions};
20pub use table_data::*;
21
22pub trait BaseCopyTarget {
24 fn supported_data_format(
26 &self,
27 ) -> impl std::future::Future<Output = Result<Vec<DataFormat>>> + Send;
28}
29
30pub trait CopySourceFactory: BaseCopyTarget {
32 type SequentialSource: CopySource;
35
36 type ParallelSource: CopySource + Clone + Sync;
39
40 fn create_source(
42 &self,
43 ) -> impl std::future::Future<
44 Output = Result<SequentialOrParallel<Self::SequentialSource, Self::ParallelSource>>,
45 > + Send;
46
47 fn create_sequential_source(
49 &self,
50 ) -> impl std::future::Future<Output = Result<Self::SequentialSource>> + Send;
51
52 fn supported_parallelism(&self) -> SupportedParallelism;
55}
56
57pub trait CopySource: Send {
59 type DataStream: Stream<Item = Result<Bytes>> + Send;
61
62 type Cleanup: AsyncCleanup;
64
65 fn get_introspection(
68 &self,
69 ) -> impl std::future::Future<Output = Result<PostgresDatabase>> + Send;
70
71 fn get_data(
73 &self,
74 schema: &PostgresSchema,
75 table: &PostgresTable,
76 data_format: &DataFormat,
77 ) -> impl std::future::Future<Output = Result<TableData<Self::DataStream, Self::Cleanup>>> + Send;
78}
79
80pub trait CopyDestinationFactory<'a>: BaseCopyTarget {
82 type SequentialDestination: CopyDestination;
85
86 type ParallelDestination: CopyDestination + Clone + Sync;
89
90 fn create_destination(
92 &'a mut self,
93 ) -> impl std::future::Future<
94 Output = Result<
95 SequentialOrParallel<Self::SequentialDestination, Self::ParallelDestination>,
96 >,
97 > + Send;
98
99 fn create_sequential_destination(
101 &'a mut self,
102 ) -> impl std::future::Future<Output = Result<Self::SequentialDestination>> + Send;
103
104 fn supported_parallelism(&self) -> SupportedParallelism;
107}
108
109pub trait CopyDestination: Send {
110 fn apply_data<S: Stream<Item = Result<Bytes>> + Send, C: AsyncCleanup>(
113 &mut self,
114 schema: &PostgresSchema,
115 table: &PostgresTable,
116 data: TableData<S, C>,
117 ) -> impl std::future::Future<Output = Result<()>> + Send;
118
119 fn apply_transactional_statement(
121 &mut self,
122 statement: &str,
123 ) -> impl std::future::Future<Output = Result<()>> + Send;
124
125 fn apply_non_transactional_statement(
128 &mut self,
129 statement: &str,
130 ) -> impl std::future::Future<Output = Result<()>> + Send;
131
132 fn begin_transaction(&mut self) -> impl std::future::Future<Output = Result<()>> + Send;
134
135 fn commit_transaction(&mut self) -> impl std::future::Future<Output = Result<()>> + Send;
137
138 fn get_identifier_quoter(&self) -> Arc<IdentifierQuoter>;
141
142 fn finish(&mut self) -> impl std::future::Future<Output = Result<()>> + Send {
143 async { Ok(()) }
144 }
145
146 fn try_introspect(
149 &self,
150 ) -> impl std::future::Future<Output = Result<Option<PostgresDatabase>>> + Send {
151 async { Ok(None) }
152 }
153
154 fn has_data_in_table(
155 &self,
156 _schema: &PostgresSchema,
157 _table: &PostgresTable,
158 ) -> impl std::future::Future<Output = Result<bool>> + Send {
159 async { Ok(false) }
160 }
161}
162
163pub enum SequentialOrParallel<S: Send, P: Send + Clone + Sync> {
165 Sequential(S),
166 Parallel(P),
167}
168
169#[derive(Clone, Debug, Eq, PartialEq)]
171pub enum SupportedParallelism {
172 Sequential,
174 Parallel,
176}
177
178impl SupportedParallelism {
179 pub fn negotiate_parallelism(&self, other: SupportedParallelism) -> SupportedParallelism {
181 match (self, other) {
182 (SupportedParallelism::Parallel, SupportedParallelism::Parallel) => {
183 SupportedParallelism::Parallel
184 }
185 _ => SupportedParallelism::Sequential,
186 }
187 }
188}
189
190impl<S: CopySource, P: CopySource + Clone + Sync> SequentialOrParallel<S, P> {
191 pub(crate) async fn get_introspection(&self) -> Result<PostgresDatabase> {
192 match self {
193 SequentialOrParallel::Sequential(s) => s.get_introspection().await,
194 SequentialOrParallel::Parallel(p) => p.get_introspection().await,
195 }
196 }
197}
198
199impl<S: CopyDestination, P: CopyDestination + Clone + Sync> SequentialOrParallel<S, P> {
200 pub(crate) async fn begin_transaction(&mut self) -> Result<()> {
201 match self {
202 SequentialOrParallel::Sequential(s) => s.begin_transaction().await,
203 SequentialOrParallel::Parallel(p) => p.begin_transaction().await,
204 }
205 }
206
207 pub(crate) async fn commit_transaction(&mut self) -> Result<()> {
208 match self {
209 SequentialOrParallel::Sequential(s) => s.commit_transaction().await,
210 SequentialOrParallel::Parallel(p) => p.commit_transaction().await,
211 }
212 }
213
214 pub(crate) async fn finish(&mut self) -> Result<()> {
215 match self {
216 SequentialOrParallel::Sequential(s) => s.finish().await,
217 SequentialOrParallel::Parallel(p) => p.finish().await,
218 }
219 }
220
221 pub(crate) async fn try_get_introspeciton(&self) -> Result<Option<PostgresDatabase>> {
222 match self {
223 SequentialOrParallel::Sequential(s) => s.try_introspect().await,
224 SequentialOrParallel::Parallel(p) => p.try_introspect().await,
225 }
226 }
227}
228
229#[derive(Copy, Clone)]
233pub struct ParallelCopyDestinationNotAvailable {
234 _private: (),
235}
236
237impl CopyDestination for ParallelCopyDestinationNotAvailable {
238 async fn apply_data<S: Stream<Item = Result<Bytes>> + Send, C: AsyncCleanup>(
239 &mut self,
240 _schema: &PostgresSchema,
241 _table: &PostgresTable,
242 _data: TableData<S, C>,
243 ) -> Result<()> {
244 unreachable!("Parallel copy destination not available")
245 }
246
247 async fn apply_transactional_statement(&mut self, _statement: &str) -> Result<()> {
248 unreachable!("Parallel copy destination not available")
249 }
250
251 async fn apply_non_transactional_statement(&mut self, _statement: &str) -> Result<()> {
252 unreachable!("Parallel copy destination not available")
253 }
254
255 async fn begin_transaction(&mut self) -> Result<()> {
256 unreachable!("Parallel copy destination not available")
257 }
258
259 async fn commit_transaction(&mut self) -> Result<()> {
260 unreachable!("Parallel copy destination not available")
261 }
262
263 fn get_identifier_quoter(&self) -> Arc<IdentifierQuoter> {
264 unreachable!("Parallel copy destination not available")
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use crate::test_helpers::{assert_pg_error, TestHelper};
271 use tokio_postgres::error::SqlState;
272
273 pub fn get_copy_source_database_create_script(version: i32) -> &'static str {
274 if version >= 150 {
275 r#"
276 create extension btree_gin;
277
278 create table people(
279 id serial primary key,
280 name text not null unique,
281 age int not null check (age > 0),
282 constraint multi_check check (name != 'fsgsdfgsdf' and age < 9999)
283 );
284
285 create index people_age_idx on people (age desc) include (name, id) where (age % 2 = 0);
286 create index people_age_brin_idx on people using brin (age);
287 create index people_name_lower_idx on people (lower(name));
288
289 insert into people(name, age)
290 values
291 ('foo', 42),
292 ('bar', 89),
293 ('nice', 69),
294 (E'str\nange', 420),
295 (E't\t\tap', 421),
296 (E'q''t', 12)
297 ;
298
299 create table field(
300 id serial primary key
301 );
302
303 create table tree_node(
304 id serial primary key,
305 field_id int not null references field(id),
306 name text not null,
307 parent_id int,
308 constraint field_id_id_unique unique (field_id, id),
309 foreign key (field_id, parent_id) references tree_node(field_id, id),
310 constraint unique_name_per_level unique nulls not distinct (field_id, parent_id, name)
311 );
312
313 create view people_who_cant_drink as select * from people where age < 18;
314
315 create table ext_test_table(
316 id serial primary key,
317 name text not null,
318 search_vector tsvector generated always as (to_tsvector('english', name)) stored
319 );
320
321 create index ext_test_table_name_idx on ext_test_table using gin (id, search_vector);
322
323 create table array_test(
324 name text[] not null
325 );
326
327 insert into array_test(name)
328 values
329 ('{"foo", "bar"}'),
330 ('{"baz", "qux"}'),
331 ('{"quux", "corge"}');
332
333 create table my_partitioned_table(
334 value int not null
335 ) partition by range (value);
336
337 create table my_partitioned_table_1 partition of my_partitioned_table for values from (1) to (10);
338 create table my_partitioned_table_2 partition of my_partitioned_table for values from (10) to (20);
339
340 insert into my_partitioned_table(value)
341 values (1), (9), (11), (19);
342
343 create table pets (
344 id serial primary key,
345 name text not null check(length(name) > 1)
346 );
347
348 create table dogs(
349 breed text not null check(length(breed) > 1)
350 ) inherits (pets);
351
352 create table cats(
353 color text not null
354 ) inherits (pets);
355
356 insert into dogs(name, breed) values('Fido', 'beagle');
357 insert into cats(name, color) values('Fluffy', 'white');
358 insert into pets(name) values('Remy');
359 "#
360 } else {
361 r#"
362 create extension btree_gin;
363
364 create table people(
365 id serial primary key,
366 name text not null unique,
367 age int not null check (age > 0),
368 constraint multi_check check (name != 'fsgsdfgsdf' and age < 9999)
369 );
370
371 create index people_age_idx on people (age desc) include (name, id) where (age % 2 = 0);
372 create index people_age_brin_idx on people using brin (age);
373 create index people_name_lower_idx on people (lower(name));
374
375 insert into people(name, age)
376 values
377 ('foo', 42),
378 ('bar', 89),
379 ('nice', 69),
380 (E'str\nange', 420),
381 (E't\t\tap', 421),
382 (E'q''t', 12)
383 ;
384
385 create table field(
386 id serial primary key
387 );
388
389 create table tree_node(
390 id serial primary key,
391 field_id int not null references field(id),
392 name text not null,
393 parent_id int,
394 constraint field_id_id_unique unique (field_id, id),
395 foreign key (field_id, parent_id) references tree_node(field_id, id),
396 constraint unique_name_per_level unique (field_id, parent_id, name)
397 );
398
399 create view people_who_cant_drink as select * from people where age < 18;
400
401 create table ext_test_table(
402 id serial primary key,
403 name text not null,
404 search_vector tsvector generated always as (to_tsvector('english', name)) stored
405 );
406
407 create index ext_test_table_name_idx on ext_test_table using gin (id, search_vector);
408
409 create table array_test(
410 name text[] not null
411 );
412
413 insert into array_test(name)
414 values
415 ('{"foo", "bar"}'),
416 ('{"baz", "qux"}'),
417 ('{"quux", "corge"}');
418
419 create table my_partitioned_table(
420 value int not null
421 ) partition by range (value);
422
423 create table my_partitioned_table_1 partition of my_partitioned_table for values from (1) to (10);
424 create table my_partitioned_table_2 partition of my_partitioned_table for values from (10) to (20);
425
426 insert into my_partitioned_table(value)
427 values (1), (9), (11), (19);
428
429 create table pets (
430 id serial primary key,
431 name text not null check(length(name) > 1)
432 );
433
434 create table dogs(
435 breed text not null check(length(breed) > 1)
436 ) inherits (pets);
437
438 create table cats(
439 color text not null
440 ) inherits (pets);
441
442 insert into dogs(name, breed) values('Fido', 'beagle');
443 insert into cats(name, color) values('Fluffy', 'white');
444 insert into pets(name) values('Remy');
445 "#
446 }
447 }
448
449 pub fn get_expected_people_data() -> Vec<(i32, String, i32)> {
450 vec![
451 (1, "foo".to_string(), 42),
452 (2, "bar".to_string(), 89),
453 (3, "nice".to_string(), 69),
454 (4, "str\nange".to_string(), 420),
455 (5, "t\t\tap".to_string(), 421),
456 (6, "q't".to_string(), 12),
457 ]
458 }
459
460 pub fn get_expected_array_test_data() -> Vec<(Vec<String>,)> {
461 vec![
462 (vec!["foo".to_string(), "bar".to_string()],),
463 (vec!["baz".to_string(), "qux".to_string()],),
464 (vec!["quux".to_string(), "corge".to_string()],),
465 ]
466 }
467
468 pub async fn validate_pets(connection: &TestHelper) {
469 let pets = connection
470 .get_results::<(i32, String)>("select id, name from pets order by id")
471 .await;
472 assert_eq!(
473 pets,
474 vec![
475 (1, "Fido".to_string()),
476 (2, "Fluffy".to_string()),
477 (3, "Remy".to_string()),
478 ]
479 );
480
481 let dogs = connection
482 .get_results::<(i32, String, String)>("select id, name, breed from dogs order by id")
483 .await;
484 assert_eq!(dogs, vec![(1, "Fido".to_string(), "beagle".to_string()),]);
485
486 let cats = connection
487 .get_results::<(i32, String, String)>("select id, name, color from cats order by id")
488 .await;
489 assert_eq!(cats, vec![(2, "Fluffy".to_string(), "white".to_string()),]);
490 }
491
492 pub async fn validate_copy_state(destination: &TestHelper) {
493 let items = destination
494 .get_results::<(i32, String, i32)>("select id, name, age from people;")
495 .await;
496
497 assert_eq!(items, get_expected_people_data());
498
499 let result = destination
500 .get_conn()
501 .execute_non_query("insert into people (name, age) values ('new-value', 10000)")
502 .await;
503 assert_pg_error(result, SqlState::CHECK_VIOLATION);
504
505 let result = destination
506 .get_conn()
507 .execute_non_query("insert into people (name, age) values ('foo', 100)")
508 .await;
509 assert_pg_error(result, SqlState::UNIQUE_VIOLATION);
510
511 destination
512 .execute_not_query("insert into field (id) values (1);")
513 .await;
514
515 destination.execute_not_query("insert into tree_node(id, field_id, name, parent_id) values (1, 1, 'foo', null), (2, 1, 'bar', 1)").await;
516 if destination.get_conn().version() >= 150 {
517 let result = destination.get_conn().execute_non_query("insert into tree_node(id, field_id, name, parent_id) values (3, 1, 'foo', null)").await;
518 assert_pg_error(result, SqlState::UNIQUE_VIOLATION);
519 }
520
521 let result = destination.get_conn().execute_non_query("insert into tree_node(id, field_id, name, parent_id) values (9999, 9999, 'foobarbaz', null)").await;
522 assert_pg_error(result, SqlState::FOREIGN_KEY_VIOLATION);
523
524 let people_who_cant_drink = destination
525 .get_results::<(i32, String, i32)>("select id, name, age from people_who_cant_drink;")
526 .await;
527 assert_eq!(people_who_cant_drink, vec![(6, "q't".to_string(), 12)]);
528
529 let array_test_data = destination
530 .get_results::<(Vec<String>,)>("select name from array_test;")
531 .await;
532
533 assert_eq!(array_test_data, get_expected_array_test_data());
534
535 let partition_test_data = destination
536 .get_results::<(i32,)>("select value from my_partitioned_table order by value;")
537 .await;
538
539 assert_eq!(partition_test_data, vec![(1,), (9,), (11,), (19,)]);
540
541 validate_pets(destination).await;
542 }
543}