flyway_sql_changelog/
lib.rs

1use std::path::{Path};
2use std::io::Read;
3use std::string::FromUtf8Error;
4use std::sync::Arc;
5use std::cmp::Ordering;
6
7use serde::{ Deserialize, Serialize };
8use std::error::Error;
9use std::fmt::{Display, Formatter};
10use std::hash::{Hash, Hasher};
11use siphasher::sip128::SipHasher13;
12
13const SINGLE_QUOTE1: u8 = '\'' as u8;
14const SINGLE_QUOTE2: u8 = '`' as u8;
15//const SINGLE_QUOTE3: u8 = 'ยด' as u8;
16const DOUBLE_QUOTE: u8 = '"' as u8;
17const SEMICOLON: u8 = ';' as u8;
18const BACKSLASH: u8 = '\\' as u8;
19const MINUS: u8 = '-' as u8;
20const LINEFEED: u8 = '\n' as u8;
21
22/// Kinds of errors that can occur when processing a `ChangelogFile`
23#[derive(Debug)]
24pub enum ChangelogErrorKind {
25    EmptyChangelog,
26    /// min_version, requested_min_version
27    MinVersionNotFound(String, String),
28    /// max_version, requested_max_version
29    MaxVersionNotFound(String, String),
30    IoError(std::io::Error),
31    Other(Box<dyn std::error::Error + Send + Sync>),
32}
33
34/// An error that occurred while processing a `ChangelogFile`
35#[derive(Debug)]
36pub struct ChangelogError {
37    kind: ChangelogErrorKind,
38}
39
40impl ChangelogError {
41    pub fn emtpy_change_log() -> ChangelogError {
42        return ChangelogError {
43            kind: ChangelogErrorKind::EmptyChangelog,
44        };
45    }
46
47    pub fn min_version_not_found(actual_min_version: &str, requested_min_version: &str) -> ChangelogError {
48        return ChangelogError {
49            kind: ChangelogErrorKind::MinVersionNotFound(actual_min_version.to_string(), requested_min_version.to_string()),
50        };
51    }
52
53    pub fn max_version_not_found(actual_max_version: String, requested_max_version: String) -> ChangelogError {
54        return ChangelogError {
55            kind: ChangelogErrorKind::MaxVersionNotFound(actual_max_version, requested_max_version),
56        };
57    }
58
59    pub fn io(io_error: std::io::Error) -> ChangelogError {
60        return ChangelogError {
61            kind: ChangelogErrorKind::IoError(io_error),
62        };
63    }
64
65    pub fn other(other_error: Box<dyn std::error::Error + Send + Sync>) -> ChangelogError {
66        return ChangelogError {
67            kind: ChangelogErrorKind::Other(other_error),
68        };
69    }
70
71    pub fn kind(&self) -> &ChangelogErrorKind {
72        &self.kind
73    }
74}
75
76impl From<std::io::Error> for ChangelogError {
77    fn from(io_error: std::io::Error) -> Self {
78        return ChangelogError::io(io_error);
79    }
80}
81
82impl Display for ChangelogError {
83    fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result {
84        match &self.kind {
85            ChangelogErrorKind::EmptyChangelog => {
86                return write!(fmt, "Database changelog is empty.");
87            }
88            ChangelogErrorKind::MinVersionNotFound(actual_min, requested_min) => {
89                return write!(fmt, "Requested minimum version {} not found in changelog. Minimum available version is {}.", requested_min, actual_min);
90            }
91            ChangelogErrorKind::MaxVersionNotFound(actual_max, requested_max) => {
92                return write!(fmt, "Requested maximum version {} not found in changelog. Maximum available version is {}.", requested_max, actual_max);
93            }
94            ChangelogErrorKind::IoError(io_error) => {
95                return io_error.fmt(fmt);
96            }
97            ChangelogErrorKind::Other(other_error) => {
98                return other_error.fmt(fmt);
99            }
100        };
101    }
102}
103
104impl Error for ChangelogError {
105    fn source(&self) -> Option<&(dyn Error + 'static)> {
106        match &self.kind {
107            ChangelogErrorKind::IoError(io_error) => {
108                return Some(io_error);
109            },
110            ChangelogErrorKind::Other(other_error) => {
111                return Some(&**other_error);
112            },
113            _ => return None
114        };
115    }
116}
117
118pub type Result<T> = std::result::Result<T, ChangelogError>;
119
120/// A changelog file
121#[derive(Debug, Clone)]
122pub struct ChangelogFile {
123    /// The version this `ChangelogFile` represents
124    pub version: u64,
125    /// The name ChangelogFile
126    pub name:String,
127    /// The checksum
128    pub checksum: u64,
129
130    /// The full code of this `ChangelogFile`
131    pub content: Arc<String>,
132}
133
134/// Internal state of the `SqlStatementIterator`
135#[derive(Debug, Clone)]
136enum SqlStatementIteratorState {
137    /// Top-level state
138    Normal,
139    /// The parser is inside a quoted region
140    ///
141    /// The argument is the type of quote used.
142    Quoted(u8),
143    /// The parser is inside an escape sequence
144    ///
145    /// The argument is the type of quote in which the escape appeared.
146    Escaped(u8),
147    /// The parser is inside a comment
148    ///
149    /// First argument is the `SqlStatementIteratorState` from before the comment started.
150    /// Second argument is the contents of the comment.
151    Comment(Box<SqlStatementIteratorState>, Vec<u8>)
152}
153
154/// The annotation of an SQL statement
155///
156/// Changelog files support annotating SQL statements so special error- and transaction-handling
157/// may be applied to the statement. Support for those annotations is not guaranteed by
158/// driver implementations.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct SqlStatementAnnotation {
161    /// Continue the migration if the annotated statement fails
162    may_fail: Option<bool>,
163}
164
165/// A single, optionally annotated, SQL statement
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct SqlStatement {
168    /// The optional annotation of of the statement
169    pub annotation: Option<SqlStatementAnnotation>,
170    /// The actual SQL statement
171    pub statement: String,
172}
173
174/// An iterator for a `ChangelogFile`
175#[derive(Debug, Clone)]
176pub struct SqlStatementIterator {
177    /// `Arc` reference to the content of the changelog
178    content: Arc<String>,
179    /// Current position inside the content
180    position: usize,
181    /// Current state of the iterator
182    state: SqlStatementIteratorState,
183}
184
185impl ChangelogFile {
186    /// Load `ChangelogFile` from a given path
187    pub fn from_path(path: &Path) -> Result<ChangelogFile> {
188        let mut version = 0;
189        let mut name="".to_string();
190        let basename_opt = path.components().last();
191        if let Some(basename) = basename_opt {
192            let basename = basename.as_os_str().to_str().unwrap();
193            let index_opt = basename.find("_");
194            if let Some(index) = index_opt {
195                if index > 0 {
196                    version = (&basename[0..index]).to_string().parse().unwrap_or_default();
197                }
198            }
199        }
200
201        return std::fs::read_to_string(path)
202            .map(|content| {
203                let mut hasher = SipHasher13::new();
204                name.hash(&mut hasher);
205                version.hash(&mut hasher);
206                content.hash(&mut hasher);
207                let checksum = hasher.finish();
208
209                ChangelogFile {
210                    version,
211                    name,
212                    checksum,
213                    content: Arc::new(content)
214                }
215            }
216            )
217            .or_else(|err| Err(err.into()));
218    }
219
220    /// Create `ChangelogFile` from a version and a string containing the contents
221    pub fn from_string(version: u64,name:&str, sql: &str) -> Result<ChangelogFile> {
222
223        let mut hasher = SipHasher13::new();
224        name.hash(&mut hasher);
225        version.hash(&mut hasher);
226        sql.hash(&mut hasher);
227        let checksum = hasher.finish();
228
229        return Ok(ChangelogFile {
230            version,
231            name: name.to_string(),
232            checksum,
233            content: Arc::new(sql.to_string())
234        });
235    }
236
237    /// Create an iterator for the statements of this `ChangelogFile`
238    pub fn iter(&self) -> SqlStatementIterator {
239        return SqlStatementIterator::from_shared_string(self.content.clone());
240    }
241
242    /// Get the version of this `ChangelogFile`
243    pub fn version(&self) -> u64 {
244        return self.version;
245    }
246
247    /// Get the raw text of the `ChangelogFile`
248    pub fn content(&self) -> &str {
249        return self.content.as_str();
250    }
251}
252
253impl PartialEq<Self> for ChangelogFile {
254    #[inline]
255    fn eq(&self, other: &Self) -> bool {
256        return self.version.eq(&other.version) &&
257            self.content.eq(&other.content);
258    }
259}
260
261impl PartialOrd<Self> for ChangelogFile {
262    #[inline]
263    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
264        return self.version.partial_cmp(&other.version);
265    }
266}
267
268impl Eq for ChangelogFile { }
269
270impl Ord for ChangelogFile {
271    fn cmp(&self, other: &Self) -> Ordering {
272        return self.version.cmp(&other.version);
273    }
274}
275
276impl SqlStatementIterator {
277    /// Create object by reading content from a given path
278    pub fn from_path(path: &Path) -> Result<SqlStatementIterator> {
279        let mut text = String::new();
280        std::fs::File::open(path)?.read_to_string(&mut text)?;
281
282        return Ok(Self::from_str(text.as_str()));
283    }
284
285    /// Create object from a string
286    pub fn from_str(content: &str) -> SqlStatementIterator {
287        return Self::from_shared_string(Arc::new(content.to_string()));
288    }
289
290    /// Create object from an `Arc<String>`
291    pub fn from_shared_string(content: Arc<String>) -> SqlStatementIterator {
292        return SqlStatementIterator {
293            content,
294            position: 0,
295            state: SqlStatementIteratorState::Normal,
296        };
297    }
298
299    /// Get the next byte of the content
300    fn next_byte(&mut self) -> Option<u8> {
301        if self.position < self.content.len() {
302            let ch = self.content.as_bytes()[self.position];
303            self.position += 1;
304            return Some(ch);
305        }
306
307        return None;
308    }
309}
310
311impl Iterator for SqlStatementIterator {
312    type Item = SqlStatement;
313
314    fn next(&mut self) -> Option<Self::Item> {
315        // println!("READING next statement: position={}, state={:?}", self.position, &self.state);
316
317        //let mut len = 0;
318        let mut statement: Vec<u8> = Vec::new();
319        let mut annotation: Vec<u8> = Vec::new();
320
321        let mut ch = self.next_byte();
322
323        while ch.is_some() {
324            //len += 1;
325            let current_char = ch.unwrap();
326            ch = self.next_byte();
327
328            //println!("ch={}", current_char);
329
330            match current_char {
331                LINEFEED => {
332                    match &self.state {
333                        SqlStatementIteratorState::Comment(prev_state, comment) => {
334                            let comment_string: String = String::from_utf8(comment.to_vec())
335                                .or_else::<FromUtf8Error, _>(|_: FromUtf8Error| Ok("(non-utf8)".to_string()))
336                                .unwrap();
337
338                            let comment_string = comment_string.trim_start();
339                            if comment_string.starts_with("--! ") {
340                                let comment_string = &comment_string[4..comment_string.len()];
341                                // println!("annotation line: {}", comment_string);
342                                for byte in comment_string.as_bytes() {
343                                    annotation.push(*byte);
344                                }
345                            } else {
346                                // println!("SQL comment: {}", comment_string);
347                            }
348                            self.state = *prev_state.clone();
349                        },
350                        _ => {
351                            statement.push(current_char);
352                        }
353                    }
354                },
355                MINUS => {
356                    match &self.state {
357                        SqlStatementIteratorState::Normal => {
358                            self.state = SqlStatementIteratorState::Comment(Box::new(self.state.clone()), "-".to_string().into_bytes());
359                        },
360                        SqlStatementIteratorState::Comment(prev_state, comment) => {
361                            self.state = SqlStatementIteratorState::Comment(
362                                prev_state.clone(),
363                                comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
364                            );
365                        },
366                        _ => {
367                            statement.push(current_char);
368                        }
369                    };
370                },
371                SINGLE_QUOTE1 => {
372                    match &self.state {
373                        SqlStatementIteratorState::Normal => {
374                            statement.push(current_char);
375                            self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
376                        },
377                        SqlStatementIteratorState::Escaped(q) => {
378                            statement.push(current_char);
379                            self.state = SqlStatementIteratorState::Quoted(*q);
380                        },
381                        SqlStatementIteratorState::Quoted(q) => {
382                            if current_char == *q {
383                                statement.push(current_char);
384                                self.state = SqlStatementIteratorState::Normal;
385                            }
386                        },
387                        SqlStatementIteratorState::Comment(prev_state, comment) => {
388                            if comment.len() < 2 {
389                                let mut comment_clone = comment.clone();
390                                statement.append(&mut comment_clone);
391                                self.state = *prev_state.clone();
392                            } else {
393                                self.state = SqlStatementIteratorState::Comment(
394                                    prev_state.clone(),
395                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
396                                );
397                            }
398                        }
399                    }
400                },
401                SINGLE_QUOTE2 => {
402                    match &self.state {
403                        SqlStatementIteratorState::Normal => {
404                            statement.push(current_char);
405                            self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
406                        },
407                        SqlStatementIteratorState::Escaped(q) => {
408                            statement.push(current_char);
409                            self.state = SqlStatementIteratorState::Quoted(*q);
410                        },
411                        SqlStatementIteratorState::Quoted(q) => {
412                            statement.push(current_char);
413                            if current_char == *q {
414                                self.state = SqlStatementIteratorState::Normal;
415                            }
416                        },
417                        SqlStatementIteratorState::Comment(prev_state, comment) => {
418                            if comment.len() < 2 {
419                                let mut comment_clone = comment.clone();
420                                statement.append(&mut comment_clone);
421                                self.state = *prev_state.clone();
422                            } else {
423                                self.state = SqlStatementIteratorState::Comment(
424                                    prev_state.clone(),
425                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
426                                );
427                            }
428                        }
429                    }
430                },
431                DOUBLE_QUOTE => {
432                    match &self.state {
433                        SqlStatementIteratorState::Normal => {
434                            statement.push(current_char);
435                            self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
436                        },
437                        SqlStatementIteratorState::Escaped(q) => {
438                            statement.push(current_char);
439                            self.state = SqlStatementIteratorState::Quoted(*q);
440                        },
441                        SqlStatementIteratorState::Quoted(q) => {
442                            statement.push(current_char);
443                            if current_char == *q {
444                                self.state = SqlStatementIteratorState::Normal;
445                            }
446                        },
447                        SqlStatementIteratorState::Comment(prev_state, comment) => {
448                            if comment.len() < 2 {
449                                let mut comment_clone = comment.clone();
450                                statement.append(&mut comment_clone);
451                                self.state = *prev_state.clone();
452                            } else {
453                                self.state = SqlStatementIteratorState::Comment(
454                                    prev_state.clone(),
455                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
456                                );
457                            }
458                        }
459                    }
460                },
461                SEMICOLON => {
462                    match &self.state {
463                        SqlStatementIteratorState::Quoted(_) => {
464                            statement.push(current_char);
465                        },
466                        SqlStatementIteratorState::Comment(prev_state, comment) => {
467                            if comment.len() < 2 {
468                                let mut comment_clone = comment.clone();
469                                statement.append(&mut comment_clone);
470                                self.state = *prev_state.clone();
471                            } else {
472                                self.state = SqlStatementIteratorState::Comment(
473                                    prev_state.clone(),
474                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
475                                );
476                            }
477                        },
478                        _ => {
479                            break;
480                        }
481                    };
482                },
483                BACKSLASH => {
484                    match &self.state {
485                        SqlStatementIteratorState::Quoted(q) => {
486                            statement.push(current_char);
487                            self.state = SqlStatementIteratorState::Escaped(*q);
488                        },
489                        SqlStatementIteratorState::Escaped(q) => {
490                            statement.push(current_char);
491                            self.state = SqlStatementIteratorState::Quoted(*q);
492                        },
493                        SqlStatementIteratorState::Comment(prev_state, comment) => {
494                            if comment.len() < 2 {
495                                let mut comment_clone = comment.clone();
496                                statement.append(&mut comment_clone);
497                                self.state = *prev_state.clone();
498                            } else {
499                                self.state = SqlStatementIteratorState::Comment(
500                                    prev_state.clone(),
501                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
502                                );
503                            }
504                        },
505                        _ => {
506                            statement.push(current_char);
507                        }
508                    };
509                },
510                _ => {
511                    match &self.state {
512                        SqlStatementIteratorState::Comment(prev_state, comment) => {
513                            if comment.len() < 2 {
514                                let mut comment_clone = comment.clone();
515                                statement.append(&mut comment_clone);
516                                self.state = *prev_state.clone();
517                            } else {
518                                self.state = SqlStatementIteratorState::Comment(
519                                    prev_state.clone(),
520                                    comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
521                                );
522                            }
523                        },
524                        _ => {
525                            statement.push(current_char);
526                        }
527                    }
528                }
529            }
530        }
531
532        for byte in statement.as_slice() {
533            if *byte > 127 {
534                log::error!("invalid byte: {:#02x}", byte);
535            }
536        }
537
538        // println!("FINISHED READING: statement={}", String::from_utf8(statement.clone()).unwrap());
539        if statement.len() > 0 {
540            //self.position += len;
541            // println!("FINISHED READING: position={}", self.position);
542            return String::from_utf8(statement)
543                .map(|value| value.trim().to_string())
544                .ok()
545                .map_or_else(|| None, |value| {
546                    if value.len() > 0 {
547                        // println!("annotation length: {}", annotation.len());
548                        let annotation = if annotation.len() > 0 {
549                            serde_yaml::from_slice::<SqlStatementAnnotation>(annotation.as_slice())
550                                .or_else(|err| {
551                                    // println!("Error parsing annotations: {:?}", err);
552                                    return Err(err);
553                                })
554                                .ok()
555                        } else {
556                            None
557                        };
558                        // println!("returning annotation: {:?}", &annotation);
559                        // println!("returning statement:  {}", &value);
560                        let result = SqlStatement {
561                            statement: value,
562                            annotation
563                        };
564                        Some(result)
565                    } else {
566                        None
567                    }
568                });
569        } else {
570            return None;
571        }
572    }
573}
574
575#[cfg(test)]
576mod test {
577    use std::path::Path;
578    use crate::ChangelogFile;
579
580    #[test]
581    pub fn test_load_changelog_file1() {
582        let path = Path::new("../").join("example/migrations/V1_test1.sql");
583        let result = ChangelogFile::from_path(&path);
584        match result {
585            Ok(changelog) => {
586                assert_eq!(changelog.version, "V1");
587                assert!(changelog.content().trim_start().starts_with("CREATE TABLE lorem"));
588                assert!(changelog.content().trim_end().ends_with("ipsum VARCHAR(16));"));
589            }
590            Err(err) => {
591                assert!(false, "Changelog file loading failed: {}", err);
592            }
593        }
594    }
595
596    #[test]
597    pub fn test_load_changelog_file2() {
598        let path = Path::new("../").join("example/migrations/V2_test2.sql");
599        let result = ChangelogFile::from_path(&path);
600        match result {
601            Ok(changelog) => {
602                assert_eq!(changelog.version, "V2");
603                assert!(changelog.content().trim_start().starts_with("CREATE INDEX idx_lorem_ipsum"));
604                assert!(changelog.content().trim_end().ends_with("sit INTEGER, ahmed BIGINT);"));
605            }
606            Err(err) => {
607                assert!(false, "Changelog file loading failed: {}", err);
608            }
609        }
610    }
611
612    #[test]
613    pub fn test_changelog_file1_iterator() {
614        let path = Path::new("../").join("example/migrations/V1_test1.sql");
615        let result = ChangelogFile::from_path(&path);
616        match result {
617            Ok(changelog) => {
618                let mut iterator = changelog.iter();
619                let statement1 = iterator.next();
620                assert!(statement1.is_some(), "Found first statement.");
621                assert_eq!(statement1.unwrap().statement.trim(),
622                           "CREATE TABLE lorem(id SERIAL, ipsum VARCHAR(16))",
623                           "Correct first statement returned.");
624                let statement2 = iterator.next();
625                assert!(statement2.is_none(), "Only one statement found in iterator.");
626            }
627            Err(err) => {
628                assert!(false, "Changelog file loading failed: {}", err);
629            }
630        }
631    }
632
633    #[test]
634    pub fn test_changelog_file2_iterator() {
635        let path = Path::new("../").join("example/migrations/V2_test2.sql");
636        let result = ChangelogFile::from_path(&path);
637        match result {
638            Ok(changelog) => {
639                let mut iterator = changelog.iter();
640                let statement1 = iterator.next();
641                assert!(statement1.is_some(), "Found first statement.");
642                assert_eq!(statement1.unwrap().statement.trim(),
643                           "CREATE INDEX idx_lorem_ipsum ON lorem(ipsum)",
644                           "Correct first statement returned.");
645                let statement2 = iterator.next();
646                assert!(statement2.is_some(), "Found second statement.");
647                assert_eq!(statement2.unwrap().statement.trim(),
648                           "CREATE TABLE dolor(id BIGSERIAL PRIMARY KEY, sit INTEGER, ahmed BIGINT)",
649                           "Correct second statement returned.");
650                let statement3 = iterator.next();
651                assert!(statement3.is_none(), "Exactly two statements found in iterator.");
652            }
653            Err(err) => {
654                assert!(false, "Changelog file loading failed: {}", err);
655            }
656        }
657    }
658}