1use std::path::{Path};
2use std::io::Read;
3use std::string::FromUtf8Error;
4use std::sync::Arc;
5use std::cmp::Ordering;
6
7use serde::{ Deserialize, Serialize };
8use std::error::Error;
9use std::fmt::{Display, Formatter};
10
11const SINGLE_QUOTE1: u8 = '\'' as u8;
12const SINGLE_QUOTE2: u8 = '`' as u8;
13const DOUBLE_QUOTE: u8 = '"' as u8;
15const SEMICOLON: u8 = ';' as u8;
16const BACKSLASH: u8 = '\\' as u8;
17const MINUS: u8 = '-' as u8;
18const LINEFEED: u8 = '\n' as u8;
19
20#[derive(Debug)]
22pub enum ChangelogErrorKind {
23 EmptyChangelog,
24 MinVersionNotFound(String, String),
26 MaxVersionNotFound(String, String),
28 IoError(std::io::Error),
29 Other(Box<dyn std::error::Error + Send + Sync>),
30}
31
32#[derive(Debug)]
34pub struct ChangelogError {
35 kind: ChangelogErrorKind,
36}
37
38impl ChangelogError {
39 pub fn emtpy_change_log() -> ChangelogError {
40 return ChangelogError {
41 kind: ChangelogErrorKind::EmptyChangelog,
42 };
43 }
44
45 pub fn min_version_not_found(actual_min_version: &str, requested_min_version: &str) -> ChangelogError {
46 return ChangelogError {
47 kind: ChangelogErrorKind::MinVersionNotFound(actual_min_version.to_string(), requested_min_version.to_string()),
48 };
49 }
50
51 pub fn max_version_not_found(actual_max_version: String, requested_max_version: String) -> ChangelogError {
52 return ChangelogError {
53 kind: ChangelogErrorKind::MaxVersionNotFound(actual_max_version, requested_max_version),
54 };
55 }
56
57 pub fn io(io_error: std::io::Error) -> ChangelogError {
58 return ChangelogError {
59 kind: ChangelogErrorKind::IoError(io_error),
60 };
61 }
62
63 pub fn other(other_error: Box<dyn std::error::Error + Send + Sync>) -> ChangelogError {
64 return ChangelogError {
65 kind: ChangelogErrorKind::Other(other_error),
66 };
67 }
68
69 pub fn kind(&self) -> &ChangelogErrorKind {
70 &self.kind
71 }
72}
73
74impl From<std::io::Error> for ChangelogError {
75 fn from(io_error: std::io::Error) -> Self {
76 return ChangelogError::io(io_error);
77 }
78}
79
80impl Display for ChangelogError {
81 fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result {
82 match &self.kind {
83 ChangelogErrorKind::EmptyChangelog => {
84 return write!(fmt, "Database changelog is empty.");
85 }
86 ChangelogErrorKind::MinVersionNotFound(actual_min, requested_min) => {
87 return write!(fmt, "Requested minimum version {} not found in changelog. Minimum available version is {}.", requested_min, actual_min);
88 }
89 ChangelogErrorKind::MaxVersionNotFound(actual_max, requested_max) => {
90 return write!(fmt, "Requested maximum version {} not found in changelog. Maximum available version is {}.", requested_max, actual_max);
91 }
92 ChangelogErrorKind::IoError(io_error) => {
93 return io_error.fmt(fmt);
94 }
95 ChangelogErrorKind::Other(other_error) => {
96 return other_error.fmt(fmt);
97 }
98 };
99 }
100}
101
102impl Error for ChangelogError {
103 fn source(&self) -> Option<&(dyn Error + 'static)> {
104 match &self.kind {
105 ChangelogErrorKind::IoError(io_error) => {
106 return Some(io_error);
107 },
108 ChangelogErrorKind::Other(other_error) => {
109 return Some(&**other_error);
110 },
111 _ => return None
112 };
113 }
114}
115
116pub type Result<T> = std::result::Result<T, ChangelogError>;
117
118#[derive(Debug, Clone)]
120pub struct ChangelogFile {
121 version: String,
123
124 content: Arc<String>,
126}
127
128#[derive(Debug, Clone)]
130enum SqlStatementIteratorState {
131 Normal,
133 Quoted(u8),
137 Escaped(u8),
141 Comment(Box<SqlStatementIteratorState>, Vec<u8>)
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct SqlStatementAnnotation {
155 may_fail: Option<bool>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct SqlStatement {
162 pub annotation: Option<SqlStatementAnnotation>,
164 pub statement: String,
166}
167
168#[derive(Debug, Clone)]
170pub struct SqlStatementIterator {
171 content: Arc<String>,
173 position: usize,
175 state: SqlStatementIteratorState,
177}
178
179impl ChangelogFile {
180 pub fn from_path(path: &Path) -> Result<ChangelogFile> {
182 let mut version = "".to_string();
183 let basename_opt = path.components().last();
184 if let Some(basename) = basename_opt {
185 let basename = basename.as_os_str().to_str().unwrap();
186 let index_opt = basename.find("_");
187 if let Some(index) = index_opt {
188 if index > 0 {
189 version = (&basename[0..index]).to_string();
190 }
191 }
192 }
193
194 return std::fs::read_to_string(path)
195 .map(|content| ChangelogFile {
196 version,
197 content: Arc::new(content)
198 })
199 .or_else(|err| Err(err.into()));
200 }
201
202 pub fn from_string(version: &str, sql: &str) -> Result<ChangelogFile> {
204 return Ok(ChangelogFile {
205 version: version.to_string(),
206 content: Arc::new(sql.to_string())
207 });
208 }
209
210 pub fn iter(&self) -> SqlStatementIterator {
212 return SqlStatementIterator::from_shared_string(self.content.clone());
213 }
214
215 pub fn version(&self) -> &str {
217 return self.version.as_str();
218 }
219
220 pub fn content(&self) -> &str {
222 return self.content.as_str();
223 }
224}
225
226impl PartialEq<Self> for ChangelogFile {
227 #[inline]
228 fn eq(&self, other: &Self) -> bool {
229 return self.version.eq(&other.version) &&
230 self.content.eq(&other.content);
231 }
232}
233
234impl PartialOrd<Self> for ChangelogFile {
235 #[inline]
236 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
237 return self.version.as_bytes().partial_cmp(other.version.as_bytes());
238 }
239}
240
241impl Eq for ChangelogFile { }
242
243impl Ord for ChangelogFile {
244 fn cmp(&self, other: &Self) -> Ordering {
245 return self.version.as_bytes().cmp(other.version.as_bytes());
246 }
247}
248
249impl SqlStatementIterator {
250 pub fn from_path(path: &Path) -> Result<SqlStatementIterator> {
252 let mut text = String::new();
253 std::fs::File::open(path)?.read_to_string(&mut text)?;
254
255 return Ok(Self::from_str(text.as_str()));
256 }
257
258 pub fn from_str(content: &str) -> SqlStatementIterator {
260 return Self::from_shared_string(Arc::new(content.to_string()));
261 }
262
263 pub fn from_shared_string(content: Arc<String>) -> SqlStatementIterator {
265 return SqlStatementIterator {
266 content,
267 position: 0,
268 state: SqlStatementIteratorState::Normal,
269 };
270 }
271
272 fn next_byte(&mut self) -> Option<u8> {
274 if self.position < self.content.len() {
275 let ch = self.content.as_bytes()[self.position];
276 self.position += 1;
277 return Some(ch);
278 }
279
280 return None;
281 }
282}
283
284impl Iterator for SqlStatementIterator {
285 type Item = SqlStatement;
286
287 fn next(&mut self) -> Option<Self::Item> {
288 let mut statement: Vec<u8> = Vec::new();
292 let mut annotation: Vec<u8> = Vec::new();
293
294 let mut ch = self.next_byte();
295
296 while ch.is_some() {
297 let current_char = ch.unwrap();
299 ch = self.next_byte();
300
301 match current_char {
304 LINEFEED => {
305 match &self.state {
306 SqlStatementIteratorState::Comment(prev_state, comment) => {
307 let comment_string: String = String::from_utf8(comment.to_vec())
308 .or_else::<FromUtf8Error, _>(|_: FromUtf8Error| Ok("(non-utf8)".to_string()))
309 .unwrap();
310
311 let comment_string = comment_string.trim_start();
312 if comment_string.starts_with("--! ") {
313 let comment_string = &comment_string[4..comment_string.len()];
314 for byte in comment_string.as_bytes() {
316 annotation.push(*byte);
317 }
318 } else {
319 }
321 self.state = *prev_state.clone();
322 },
323 _ => {
324 statement.push(current_char);
325 }
326 }
327 },
328 MINUS => {
329 match &self.state {
330 SqlStatementIteratorState::Normal => {
331 self.state = SqlStatementIteratorState::Comment(Box::new(self.state.clone()), "-".to_string().into_bytes());
332 },
333 SqlStatementIteratorState::Comment(prev_state, comment) => {
334 self.state = SqlStatementIteratorState::Comment(
335 prev_state.clone(),
336 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
337 );
338 },
339 _ => {
340 statement.push(current_char);
341 }
342 };
343 },
344 SINGLE_QUOTE1 => {
345 match &self.state {
346 SqlStatementIteratorState::Normal => {
347 statement.push(current_char);
348 self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
349 },
350 SqlStatementIteratorState::Escaped(q) => {
351 statement.push(current_char);
352 self.state = SqlStatementIteratorState::Quoted(*q);
353 },
354 SqlStatementIteratorState::Quoted(q) => {
355 if current_char == *q {
356 statement.push(current_char);
357 self.state = SqlStatementIteratorState::Normal;
358 }
359 },
360 SqlStatementIteratorState::Comment(prev_state, comment) => {
361 if comment.len() < 2 {
362 let mut comment_clone = comment.clone();
363 statement.append(&mut comment_clone);
364 self.state = *prev_state.clone();
365 } else {
366 self.state = SqlStatementIteratorState::Comment(
367 prev_state.clone(),
368 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
369 );
370 }
371 }
372 }
373 },
374 SINGLE_QUOTE2 => {
375 match &self.state {
376 SqlStatementIteratorState::Normal => {
377 statement.push(current_char);
378 self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
379 },
380 SqlStatementIteratorState::Escaped(q) => {
381 statement.push(current_char);
382 self.state = SqlStatementIteratorState::Quoted(*q);
383 },
384 SqlStatementIteratorState::Quoted(q) => {
385 statement.push(current_char);
386 if current_char == *q {
387 self.state = SqlStatementIteratorState::Normal;
388 }
389 },
390 SqlStatementIteratorState::Comment(prev_state, comment) => {
391 if comment.len() < 2 {
392 let mut comment_clone = comment.clone();
393 statement.append(&mut comment_clone);
394 self.state = *prev_state.clone();
395 } else {
396 self.state = SqlStatementIteratorState::Comment(
397 prev_state.clone(),
398 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
399 );
400 }
401 }
402 }
403 },
404 DOUBLE_QUOTE => {
405 match &self.state {
406 SqlStatementIteratorState::Normal => {
407 statement.push(current_char);
408 self.state = SqlStatementIteratorState::Quoted(SINGLE_QUOTE1);
409 },
410 SqlStatementIteratorState::Escaped(q) => {
411 statement.push(current_char);
412 self.state = SqlStatementIteratorState::Quoted(*q);
413 },
414 SqlStatementIteratorState::Quoted(q) => {
415 statement.push(current_char);
416 if current_char == *q {
417 self.state = SqlStatementIteratorState::Normal;
418 }
419 },
420 SqlStatementIteratorState::Comment(prev_state, comment) => {
421 if comment.len() < 2 {
422 let mut comment_clone = comment.clone();
423 statement.append(&mut comment_clone);
424 self.state = *prev_state.clone();
425 } else {
426 self.state = SqlStatementIteratorState::Comment(
427 prev_state.clone(),
428 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
429 );
430 }
431 }
432 }
433 },
434 SEMICOLON => {
435 match &self.state {
436 SqlStatementIteratorState::Quoted(_) => {
437 statement.push(current_char);
438 },
439 SqlStatementIteratorState::Comment(prev_state, comment) => {
440 if comment.len() < 2 {
441 let mut comment_clone = comment.clone();
442 statement.append(&mut comment_clone);
443 self.state = *prev_state.clone();
444 } else {
445 self.state = SqlStatementIteratorState::Comment(
446 prev_state.clone(),
447 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
448 );
449 }
450 },
451 _ => {
452 break;
453 }
454 };
455 },
456 BACKSLASH => {
457 match &self.state {
458 SqlStatementIteratorState::Quoted(q) => {
459 statement.push(current_char);
460 self.state = SqlStatementIteratorState::Escaped(*q);
461 },
462 SqlStatementIteratorState::Escaped(q) => {
463 statement.push(current_char);
464 self.state = SqlStatementIteratorState::Quoted(*q);
465 },
466 SqlStatementIteratorState::Comment(prev_state, comment) => {
467 if comment.len() < 2 {
468 let mut comment_clone = comment.clone();
469 statement.append(&mut comment_clone);
470 self.state = *prev_state.clone();
471 } else {
472 self.state = SqlStatementIteratorState::Comment(
473 prev_state.clone(),
474 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
475 );
476 }
477 },
478 _ => {
479 statement.push(current_char);
480 }
481 };
482 },
483 _ => {
484 match &self.state {
485 SqlStatementIteratorState::Comment(prev_state, comment) => {
486 if comment.len() < 2 {
487 let mut comment_clone = comment.clone();
488 statement.append(&mut comment_clone);
489 self.state = *prev_state.clone();
490 } else {
491 self.state = SqlStatementIteratorState::Comment(
492 prev_state.clone(),
493 comment.to_vec().into_iter().chain(vec![current_char].into_iter()).collect()
494 );
495 }
496 },
497 _ => {
498 statement.push(current_char);
499 }
500 }
501 }
502 }
503 }
504
505 for byte in statement.as_slice() {
506 if *byte > 127 {
507 println!("invalid byte: {:#02x}", byte);
508 }
509 }
510
511 if statement.len() > 0 {
513 return String::from_utf8(statement)
516 .map(|value| value.trim().to_string())
517 .ok()
518 .map_or_else(|| None, |value| {
519 if value.len() > 0 {
520 let annotation = if annotation.len() > 0 {
522 serde_yaml::from_slice::<SqlStatementAnnotation>(annotation.as_slice())
523 .or_else(|err| {
524 return Err(err);
526 })
527 .ok()
528 } else {
529 None
530 };
531 let result = SqlStatement {
534 statement: value,
535 annotation
536 };
537 Some(result)
538 } else {
539 None
540 }
541 });
542 } else {
543 return None;
544 }
545 }
546}
547
548#[cfg(test)]
549mod test {
550 use std::path::Path;
551 use crate::ChangelogFile;
552
553 #[test]
554 pub fn test_load_changelog_file1() {
555 let path = Path::new(".").join("examples/migrations/V1_test1.sql");
556 let result = ChangelogFile::from_path(&path);
557 match result {
558 Ok(changelog) => {
559 assert_eq!(changelog.version, "V1");
560 assert!(changelog.content().trim_start().starts_with("CREATE TABLE lorem"));
561 assert!(changelog.content().trim_end().ends_with("ipsum VARCHAR(16));"));
562 }
563 Err(err) => {
564 assert!(false, "Changelog file loading failed: {}", err);
565 }
566 }
567 }
568
569 #[test]
570 pub fn test_load_changelog_file2() {
571 let path = Path::new(".").join("examples/migrations/V2_test2.sql");
572 let result = ChangelogFile::from_path(&path);
573 match result {
574 Ok(changelog) => {
575 assert_eq!(changelog.version, "V2");
576 assert!(changelog.content().trim_start().starts_with("CREATE INDEX idx_lorem_ipsum"));
577 assert!(changelog.content().trim_end().ends_with("sit INTEGER, ahmed BIGINT);"));
578 }
579 Err(err) => {
580 assert!(false, "Changelog file loading failed: {}", err);
581 }
582 }
583 }
584
585 #[test]
586 pub fn test_changelog_file1_iterator() {
587 let path = Path::new(".").join("examples/migrations/V1_test1.sql");
588 let result = ChangelogFile::from_path(&path);
589 match result {
590 Ok(changelog) => {
591 let mut iterator = changelog.iter();
592 let statement1 = iterator.next();
593 assert!(statement1.is_some(), "Found first statement.");
594 assert_eq!(statement1.unwrap().statement.trim(),
595 "CREATE TABLE lorem(id SERIAL, ipsum VARCHAR(16))",
596 "Correct first statement returned.");
597 let statement2 = iterator.next();
598 assert!(statement2.is_none(), "Only one statement found in iterator.");
599 }
600 Err(err) => {
601 assert!(false, "Changelog file loading failed: {}", err);
602 }
603 }
604 }
605
606 #[test]
607 pub fn test_changelog_file2_iterator() {
608 let path = Path::new(".").join("examples/migrations/V2_test2.sql");
609 let result = ChangelogFile::from_path(&path);
610 match result {
611 Ok(changelog) => {
612 let mut iterator = changelog.iter();
613 let statement1 = iterator.next();
614 assert!(statement1.is_some(), "Found first statement.");
615 assert_eq!(statement1.unwrap().statement.trim(),
616 "CREATE INDEX idx_lorem_ipsum ON lorem(ipsum)",
617 "Correct first statement returned.");
618 let statement2 = iterator.next();
619 assert!(statement2.is_some(), "Found second statement.");
620 assert_eq!(statement2.unwrap().statement.trim(),
621 "CREATE TABLE dolor(id BIGSERIAL PRIMARY KEY, sit INTEGER, ahmed BIGINT)",
622 "Correct second statement returned.");
623 let statement3 = iterator.next();
624 assert!(statement3.is_none(), "Exactly two statements found in iterator.");
625 }
626 Err(err) => {
627 assert!(false, "Changelog file loading failed: {}", err);
628 }
629 }
630 }
631}