db_up_sql_changelog/
lib.rs

1use std::path::{Path};
2use std::io::Read;
3use std::string::FromUtf8Error;
4use std::sync::Arc;
5use std::cmp::Ordering;
6
7use serde::{ Deserialize, Serialize };
8use std::error::Error;
9use std::fmt::{Display, Formatter};
10
11const SINGLE_QUOTE1: u8 = '\'' as u8;
12const SINGLE_QUOTE2: u8 = '`' as u8;
13//const SINGLE_QUOTE3: u8 = 'ยด' as u8;
14const DOUBLE_QUOTE: u8 = '"' as u8;
15const SEMICOLON: u8 = ';' as u8;
16const BACKSLASH: u8 = '\\' as u8;
17const MINUS: u8 = '-' as u8;
18const LINEFEED: u8 = '\n' as u8;
19
20/// Kinds of errors that can occur when processing a `ChangelogFile`
21#[derive(Debug)]
22pub enum ChangelogErrorKind {
23    EmptyChangelog,
24    /// min_version, requested_min_version
25    MinVersionNotFound(String, String),
26    /// max_version, requested_max_version
27    MaxVersionNotFound(String, String),
28    IoError(std::io::Error),
29    Other(Box<dyn std::error::Error + Send + Sync>),
30}
31
32/// An error that occurred while processing a `ChangelogFile`
33#[derive(Debug)]
34pub struct ChangelogError {
35    kind: ChangelogErrorKind,
36}
37
38impl ChangelogError {
39    pub fn emtpy_change_log() -> ChangelogError {
40        return ChangelogError {
41            kind: ChangelogErrorKind::EmptyChangelog,
42        };
43    }
44
45    pub fn min_version_not_found(actual_min_version: &str, requested_min_version: &str) -> ChangelogError {
46        return ChangelogError {
47            kind: ChangelogErrorKind::MinVersionNotFound(actual_min_version.to_string(), requested_min_version.to_string()),
48        };
49    }
50
51    pub fn max_version_not_found(actual_max_version: String, requested_max_version: String) -> ChangelogError {
52        return ChangelogError {
53            kind: ChangelogErrorKind::MaxVersionNotFound(actual_max_version, requested_max_version),
54        };
55    }
56
57    pub fn io(io_error: std::io::Error) -> ChangelogError {
58        return ChangelogError {
59            kind: ChangelogErrorKind::IoError(io_error),
60        };
61    }
62
63    pub fn other(other_error: Box<dyn std::error::Error + Send + Sync>) -> ChangelogError {
64        return ChangelogError {
65            kind: ChangelogErrorKind::Other(other_error),
66        };
67    }
68
69    pub fn kind(&self) -> &ChangelogErrorKind {
70        &self.kind
71    }
72}
73
74impl From<std::io::Error> for ChangelogError {
75    fn from(io_error: std::io::Error) -> Self {
76        return ChangelogError::io(io_error);
77    }
78}
79
80impl Display for ChangelogError {
81    fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result {
82        match &self.kind {
83            ChangelogErrorKind::EmptyChangelog => {
84                return write!(fmt, "Database changelog is empty.");
85            }
86            ChangelogErrorKind::MinVersionNotFound(actual_min, requested_min) => {
87                return write!(fmt, "Requested minimum version {} not found in changelog. Minimum available version is {}.", requested_min, actual_min);
88            }
89            ChangelogErrorKind::MaxVersionNotFound(actual_max, requested_max) => {
90                return write!(fmt, "Requested maximum version {} not found in changelog. Maximum available version is {}.", requested_max, actual_max);
91            }
92            ChangelogErrorKind::IoError(io_error) => {
93                return io_error.fmt(fmt);
94            }
95            ChangelogErrorKind::Other(other_error) => {
96                return other_error.fmt(fmt);
97            }
98        };
99    }
100}
101
102impl Error for ChangelogError {
103    fn source(&self) -> Option<&(dyn Error + 'static)> {
104        match &self.kind {
105            ChangelogErrorKind::IoError(io_error) => {
106                return Some(io_error);
107            },
108            ChangelogErrorKind::Other(other_error) => {
109                return Some(&**other_error);
110            },
111            _ => return None
112        };
113    }
114}
115
116pub type Result<T> = std::result::Result<T, ChangelogError>;
117
118/// A changelog file
119#[derive(Debug, Clone)]
120pub struct ChangelogFile {
121    /// The version this `ChangelogFile` represents
122    version: String,
123
124    /// The full code of this `ChangelogFile`
125    content: Arc<String>,
126}
127
128/// Internal state of the `SqlStatementIterator`
129#[derive(Debug, Clone)]
130enum SqlStatementIteratorState {
131    /// Top-level state
132    Normal,
133    /// The parser is inside a quoted region
134    ///
135    /// The argument is the type of quote used.
136    Quoted(u8),
137    /// The parser is inside an escape sequence
138    ///
139    /// The argument is the type of quote in which the escape appeared.
140    Escaped(u8),
141    /// The parser is inside a comment
142    ///
143    /// First argument is the `SqlStatementIteratorState` from before the comment started.
144    /// Second argument is the contents of the comment.
145    Comment(Box<SqlStatementIteratorState>, Vec<u8>)
146}
147
148/// The annotation of an SQL statement
149///
150/// Changelog files support annotating SQL statements so special error- and transaction-handling
151/// may be applied to the statement. Support for those annotations is not guaranteed by
152/// driver implementations.
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct SqlStatementAnnotation {
155    /// Continue the migration if the annotated statement fails
156    may_fail: Option<bool>,
157}
158
159/// A single, optionally annotated, SQL statement
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct SqlStatement {
162    /// The optional annotation of of the statement
163    pub annotation: Option<SqlStatementAnnotation>,
164    /// The actual SQL statement
165    pub statement: String,
166}
167
168/// An iterator for a `ChangelogFile`
169#[derive(Debug, Clone)]
170pub struct SqlStatementIterator {
171    /// `Arc` reference to the content of the changelog
172    content: Arc<String>,
173    /// Current position inside the content
174    position: usize,
175    /// Current state of the iterator
176    state: SqlStatementIteratorState,
177}
178
179impl ChangelogFile {
180    /// Load `ChangelogFile` from a given path
181    pub fn from_path(path: &Path) -> Result<ChangelogFile> {
182        let mut version = "".to_string();
183        let basename_opt = path.components().last();
184        if let Some(basename) = basename_opt {
185            let basename = basename.as_os_str().to_str().unwrap();
186            let index_opt = basename.find("_");
187            if let Some(index) = index_opt {
188                if index > 0 {
189                    version = (&basename[0..index]).to_string();
190                }
191            }
192        }
193
194        return std::fs::read_to_string(path)
195            .map(|content| ChangelogFile {
196                version,
197                content: Arc::new(content)
198            })
199            .or_else(|err| Err(err.into()));
200    }
201
202    /// Create `ChangelogFile` from a version and a string containing the contents
203    pub fn from_string(version: &str, sql: &str) -> Result<ChangelogFile> {
204        return Ok(ChangelogFile {
205            version: version.to_string(),
206            content: Arc::new(sql.to_string())
207        });
208    }
209
210    /// Create an iterator for the statements of this `ChangelogFile`
211    pub fn iter(&self) -> SqlStatementIterator {
212        return SqlStatementIterator::from_shared_string(self.content.clone());
213    }
214
215    /// Get the version of this `ChangelogFile`
216    pub fn version(&self) -> &str {
217        return self.version.as_str();
218    }
219
220    /// Get the raw text of the `ChangelogFile`
221    pub fn content(&self) -> &str {
222        return self.content.as_str();
223    }
224}
225
226impl PartialEq<Self> for ChangelogFile {
227    #[inline]
228    fn eq(&self, other: &Self) -> bool {
229        return self.version.eq(&other.version) &&
230            self.content.eq(&other.content);
231    }
232}
233
234impl PartialOrd<Self> for ChangelogFile {
235    #[inline]
236    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
237        return self.version.as_bytes().partial_cmp(other.version.as_bytes());
238    }
239}
240
241impl Eq for ChangelogFile { }
242
243impl Ord for ChangelogFile {
244    fn cmp(&self, other: &Self) -> Ordering {
245        return self.version.as_bytes().cmp(other.version.as_bytes());
246    }
247}
248
249impl SqlStatementIterator {
250    /// Create object by reading content from a given path
251    pub fn from_path(path: &Path) -> Result<SqlStatementIterator> {
252        let mut text = String::new();
253        std::fs::File::open(path)?.read_to_string(&mut text)?;
254
255        return Ok(Self::from_str(text.as_str()));
256    }
257
258    /// Create object from a string
259    pub fn from_str(content: &str) -> SqlStatementIterator {
260        return Self::from_shared_string(Arc::new(content.to_string()));
261    }
262
263    /// Create object from an `Arc<String>`
264    pub fn from_shared_string(content: Arc<String>) -> SqlStatementIterator {
265        return SqlStatementIterator {
266            content,
267            position: 0,
268            state: SqlStatementIteratorState::Normal,
269        };
270    }
271
272    /// Get the next byte of the content
273    fn next_byte(&mut self) -> Option<u8> {
274        if self.position < self.content.len() {
275            let ch = self.content.as_bytes()[self.position];
276            self.position += 1;
277            return Some(ch);
278        }
279
280        return None;
281    }
282}
283
284impl Iterator for SqlStatementIterator {
285    type Item = SqlStatement;
286
287    fn next(&mut self) -> Option<Self::Item> {
288        // println!("READING next statement: position={}, state={:?}", self.position, &self.state);
289
290        //let mut len = 0;
291        let mut statement: Vec<u8> = Vec::new();
292        let mut annotation: Vec<u8> = Vec::new();
293
294        let mut ch = self.next_byte();
295
296        while ch.is_some() {
297            //len += 1;
298            let current_char = ch.unwrap();
299            ch = self.next_byte();
300
301            //println!("ch={}", current_char);
302
303            match current_char {
304                LINEFEED => {
305                    match &self.state {
306                        SqlStatementIteratorState::Comment(prev_state, comment) => {
307                            let comment_string: String = String::from_utf8(comment.to_vec())
308                                .or_else::<FromUtf8Error, _>(|_: FromUtf8Error| Ok("(non-utf8)".to_string()))
309                                .unwrap();
310
311                            let comment_string = comment_string.trim_start();
312                            if comment_string.starts_with("--! ") {
313                                let comment_string = &comment_string[4..comment_string.len()];
314                                // println!("annotation line: {}", comment_string);
315                                for byte in comment_string.as_bytes() {
316                                    annotation.push(*byte);
317                                }
318                            } else {
319                                // println!("SQL comment: {}", comment_string);
320                            }
321                            self.state = *prev_state.clone();
322                        },
323                        _ => {
324                            statement.push(current_char);
325                        }
326                    }
327                },
328                MINUS => {
329                    match &self.state {
330                        SqlStatementIteratorState::Normal => {
331                            self.state = SqlStatementIteratorState::Comment(Box::new(self.state.clone()), "-".to_string().into_bytes());
332                        },
333                        SqlStatementIteratorState::Comment(prev_state, comment) => {
334                            self.state = SqlStatementIteratorState::Comment(
335                                prev_state.clone(),
336                                comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
337                            );
338                        },
339                        _ => {
340                            statement.push(current_char);
341                        }
342                    };
343                },
344                SINGLE_QUOTE1 => {
345                    match &self.state {
346                        SqlStatementIteratorState::Normal => {
347                            statement.push(current_char);
348                            self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
349                        },
350                        SqlStatementIteratorState::Escaped(q) => {
351                            statement.push(current_char);
352                            self.state = SqlStatementIteratorState::Quoted(*q);
353                        },
354                        SqlStatementIteratorState::Quoted(q) => {
355                            if current_char == *q {
356                                statement.push(current_char);
357                                self.state = SqlStatementIteratorState::Normal;
358                            }
359                        },
360                        SqlStatementIteratorState::Comment(prev_state, comment) => {
361                            if comment.len() < 2 {
362                                let mut comment_clone = comment.clone();
363                                statement.append(&mut comment_clone);
364                                self.state = *prev_state.clone();
365                            } else {
366                                self.state = SqlStatementIteratorState::Comment(
367                                    prev_state.clone(),
368                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
369                                );
370                            }
371                        }
372                    }
373                },
374                SINGLE_QUOTE2 => {
375                    match &self.state {
376                        SqlStatementIteratorState::Normal => {
377                            statement.push(current_char);
378                            self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
379                        },
380                        SqlStatementIteratorState::Escaped(q) => {
381                            statement.push(current_char);
382                            self.state = SqlStatementIteratorState::Quoted(*q);
383                        },
384                        SqlStatementIteratorState::Quoted(q) => {
385                            statement.push(current_char);
386                            if current_char == *q {
387                                self.state = SqlStatementIteratorState::Normal;
388                            }
389                        },
390                        SqlStatementIteratorState::Comment(prev_state, comment) => {
391                            if comment.len() < 2 {
392                                let mut comment_clone = comment.clone();
393                                statement.append(&mut comment_clone);
394                                self.state = *prev_state.clone();
395                            } else {
396                                self.state = SqlStatementIteratorState::Comment(
397                                    prev_state.clone(),
398                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
399                                );
400                            }
401                        }
402                    }
403                },
404                DOUBLE_QUOTE => {
405                    match &self.state {
406                        SqlStatementIteratorState::Normal => {
407                            statement.push(current_char);
408                            self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
409                        },
410                        SqlStatementIteratorState::Escaped(q) => {
411                            statement.push(current_char);
412                            self.state = SqlStatementIteratorState::Quoted(*q);
413                        },
414                        SqlStatementIteratorState::Quoted(q) => {
415                            statement.push(current_char);
416                            if current_char == *q {
417                                self.state = SqlStatementIteratorState::Normal;
418                            }
419                        },
420                        SqlStatementIteratorState::Comment(prev_state, comment) => {
421                            if comment.len() < 2 {
422                                let mut comment_clone = comment.clone();
423                                statement.append(&mut comment_clone);
424                                self.state = *prev_state.clone();
425                            } else {
426                                self.state = SqlStatementIteratorState::Comment(
427                                    prev_state.clone(),
428                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
429                                );
430                            }
431                        }
432                    }
433                },
434                SEMICOLON => {
435                    match &self.state {
436                        SqlStatementIteratorState::Quoted(_) => {
437                            statement.push(current_char);
438                        },
439                        SqlStatementIteratorState::Comment(prev_state, comment) => {
440                            if comment.len() < 2 {
441                                let mut comment_clone = comment.clone();
442                                statement.append(&mut comment_clone);
443                                self.state = *prev_state.clone();
444                            } else {
445                                self.state = SqlStatementIteratorState::Comment(
446                                    prev_state.clone(),
447                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
448                                );
449                            }
450                        },
451                        _ => {
452                            break;
453                        }
454                    };
455                },
456                BACKSLASH => {
457                    match &self.state {
458                        SqlStatementIteratorState::Quoted(q) => {
459                            statement.push(current_char);
460                            self.state = SqlStatementIteratorState::Escaped(*q);
461                        },
462                        SqlStatementIteratorState::Escaped(q) => {
463                            statement.push(current_char);
464                            self.state = SqlStatementIteratorState::Quoted(*q);
465                        },
466                        SqlStatementIteratorState::Comment(prev_state, comment) => {
467                            if comment.len() < 2 {
468                                let mut comment_clone = comment.clone();
469                                statement.append(&mut comment_clone);
470                                self.state = *prev_state.clone();
471                            } else {
472                                self.state = SqlStatementIteratorState::Comment(
473                                    prev_state.clone(),
474                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
475                                );
476                            }
477                        },
478                        _ => {
479                            statement.push(current_char);
480                        }
481                    };
482                },
483                _ => {
484                    match &self.state {
485                        SqlStatementIteratorState::Comment(prev_state, comment) => {
486                            if comment.len() < 2 {
487                                let mut comment_clone = comment.clone();
488                                statement.append(&mut comment_clone);
489                                self.state = *prev_state.clone();
490                            } else {
491                                self.state = SqlStatementIteratorState::Comment(
492                                    prev_state.clone(),
493                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
494                                );
495                            }
496                        },
497                        _ => {
498                            statement.push(current_char);
499                        }
500                    }
501                }
502            }
503        }
504
505        for byte in statement.as_slice() {
506            if *byte > 127 {
507                println!("invalid byte: {:#02x}", byte);
508            }
509        }
510
511        // println!("FINISHED READING: statement={}", String::from_utf8(statement.clone()).unwrap());
512        if statement.len() > 0 {
513            //self.position += len;
514            // println!("FINISHED READING: position={}", self.position);
515            return String::from_utf8(statement)
516                .map(|value| value.trim().to_string())
517                .ok()
518                .map_or_else(|| None, |value| {
519                    if value.len() > 0 {
520                        // println!("annotation length: {}", annotation.len());
521                        let annotation = if annotation.len() > 0 {
522                            serde_yaml::from_slice::<SqlStatementAnnotation>(annotation.as_slice())
523                                .or_else(|err| {
524                                    // println!("Error parsing annotations: {:?}", err);
525                                    return Err(err);
526                                })
527                                .ok()
528                        } else {
529                            None
530                        };
531                        // println!("returning annotation: {:?}", &annotation);
532                        // println!("returning statement:  {}", &value);
533                        let result = SqlStatement {
534                            statement: value,
535                            annotation
536                        };
537                        Some(result)
538                    } else {
539                        None
540                    }
541                });
542        } else {
543            return None;
544        }
545    }
546}
547
548#[cfg(test)]
549mod test {
550    use std::path::Path;
551    use crate::ChangelogFile;
552
553    #[test]
554    pub fn test_load_changelog_file1() {
555        let path = Path::new(".").join("examples/migrations/V1_test1.sql");
556        let result = ChangelogFile::from_path(&path);
557        match result {
558            Ok(changelog) => {
559                assert_eq!(changelog.version, "V1");
560                assert!(changelog.content().trim_start().starts_with("CREATE TABLE lorem"));
561                assert!(changelog.content().trim_end().ends_with("ipsum VARCHAR(16));"));
562            }
563            Err(err) => {
564                assert!(false, "Changelog file loading failed: {}", err);
565            }
566        }
567    }
568
569    #[test]
570    pub fn test_load_changelog_file2() {
571        let path = Path::new(".").join("examples/migrations/V2_test2.sql");
572        let result = ChangelogFile::from_path(&path);
573        match result {
574            Ok(changelog) => {
575                assert_eq!(changelog.version, "V2");
576                assert!(changelog.content().trim_start().starts_with("CREATE INDEX idx_lorem_ipsum"));
577                assert!(changelog.content().trim_end().ends_with("sit INTEGER, ahmed BIGINT);"));
578            }
579            Err(err) => {
580                assert!(false, "Changelog file loading failed: {}", err);
581            }
582        }
583    }
584
585    #[test]
586    pub fn test_changelog_file1_iterator() {
587        let path = Path::new(".").join("examples/migrations/V1_test1.sql");
588        let result = ChangelogFile::from_path(&path);
589        match result {
590            Ok(changelog) => {
591                let mut iterator = changelog.iter();
592                let statement1 = iterator.next();
593                assert!(statement1.is_some(), "Found first statement.");
594                assert_eq!(statement1.unwrap().statement.trim(),
595                           "CREATE TABLE lorem(id SERIAL, ipsum VARCHAR(16))",
596                           "Correct first statement returned.");
597                let statement2 = iterator.next();
598                assert!(statement2.is_none(), "Only one statement found in iterator.");
599            }
600            Err(err) => {
601                assert!(false, "Changelog file loading failed: {}", err);
602            }
603        }
604    }
605
606    #[test]
607    pub fn test_changelog_file2_iterator() {
608        let path = Path::new(".").join("examples/migrations/V2_test2.sql");
609        let result = ChangelogFile::from_path(&path);
610        match result {
611            Ok(changelog) => {
612                let mut iterator = changelog.iter();
613                let statement1 = iterator.next();
614                assert!(statement1.is_some(), "Found first statement.");
615                assert_eq!(statement1.unwrap().statement.trim(),
616                           "CREATE INDEX idx_lorem_ipsum ON lorem(ipsum)",
617                           "Correct first statement returned.");
618                let statement2 = iterator.next();
619                assert!(statement2.is_some(), "Found second statement.");
620                assert_eq!(statement2.unwrap().statement.trim(),
621                           "CREATE TABLE dolor(id BIGSERIAL PRIMARY KEY, sit INTEGER, ahmed BIGINT)",
622                           "Correct second statement returned.");
623                let statement3 = iterator.next();
624                assert!(statement3.is_none(), "Exactly two statements found in iterator.");
625            }
626            Err(err) => {
627                assert!(false, "Changelog file loading failed: {}", err);
628            }
629        }
630    }
631}