elefant_tools/storage/
mod.rs

1use crate::models::PostgresDatabase;
2use crate::*;
3use bytes::Bytes;
4use futures::Stream;
5use std::sync::Arc;
6
7mod data_format;
8mod elefant_file;
9mod postgres;
10mod sql_file;
11mod table_data;
12
13// pub use elefant_file::ElefantFileDestinationStorage;
14use crate::models::PostgresSchema;
15use crate::models::PostgresTable;
16use crate::quoting::IdentifierQuoter;
17pub use data_format::*;
18pub use postgres::PostgresInstanceStorage;
19pub use sql_file::{apply_sql_file, apply_sql_string, SqlDataMode, SqlFile, SqlFileOptions};
20pub use table_data::*;
21
22/// A trait for thing that are either a CopyDestination or CopySource.
23pub trait BaseCopyTarget {
24    /// Which data format is supported by this destination/source.
25    fn supported_data_format(
26        &self,
27    ) -> impl std::future::Future<Output = Result<Vec<DataFormat>>> + Send;
28}
29
30/// A factory for providing copy sources. This is used to create a source that can be used to read data from.
31pub trait CopySourceFactory: BaseCopyTarget {
32    /// A type that can be used to read data from the source. This type has to support
33    /// single threaded reading, but can support multiple threads reading at the same time.
34    type SequentialSource: CopySource;
35
36    /// A type that can be used to read data from the source. This type has to support
37    /// multiple threads reading at the same time.
38    type ParallelSource: CopySource + Clone + Sync;
39
40    /// Should create whatever type is needed to be able to read data from the source.
41    fn create_source(
42        &self,
43    ) -> impl std::future::Future<
44        Output = Result<SequentialOrParallel<Self::SequentialSource, Self::ParallelSource>>,
45    > + Send;
46
47    /// Should create a datasource that works with single threaded reading.
48    fn create_sequential_source(
49        &self,
50    ) -> impl std::future::Future<Output = Result<Self::SequentialSource>> + Send;
51
52    /// Should return what kind of parallelism is supported by the source. This is used
53    /// for negotiation with the destination.
54    fn supported_parallelism(&self) -> SupportedParallelism;
55}
56
57/// A copy source is something that can be used to read data from a source.
58pub trait CopySource: Send {
59    /// The type of the specific data stream provided when reading data
60    type DataStream: Stream<Item = Result<Bytes>> + Send;
61
62    /// The type of the cleanup that is returned when reading data. Can be `()` if no cleanup is needed.
63    type Cleanup: AsyncCleanup;
64
65    /// Should provide introspection data of the source. This means poking the `pg_catalog` tables when
66    /// working with Postgres, for example.
67    fn get_introspection(
68        &self,
69    ) -> impl std::future::Future<Output = Result<PostgresDatabase>> + Send;
70
71    /// Should return a data-stream for the specified type in the specified format.
72    fn get_data(
73        &self,
74        schema: &PostgresSchema,
75        table: &PostgresTable,
76        data_format: &DataFormat,
77    ) -> impl std::future::Future<Output = Result<TableData<Self::DataStream, Self::Cleanup>>> + Send;
78}
79
80/// A factory for providing copy destinations. This is used to create a destination that can be used to write data to.
81pub trait CopyDestinationFactory<'a>: BaseCopyTarget {
82    /// The implementation type when dealing with single-threaded workloads. The can optionally
83    /// support multi-threading, but it is not needed.
84    type SequentialDestination: CopyDestination;
85
86    /// The implementation type when dealing with multithreaded workloads. This type has to support
87    /// multi-threading.
88    type ParallelDestination: CopyDestination + Clone + Sync;
89
90    /// Should create whatever type is needed to be able to write data to the destination.
91    fn create_destination(
92        &'a mut self,
93    ) -> impl std::future::Future<
94        Output = Result<
95            SequentialOrParallel<Self::SequentialDestination, Self::ParallelDestination>,
96        >,
97    > + Send;
98
99    /// Should create a destination that works with single threaded writing.
100    fn create_sequential_destination(
101        &'a mut self,
102    ) -> impl std::future::Future<Output = Result<Self::SequentialDestination>> + Send;
103
104    /// Should return what kind of parallelism is supported by the destination. This is used
105    /// for negotiation with the source.
106    fn supported_parallelism(&self) -> SupportedParallelism;
107}
108
109pub trait CopyDestination: Send {
110    /// This should apply the data to the destination. The data is expected to be in the
111    /// format returned by `supported_data_format`, if possible.
112    fn apply_data<S: Stream<Item = Result<Bytes>> + Send, C: AsyncCleanup>(
113        &mut self,
114        schema: &PostgresSchema,
115        table: &PostgresTable,
116        data: TableData<S, C>,
117    ) -> impl std::future::Future<Output = Result<()>> + Send;
118
119    /// This should apply the DDL statements to the destination.
120    fn apply_transactional_statement(
121        &mut self,
122        statement: &str,
123    ) -> impl std::future::Future<Output = Result<()>> + Send;
124
125    /// This should apply the DDL statements to the destination.
126    /// These commands has to be run outside a transaction, as they might fail otherwise.
127    fn apply_non_transactional_statement(
128        &mut self,
129        statement: &str,
130    ) -> impl std::future::Future<Output = Result<()>> + Send;
131
132    /// Should begin a new transaction.
133    fn begin_transaction(&mut self) -> impl std::future::Future<Output = Result<()>> + Send;
134
135    /// Should commit a running transaction.
136    fn commit_transaction(&mut self) -> impl std::future::Future<Output = Result<()>> + Send;
137
138    /// Should get the identifier quoter that works with this destination. This ensures
139    /// quoting respects the rules of the destination, not the source.
140    fn get_identifier_quoter(&self) -> Arc<IdentifierQuoter>;
141
142    fn finish(&mut self) -> impl std::future::Future<Output = Result<()>> + Send {
143        async { Ok(()) }
144    }
145
146    /// Should try to introspect the destination. If introspection is not supported, this should return `Ok(None)`,
147    /// not an error. Errors should only be returned if introspection is supported, but failed.
148    fn try_introspect(
149        &self,
150    ) -> impl std::future::Future<Output = Result<Option<PostgresDatabase>>> + Send {
151        async { Ok(None) }
152    }
153
154    fn has_data_in_table(
155        &self,
156        _schema: &PostgresSchema,
157        _table: &PostgresTable,
158    ) -> impl std::future::Future<Output = Result<bool>> + Send {
159        async { Ok(false) }
160    }
161}
162
163/// A type that can be either a sequential or parallel source or destination.
164pub enum SequentialOrParallel<S: Send, P: Send + Clone + Sync> {
165    Sequential(S),
166    Parallel(P),
167}
168
169/// Indicates if parallelism is supported.
170#[derive(Clone, Debug, Eq, PartialEq)]
171pub enum SupportedParallelism {
172    /// Only sequential single-threaded operations are available.
173    Sequential,
174    /// Parallel multithreaded operations are available.
175    Parallel,
176}
177
178impl SupportedParallelism {
179    /// Negotiate the parallelism between two sources or destinations.
180    pub fn negotiate_parallelism(&self, other: SupportedParallelism) -> SupportedParallelism {
181        match (self, other) {
182            (SupportedParallelism::Parallel, SupportedParallelism::Parallel) => {
183                SupportedParallelism::Parallel
184            }
185            _ => SupportedParallelism::Sequential,
186        }
187    }
188}
189
190impl<S: CopySource, P: CopySource + Clone + Sync> SequentialOrParallel<S, P> {
191    pub(crate) async fn get_introspection(&self) -> Result<PostgresDatabase> {
192        match self {
193            SequentialOrParallel::Sequential(s) => s.get_introspection().await,
194            SequentialOrParallel::Parallel(p) => p.get_introspection().await,
195        }
196    }
197}
198
199impl<S: CopyDestination, P: CopyDestination + Clone + Sync> SequentialOrParallel<S, P> {
200    pub(crate) async fn begin_transaction(&mut self) -> Result<()> {
201        match self {
202            SequentialOrParallel::Sequential(s) => s.begin_transaction().await,
203            SequentialOrParallel::Parallel(p) => p.begin_transaction().await,
204        }
205    }
206
207    pub(crate) async fn commit_transaction(&mut self) -> Result<()> {
208        match self {
209            SequentialOrParallel::Sequential(s) => s.commit_transaction().await,
210            SequentialOrParallel::Parallel(p) => p.commit_transaction().await,
211        }
212    }
213
214    pub(crate) async fn finish(&mut self) -> Result<()> {
215        match self {
216            SequentialOrParallel::Sequential(s) => s.finish().await,
217            SequentialOrParallel::Parallel(p) => p.finish().await,
218        }
219    }
220
221    pub(crate) async fn try_get_introspeciton(&self) -> Result<Option<PostgresDatabase>> {
222        match self {
223            SequentialOrParallel::Sequential(s) => s.try_introspect().await,
224            SequentialOrParallel::Parallel(p) => p.try_introspect().await,
225        }
226    }
227}
228
229/// A CopyDestination that panics when used.
230/// Cannot be constructed outside this module, but is available for type reference
231/// to indicate Parallel copy is not supported.
232#[derive(Copy, Clone)]
233pub struct ParallelCopyDestinationNotAvailable {
234    _private: (),
235}
236
237impl CopyDestination for ParallelCopyDestinationNotAvailable {
238    async fn apply_data<S: Stream<Item = Result<Bytes>> + Send, C: AsyncCleanup>(
239        &mut self,
240        _schema: &PostgresSchema,
241        _table: &PostgresTable,
242        _data: TableData<S, C>,
243    ) -> Result<()> {
244        unreachable!("Parallel copy destination not available")
245    }
246
247    async fn apply_transactional_statement(&mut self, _statement: &str) -> Result<()> {
248        unreachable!("Parallel copy destination not available")
249    }
250
251    async fn apply_non_transactional_statement(&mut self, _statement: &str) -> Result<()> {
252        unreachable!("Parallel copy destination not available")
253    }
254
255    async fn begin_transaction(&mut self) -> Result<()> {
256        unreachable!("Parallel copy destination not available")
257    }
258
259    async fn commit_transaction(&mut self) -> Result<()> {
260        unreachable!("Parallel copy destination not available")
261    }
262
263    fn get_identifier_quoter(&self) -> Arc<IdentifierQuoter> {
264        unreachable!("Parallel copy destination not available")
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use crate::test_helpers::{assert_pg_error, TestHelper};
271    use tokio_postgres::error::SqlState;
272
273    pub fn get_copy_source_database_create_script(version: i32) -> &'static str {
274        if version >= 150 {
275            r#"
276        create extension btree_gin;
277
278        create table people(
279            id serial primary key,
280            name text not null unique,
281            age int not null check (age > 0),
282            constraint multi_check check (name != 'fsgsdfgsdf' and age < 9999)
283        );
284
285        create index people_age_idx on people (age desc) include (name, id) where (age % 2 = 0);
286        create index people_age_brin_idx on people using brin (age);
287        create index people_name_lower_idx on people (lower(name));
288
289        insert into people(name, age)
290        values
291            ('foo', 42),
292            ('bar', 89),
293            ('nice', 69),
294            (E'str\nange', 420),
295            (E't\t\tap', 421),
296            (E'q''t', 12)
297            ;
298
299        create table field(
300            id serial primary key
301        );
302
303        create table tree_node(
304            id serial primary key,
305            field_id int not null references field(id),
306            name text not null,
307            parent_id int,
308            constraint field_id_id_unique unique (field_id, id),
309            foreign key (field_id, parent_id) references tree_node(field_id, id),
310            constraint unique_name_per_level unique nulls not distinct (field_id, parent_id, name)
311        );
312
313        create view people_who_cant_drink as select * from people where age < 18;
314
315        create table ext_test_table(
316            id serial primary key,
317            name text not null,
318            search_vector tsvector generated always as (to_tsvector('english', name)) stored
319        );
320
321        create index ext_test_table_name_idx on ext_test_table using gin (id, search_vector);
322
323        create table array_test(
324            name text[] not null
325        );
326
327        insert into array_test(name)
328        values
329            ('{"foo", "bar"}'),
330            ('{"baz", "qux"}'),
331            ('{"quux", "corge"}');
332
333        create table my_partitioned_table(
334            value int not null
335        ) partition by range (value);
336
337        create table my_partitioned_table_1 partition of my_partitioned_table for values from (1) to (10);
338        create table my_partitioned_table_2 partition of my_partitioned_table for values from (10) to (20);
339
340        insert into my_partitioned_table(value)
341        values (1), (9), (11), (19);
342
343        create table pets (
344            id serial primary key,
345            name text not null check(length(name) > 1)
346        );
347
348        create table dogs(
349            breed text not null check(length(breed) > 1)
350        ) inherits (pets);
351
352        create table cats(
353            color text not null
354        ) inherits (pets);
355
356        insert into dogs(name, breed) values('Fido', 'beagle');
357        insert into cats(name, color) values('Fluffy', 'white');
358        insert into pets(name) values('Remy');
359            "#
360        } else {
361            r#"
362        create extension btree_gin;
363
364        create table people(
365            id serial primary key,
366            name text not null unique,
367            age int not null check (age > 0),
368            constraint multi_check check (name != 'fsgsdfgsdf' and age < 9999)
369        );
370
371        create index people_age_idx on people (age desc) include (name, id) where (age % 2 = 0);
372        create index people_age_brin_idx on people using brin (age);
373        create index people_name_lower_idx on people (lower(name));
374
375        insert into people(name, age)
376        values
377            ('foo', 42),
378            ('bar', 89),
379            ('nice', 69),
380            (E'str\nange', 420),
381            (E't\t\tap', 421),
382            (E'q''t', 12)
383            ;
384
385        create table field(
386            id serial primary key
387        );
388
389        create table tree_node(
390            id serial primary key,
391            field_id int not null references field(id),
392            name text not null,
393            parent_id int,
394            constraint field_id_id_unique unique (field_id, id),
395            foreign key (field_id, parent_id) references tree_node(field_id, id),
396            constraint unique_name_per_level unique (field_id, parent_id, name)
397        );
398
399        create view people_who_cant_drink as select * from people where age < 18;
400
401        create table ext_test_table(
402            id serial primary key,
403            name text not null,
404            search_vector tsvector generated always as (to_tsvector('english', name)) stored
405        );
406
407        create index ext_test_table_name_idx on ext_test_table using gin (id, search_vector);
408
409        create table array_test(
410            name text[] not null
411        );
412
413        insert into array_test(name)
414        values
415            ('{"foo", "bar"}'),
416            ('{"baz", "qux"}'),
417            ('{"quux", "corge"}');
418
419        create table my_partitioned_table(
420            value int not null
421        ) partition by range (value);
422
423        create table my_partitioned_table_1 partition of my_partitioned_table for values from (1) to (10);
424        create table my_partitioned_table_2 partition of my_partitioned_table for values from (10) to (20);
425
426        insert into my_partitioned_table(value)
427        values (1), (9), (11), (19);
428
429        create table pets (
430            id serial primary key,
431            name text not null check(length(name) > 1)
432        );
433
434        create table dogs(
435            breed text not null check(length(breed) > 1)
436        ) inherits (pets);
437
438        create table cats(
439            color text not null
440        ) inherits (pets);
441
442        insert into dogs(name, breed) values('Fido', 'beagle');
443        insert into cats(name, color) values('Fluffy', 'white');
444        insert into pets(name) values('Remy');
445            "#
446        }
447    }
448
449    pub fn get_expected_people_data() -> Vec<(i32, String, i32)> {
450        vec![
451            (1, "foo".to_string(), 42),
452            (2, "bar".to_string(), 89),
453            (3, "nice".to_string(), 69),
454            (4, "str\nange".to_string(), 420),
455            (5, "t\t\tap".to_string(), 421),
456            (6, "q't".to_string(), 12),
457        ]
458    }
459
460    pub fn get_expected_array_test_data() -> Vec<(Vec<String>,)> {
461        vec![
462            (vec!["foo".to_string(), "bar".to_string()],),
463            (vec!["baz".to_string(), "qux".to_string()],),
464            (vec!["quux".to_string(), "corge".to_string()],),
465        ]
466    }
467
468    pub async fn validate_pets(connection: &TestHelper) {
469        let pets = connection
470            .get_results::<(i32, String)>("select id, name from pets order by id")
471            .await;
472        assert_eq!(
473            pets,
474            vec![
475                (1, "Fido".to_string()),
476                (2, "Fluffy".to_string()),
477                (3, "Remy".to_string()),
478            ]
479        );
480
481        let dogs = connection
482            .get_results::<(i32, String, String)>("select id, name, breed from dogs order by id")
483            .await;
484        assert_eq!(dogs, vec![(1, "Fido".to_string(), "beagle".to_string()),]);
485
486        let cats = connection
487            .get_results::<(i32, String, String)>("select id, name, color from cats order by id")
488            .await;
489        assert_eq!(cats, vec![(2, "Fluffy".to_string(), "white".to_string()),]);
490    }
491
492    pub async fn validate_copy_state(destination: &TestHelper) {
493        let items = destination
494            .get_results::<(i32, String, i32)>("select id, name, age from people;")
495            .await;
496
497        assert_eq!(items, get_expected_people_data());
498
499        let result = destination
500            .get_conn()
501            .execute_non_query("insert into people (name, age) values ('new-value', 10000)")
502            .await;
503        assert_pg_error(result, SqlState::CHECK_VIOLATION);
504
505        let result = destination
506            .get_conn()
507            .execute_non_query("insert into people (name, age) values ('foo', 100)")
508            .await;
509        assert_pg_error(result, SqlState::UNIQUE_VIOLATION);
510
511        destination
512            .execute_not_query("insert into field (id) values (1);")
513            .await;
514
515        destination.execute_not_query("insert into tree_node(id, field_id, name, parent_id) values (1, 1, 'foo', null), (2, 1, 'bar', 1)").await;
516        if destination.get_conn().version() >= 150 {
517            let result = destination.get_conn().execute_non_query("insert into tree_node(id, field_id, name, parent_id) values (3, 1, 'foo', null)").await;
518            assert_pg_error(result, SqlState::UNIQUE_VIOLATION);
519        }
520
521        let result = destination.get_conn().execute_non_query("insert into tree_node(id, field_id, name, parent_id) values (9999, 9999, 'foobarbaz', null)").await;
522        assert_pg_error(result, SqlState::FOREIGN_KEY_VIOLATION);
523
524        let people_who_cant_drink = destination
525            .get_results::<(i32, String, i32)>("select id, name, age from people_who_cant_drink;")
526            .await;
527        assert_eq!(people_who_cant_drink, vec![(6, "q't".to_string(), 12)]);
528
529        let array_test_data = destination
530            .get_results::<(Vec<String>,)>("select name from array_test;")
531            .await;
532
533        assert_eq!(array_test_data, get_expected_array_test_data());
534
535        let partition_test_data = destination
536            .get_results::<(i32,)>("select value from my_partitioned_table order by value;")
537            .await;
538
539        assert_eq!(partition_test_data, vec![(1,), (9,), (11,), (19,)]);
540
541        validate_pets(destination).await;
542    }
543}