1use std::collections::{HashMap, HashSet};
2use std::error::Error;
3use std::fmt::{Debug, Display};
4use std::fs::File;
5use std::io::{Read, Write};
6use std::path::{Path, PathBuf};
7
8use anyhow::{bail, Context};
9use cargo_toml::Manifest;
10use cot::db::migrations::{DynMigration, MigrationEngine};
11use cot_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType};
12use cot_codegen::symbol_resolver::SymbolResolver;
13use darling::FromMeta;
14use petgraph::graph::DiGraph;
15use petgraph::visit::EdgeRef;
16use proc_macro2::TokenStream;
17use quote::{format_ident, quote, ToTokens};
18use syn::{parse_quote, Meta};
19use tracing::{debug, info, trace};
20
21use crate::utils::find_cargo_toml;
22
23pub fn make_migrations(path: &Path, options: MigrationGeneratorOptions) -> anyhow::Result<()> {
24 match find_cargo_toml(
25 &path
26 .canonicalize()
27 .with_context(|| "unable to canonicalize Cargo.toml path")?,
28 ) {
29 Some(cargo_toml_path) => {
30 let manifest = Manifest::from_path(&cargo_toml_path)
31 .with_context(|| "unable to read Cargo.toml")?;
32 let crate_name = manifest
33 .package
34 .with_context(|| "unable to find package in Cargo.toml")?
35 .name;
36
37 let mut generator = MigrationGenerator::new(cargo_toml_path, crate_name, options);
38 generator
39 .generate_and_write_migrations()
40 .with_context(|| "unable to generate migrations")?;
41 generator
42 .write_migrations_module()
43 .with_context(|| "unable to write migrations.rs")?;
44 }
45 None => {
46 bail!("Cargo.toml not found in the specified directory or any parent directory.")
47 }
48 }
49
50 Ok(())
51}
52
53#[derive(Debug, Clone, Default)]
54pub struct MigrationGeneratorOptions {
55 pub app_name: Option<String>,
56 pub output_dir: Option<PathBuf>,
57}
58
59#[derive(Debug)]
60pub struct MigrationGenerator {
61 cargo_toml_path: PathBuf,
62 crate_name: String,
63 options: MigrationGeneratorOptions,
64}
65
66const MIGRATIONS_MODULE_NAME: &str = "migrations";
67const MIGRATIONS_MODULE_PREFIX: &str = "m_";
68
69impl MigrationGenerator {
70 #[must_use]
71 pub fn new(
72 cargo_toml_path: PathBuf,
73 crate_name: String,
74 options: MigrationGeneratorOptions,
75 ) -> Self {
76 Self {
77 cargo_toml_path,
78 crate_name,
79 options,
80 }
81 }
82
83 fn generate_and_write_migrations(&mut self) -> anyhow::Result<()> {
84 let source_files = self.get_source_files()?;
85
86 if let Some(migration) = self.generate_migrations_to_write(source_files)? {
87 self.write_migration(&migration)?;
88 }
89
90 Ok(())
91 }
92
93 pub fn generate_migrations_to_write(
95 &mut self,
96 source_files: Vec<SourceFile>,
97 ) -> anyhow::Result<Option<MigrationAsSource>> {
98 if let Some(migration) = self.generate_migrations(source_files)? {
99 let migration_name = migration.migration_name.clone();
100 let content = self.generate_migration_file_content(migration);
101 Ok(Some(MigrationAsSource::new(migration_name, content)))
102 } else {
103 Ok(None)
104 }
105 }
106
107 pub fn generate_migrations(
110 &mut self,
111 source_files: Vec<SourceFile>,
112 ) -> anyhow::Result<Option<GeneratedMigration>> {
113 let AppState { models, migrations } = self.process_source_files(source_files)?;
114 let migration_processor = MigrationProcessor::new(migrations)?;
115 let migration_models = migration_processor.latest_models();
116
117 let (modified_models, operations) = self.generate_operations(&models, &migration_models);
118 if operations.is_empty() {
119 Ok(None)
120 } else {
121 let migration_name = migration_processor.next_migration_name()?;
122 let dependencies = migration_processor.base_dependencies();
123
124 let migration =
125 GeneratedMigration::new(migration_name, modified_models, dependencies, operations);
126 Ok(Some(migration))
127 }
128 }
129
130 fn get_source_files(&mut self) -> anyhow::Result<Vec<SourceFile>> {
131 let src_dir = self
132 .cargo_toml_path
133 .parent()
134 .with_context(|| "unable to find parent dir")?
135 .join("src");
136 let src_dir = src_dir
137 .canonicalize()
138 .with_context(|| "unable to canonicalize src dir")?;
139
140 let source_file_paths = Self::find_source_files(&src_dir)?;
141 let source_files = source_file_paths
142 .into_iter()
143 .map(|path| {
144 Self::parse_file(&src_dir, path.clone())
145 .with_context(|| format!("unable to parse file: {path:?}"))
146 })
147 .collect::<anyhow::Result<Vec<_>>>()?;
148 Ok(source_files)
149 }
150
151 fn find_source_files(src_dir: &Path) -> anyhow::Result<Vec<PathBuf>> {
152 let mut paths = Vec::new();
153 for entry in glob::glob(src_dir.join("**/*.rs").to_str().unwrap())
154 .with_context(|| "unable to find Rust source files with glob")?
155 {
156 let path = entry?;
157 paths.push(
158 path.strip_prefix(src_dir)
159 .expect("path must be in src dir")
160 .to_path_buf(),
161 );
162 }
163
164 Ok(paths)
165 }
166
167 fn process_source_files(&self, source_files: Vec<SourceFile>) -> anyhow::Result<AppState> {
168 let mut app_state = AppState::new();
169
170 for source_file in source_files {
171 let path = source_file.path.clone();
172 self.process_parsed_file(source_file, &mut app_state)
173 .with_context(|| format!("unable to find models in file: {path:?}"))?;
174 }
175
176 Ok(app_state)
177 }
178
179 fn parse_file(src_dir: &Path, path: PathBuf) -> anyhow::Result<SourceFile> {
180 let full_path = src_dir.join(&path);
181 debug!("Parsing file: {:?}", &full_path);
182 let mut file = File::open(&full_path).with_context(|| "unable to open file")?;
183
184 let mut src = String::new();
185 file.read_to_string(&mut src)
186 .with_context(|| format!("unable to read file: {full_path:?}"))?;
187
188 SourceFile::parse(path, &src)
189 }
190
191 fn process_parsed_file(
192 &self,
193 SourceFile {
194 path,
195 content: file,
196 }: SourceFile,
197 app_state: &mut AppState,
198 ) -> anyhow::Result<()> {
199 trace!("Processing file: {:?}", &path);
200
201 let symbol_resolver = SymbolResolver::from_file(&file, &path);
202
203 let mut migration_models = Vec::new();
204 for item in file.items {
205 if let syn::Item::Struct(mut item) = item {
206 for attr in &item.attrs.clone() {
207 if is_model_attr(attr) {
208 symbol_resolver.resolve_struct(&mut item);
209
210 let args = Self::args_from_attr(&path, attr)?;
211 let model_in_source =
212 ModelInSource::from_item(item, &args, &symbol_resolver)?;
213
214 match args.model_type {
215 ModelType::Application => {
216 trace!(
217 "Found an Application model: {}",
218 model_in_source.model.name.to_string()
219 );
220 app_state.models.push(model_in_source);
221 }
222 ModelType::Migration => {
223 trace!(
224 "Found a Migration model: {}",
225 model_in_source.model.name.to_string()
226 );
227 migration_models.push(model_in_source);
228 }
229 ModelType::Internal => {}
230 }
231
232 break;
233 }
234 }
235 }
236 }
237
238 if !migration_models.is_empty() {
239 let migration_name = path
240 .file_stem()
241 .with_context(|| format!("unable to get migration file name: {}", path.display()))?
242 .to_string_lossy()
243 .to_string();
244 app_state.migrations.push(Migration {
245 app_name: self.crate_name.clone(),
246 name: migration_name,
247 models: migration_models,
248 });
249 }
250
251 Ok(())
252 }
253
254 fn args_from_attr(path: &Path, attr: &syn::Attribute) -> Result<ModelArgs, ParsingError> {
255 match attr.meta {
256 Meta::Path(_) => {
257 Ok(ModelArgs::default())
259 }
260 _ => ModelArgs::from_meta(&attr.meta).map_err(|e| {
261 ParsingError::from_darling(
262 "couldn't parse model macro arguments",
263 path.to_owned(),
264 &e,
265 )
266 }),
267 }
268 }
269
270 #[must_use]
271 fn generate_operations(
272 &self,
273 app_models: &Vec<ModelInSource>,
274 migration_models: &Vec<ModelInSource>,
275 ) -> (Vec<ModelInSource>, Vec<DynOperation>) {
276 let mut operations = Vec::new();
277 let mut modified_models = Vec::new();
278
279 let mut all_model_names = HashSet::new();
280 let mut app_models_map = HashMap::new();
281 for model in app_models {
282 all_model_names.insert(model.model.table_name.clone());
283 app_models_map.insert(model.model.table_name.clone(), model);
284 }
285 let mut migration_models_map = HashMap::new();
286 for model in migration_models {
287 all_model_names.insert(model.model.table_name.clone());
288 migration_models_map.insert(model.model.table_name.clone(), model);
289 }
290 let mut all_model_names: Vec<_> = all_model_names.into_iter().collect();
291 all_model_names.sort();
292
293 for model_name in all_model_names {
294 let app_model = app_models_map.get(&model_name);
295 let migration_model = migration_models_map.get(&model_name);
296
297 match (app_model, migration_model) {
298 (Some(&app_model), None) => {
299 operations.push(Self::make_create_model_operation(app_model));
300 modified_models.push(app_model.clone());
301 }
302 (Some(&app_model), Some(&migration_model)) => {
303 if app_model.model != migration_model.model {
304 modified_models.push(app_model.clone());
305 operations
306 .extend(self.make_alter_model_operations(app_model, migration_model));
307 }
308 }
309 (None, Some(&migration_model)) => {
310 operations.push(self.make_remove_model_operation(migration_model));
311 }
312 (None, None) => unreachable!(),
313 }
314 }
315
316 (modified_models, operations)
317 }
318
319 #[must_use]
320 fn make_create_model_operation(app_model: &ModelInSource) -> DynOperation {
321 DynOperation::CreateModel {
322 table_name: app_model.model.table_name.clone(),
323 model_ty: app_model.model.resolved_ty.clone(),
324 fields: app_model.model.fields.clone(),
325 }
326 }
327
328 #[must_use]
329 fn make_alter_model_operations(
330 &self,
331 app_model: &ModelInSource,
332 migration_model: &ModelInSource,
333 ) -> Vec<DynOperation> {
334 let mut all_field_names = HashSet::new();
335 let mut app_model_fields = HashMap::new();
336 for field in &app_model.model.fields {
337 all_field_names.insert(field.column_name.clone());
338 app_model_fields.insert(field.column_name.clone(), field);
339 }
340 let mut migration_model_fields = HashMap::new();
341 for field in &migration_model.model.fields {
342 all_field_names.insert(field.column_name.clone());
343 migration_model_fields.insert(field.column_name.clone(), field);
344 }
345
346 let mut all_field_names: Vec<_> = all_field_names.into_iter().collect();
347 all_field_names.sort();
349
350 let mut operations = Vec::new();
351 for field_name in all_field_names {
352 let app_field = app_model_fields.get(&field_name);
353 let migration_field = migration_model_fields.get(&field_name);
354
355 match (app_field, migration_field) {
356 (Some(app_field), None) => {
357 operations.push(Self::make_add_field_operation(app_model, app_field));
358 }
359 (Some(app_field), Some(migration_field)) => {
360 let operation = self.make_alter_field_operation(
361 app_model,
362 app_field,
363 migration_model,
364 migration_field,
365 );
366 if let Some(operation) = operation {
367 operations.push(operation);
368 }
369 }
370 (None, Some(migration_field)) => {
371 operations
372 .push(self.make_remove_field_operation(migration_model, migration_field));
373 }
374 (None, None) => unreachable!(),
375 }
376 }
377
378 operations
379 }
380
381 #[must_use]
382 fn make_add_field_operation(app_model: &ModelInSource, field: &Field) -> DynOperation {
383 DynOperation::AddField {
384 table_name: app_model.model.table_name.clone(),
385 model_ty: app_model.model.resolved_ty.clone(),
386 field: field.clone(),
387 }
388 }
389
390 #[must_use]
391 fn make_alter_field_operation(
392 &self,
393 _app_model: &ModelInSource,
394 app_field: &Field,
395 _migration_model: &ModelInSource,
396 migration_field: &Field,
397 ) -> Option<DynOperation> {
398 if app_field == migration_field {
399 return None;
400 }
401 todo!()
402 }
403
404 #[must_use]
405 fn make_remove_field_operation(
406 &self,
407 _migration_model: &ModelInSource,
408 _migration_field: &Field,
409 ) -> DynOperation {
410 todo!()
411 }
412
413 #[must_use]
414 fn make_remove_model_operation(&self, _migration_model: &ModelInSource) -> DynOperation {
415 todo!()
416 }
417
418 fn generate_migration_file_content(&self, migration: GeneratedMigration) -> String {
419 let operations: Vec<_> = migration
420 .operations
421 .into_iter()
422 .map(|operation| operation.repr())
423 .collect();
424 let dependencies: Vec<_> = migration
425 .dependencies
426 .into_iter()
427 .map(|dependency| dependency.repr())
428 .collect();
429
430 let app_name = self.options.app_name.as_ref().unwrap_or(&self.crate_name);
431 let migration_name = &migration.migration_name;
432 let migration_def = quote! {
433 #[derive(Debug, Copy, Clone)]
434 pub(super) struct Migration;
435
436 impl ::cot::db::migrations::Migration for Migration {
437 const APP_NAME: &'static str = #app_name;
438 const MIGRATION_NAME: &'static str = #migration_name;
439 const DEPENDENCIES: &'static [::cot::db::migrations::MigrationDependency] = &[
440 #(#dependencies,)*
441 ];
442 const OPERATIONS: &'static [::cot::db::migrations::Operation] = &[
443 #(#operations,)*
444 ];
445 }
446 };
447
448 let models = migration
449 .modified_models
450 .iter()
451 .map(Self::model_to_migration_model)
452 .collect::<Vec<_>>();
453 let models_def = quote! {
454 #(#models)*
455 };
456
457 Self::generate_migration(migration_def, models_def)
458 }
459
460 fn write_migration(&self, migration: &MigrationAsSource) -> anyhow::Result<()> {
461 let src_path = self.get_src_path();
462 let migration_path = src_path.join(MIGRATIONS_MODULE_NAME);
463 let migration_file = migration_path.join(format!("{}.rs", migration.name));
464
465 std::fs::create_dir_all(&migration_path).with_context(|| {
466 format!(
467 "unable to create migrations directory: {}",
468 migration_path.display()
469 )
470 })?;
471
472 let mut file = File::create(&migration_file).with_context(|| {
473 format!(
474 "unable to create migration file: {}",
475 migration_file.display()
476 )
477 })?;
478 file.write_all(migration.content.as_bytes())
479 .with_context(|| "unable to write migration file")?;
480 info!("Generated migration: {}", migration_file.display());
481 Ok(())
482 }
483
484 #[must_use]
485 fn generate_migration(migration: TokenStream, modified_models: TokenStream) -> String {
486 let migration = Self::format_tokens(migration);
487 let modified_models = Self::format_tokens(modified_models);
488
489 let header = Self::migration_header();
490
491 format!("{header}\n\n{migration}\n{modified_models}")
492 }
493
494 fn migration_header() -> String {
495 let version = env!("CARGO_PKG_VERSION");
496 let date_time = chrono::offset::Utc::now().format("%Y-%m-%d %H:%M:%S%:z");
497 let header = format!("//! Generated by cot CLI {version} on {date_time}");
498 header
499 }
500
501 #[must_use]
502 fn format_tokens(tokens: TokenStream) -> String {
503 let parsed: syn::File = syn::parse2(tokens).unwrap();
504 prettyplease::unparse(&parsed)
505 }
506
507 #[must_use]
508 fn model_to_migration_model(model: &ModelInSource) -> TokenStream {
509 let mut model_source = model.model_item.clone();
510 model_source.vis = syn::Visibility::Inherited;
511 model_source.ident = format_ident!("_{}", model_source.ident);
512 model_source.attrs.clear();
513 model_source
514 .attrs
515 .push(syn::parse_quote! {#[derive(::core::fmt::Debug)]});
516 model_source
517 .attrs
518 .push(syn::parse_quote! {#[::cot::db::model(model_type = "migration")]});
519 quote! {
520 #model_source
521 }
522 }
523
524 pub fn write_migrations_module(&self) -> anyhow::Result<()> {
525 let src_path = self.get_src_path();
526 let migrations_dir = src_path.join(MIGRATIONS_MODULE_NAME);
527
528 let migration_list = Self::get_migration_list(&migrations_dir)?;
529 let contents = Self::get_migration_module_contents(&migration_list);
530 let contents_string = Self::format_tokens(contents);
531
532 let header = Self::migration_header();
533 let migration_header = "//! List of migrations for the current app.\n//!";
534 let contents_with_header = format!("{migration_header}\n{header}\n\n{contents_string}");
535
536 let mut file = File::create(src_path.join(format!("{MIGRATIONS_MODULE_NAME}.rs")))?;
537 file.write_all(contents_with_header.as_bytes())?;
538
539 Ok(())
540 }
541
542 fn get_migration_list(migrations_dir: &PathBuf) -> anyhow::Result<Vec<String>> {
543 Ok(std::fs::read_dir(migrations_dir)
544 .with_context(|| {
545 format!(
546 "unable to read migrations directory: {}",
547 migrations_dir.display()
548 )
549 })?
550 .filter_map(|entry| {
551 let entry = entry.ok()?;
552 let path = entry.path();
553 let stem = path.file_stem();
554
555 if path.is_file()
556 && stem
557 .unwrap_or_default()
558 .to_string_lossy()
559 .starts_with(MIGRATIONS_MODULE_PREFIX)
560 && path.extension() == Some("rs".as_ref())
561 {
562 stem.map(|stem| stem.to_string_lossy().to_string())
563 } else {
564 None
565 }
566 })
567 .collect())
568 }
569
570 #[must_use]
571 fn get_migration_module_contents(migration_list: &[String]) -> TokenStream {
572 let migration_mods = migration_list.iter().map(|migration| {
573 let migration = format_ident!("{}", migration);
574 quote! {
575 pub mod #migration;
576 }
577 });
578 let migration_refs = migration_list.iter().map(|migration| {
579 let migration = format_ident!("{}", migration);
580 quote! {
581 &#migration::Migration
582 }
583 });
584
585 quote! {
586 #(#migration_mods)*
587
588 pub const MIGRATIONS: &[&::cot::db::migrations::SyncDynMigration] = &[
590 #(#migration_refs),*
591 ];
592 }
593 }
594
595 fn get_src_path(&self) -> PathBuf {
596 self.options
597 .output_dir
598 .clone()
599 .unwrap_or(self.cargo_toml_path.parent().unwrap().join("src"))
600 }
601}
602
603#[derive(Debug, Clone)]
604pub struct SourceFile {
605 path: PathBuf,
606 content: syn::File,
607}
608
609impl SourceFile {
610 #[must_use]
611 fn new(path: PathBuf, content: syn::File) -> Self {
612 assert!(
613 path.is_relative(),
614 "path must be relative to the src directory"
615 );
616 Self { path, content }
617 }
618
619 pub fn parse(path: PathBuf, content: &str) -> anyhow::Result<Self> {
620 Ok(Self::new(
621 path,
622 syn::parse_file(content).with_context(|| "unable to parse file")?,
623 ))
624 }
625}
626
627#[derive(Debug, Clone)]
628struct AppState {
629 models: Vec<ModelInSource>,
631 migrations: Vec<Migration>,
633}
634
635impl AppState {
636 #[must_use]
637 fn new() -> Self {
638 Self {
639 models: Vec::new(),
640 migrations: Vec::new(),
641 }
642 }
643}
644
645#[derive(Debug, Clone)]
647struct MigrationProcessor {
648 migrations: Vec<Migration>,
649}
650
651impl MigrationProcessor {
652 fn new(mut migrations: Vec<Migration>) -> anyhow::Result<Self> {
653 MigrationEngine::sort_migrations(&mut migrations)?;
654 Ok(Self { migrations })
655 }
656
657 #[must_use]
666 fn latest_models(&self) -> Vec<ModelInSource> {
667 let mut migration_models: HashMap<String, &ModelInSource> = HashMap::new();
668 for migration in &self.migrations {
669 for model in &migration.models {
670 migration_models.insert(model.model.table_name.clone(), model);
671 }
672 }
673
674 migration_models.into_values().cloned().collect()
675 }
676
677 fn next_migration_name(&self) -> anyhow::Result<String> {
678 if self.migrations.is_empty() {
679 return Ok(format!("{MIGRATIONS_MODULE_PREFIX}0001_initial"));
680 }
681
682 let last_migration = self.migrations.last().unwrap();
683 let last_migration_number = last_migration
684 .name
685 .split('_')
686 .nth(1)
687 .with_context(|| format!("migration number not found: {}", last_migration.name))?
688 .parse::<u32>()
689 .with_context(|| {
690 format!("unable to parse migration number: {}", last_migration.name)
691 })?;
692
693 let migration_number = last_migration_number + 1;
694 let now = chrono::Utc::now();
695 let date_time = now.format("%Y%m%d_%H%M%S");
696
697 Ok(format!(
698 "{MIGRATIONS_MODULE_PREFIX}{migration_number:04}_auto_{date_time}"
699 ))
700 }
701
702 fn base_dependencies(&self) -> Vec<DynDependency> {
705 if self.migrations.is_empty() {
706 return Vec::new();
707 }
708
709 let last_migration = self.migrations.last().unwrap();
710 vec![DynDependency::Migration {
711 app: last_migration.app_name.clone(),
712 migration: last_migration.name.clone(),
713 }]
714 }
715}
716
717#[derive(Debug, Clone, PartialEq, Eq, Hash)]
718pub struct ModelInSource {
719 model_item: syn::ItemStruct,
720 model: Model,
721}
722
723impl ModelInSource {
724 fn from_item(
725 item: syn::ItemStruct,
726 args: &ModelArgs,
727 symbol_resolver: &SymbolResolver,
728 ) -> anyhow::Result<Self> {
729 let input: syn::DeriveInput = item.clone().into();
730 let opts = ModelOpts::new_from_derive_input(&input)
731 .map_err(|e| anyhow::anyhow!("cannot parse model: {}", e))?;
732 let model = opts.as_model(args, symbol_resolver)?;
733
734 Ok(Self {
735 model_item: item,
736 model,
737 })
738 }
739}
740
741#[derive(Debug, Clone)]
744pub struct GeneratedMigration {
745 pub migration_name: String,
746 pub modified_models: Vec<ModelInSource>,
747 pub dependencies: Vec<DynDependency>,
748 pub operations: Vec<DynOperation>,
749}
750
751impl GeneratedMigration {
752 #[must_use]
753 fn new(
754 migration_name: String,
755 modified_models: Vec<ModelInSource>,
756 mut dependencies: Vec<DynDependency>,
757 mut operations: Vec<DynOperation>,
758 ) -> Self {
759 Self::remove_cycles(&mut operations);
760 Self::toposort_operations(&mut operations);
761 dependencies.extend(Self::get_foreign_key_dependencies(&operations));
762
763 Self {
764 migration_name,
765 modified_models,
766 dependencies,
767 operations,
768 }
769 }
770
771 fn get_foreign_key_dependencies(operations: &[DynOperation]) -> Vec<DynDependency> {
774 let create_ops = Self::get_create_ops_map(operations);
775 let ops_adding_foreign_keys = Self::get_ops_adding_foreign_keys(operations);
776
777 let mut dependencies = Vec::new();
778 for (_index, dependency_ty) in &ops_adding_foreign_keys {
779 if !create_ops.contains_key(dependency_ty) {
780 dependencies.push(DynDependency::Model {
781 model_type: dependency_ty.clone(),
782 });
783 }
784 }
785
786 dependencies
787 }
788
789 fn remove_cycles(operations: &mut Vec<DynOperation>) {
800 let graph = Self::construct_dependency_graph(operations);
801
802 let cycle_edges = petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&graph);
803 for edge_id in cycle_edges {
804 let (from, to) = graph
805 .edge_endpoints(edge_id.id())
806 .expect("greedy_feedback_arc_set should always return valid edge refs");
807
808 let to_op = operations[to.index()].clone();
809 let from_op = &mut operations[from.index()];
810 debug!(
811 "Removing cycle by removing operation {:?} that depends on {:?}",
812 from_op, to_op
813 );
814
815 let to_add = Self::remove_dependency(from_op, &to_op);
816 operations.extend(to_add);
817 }
818 }
819
820 #[must_use]
826 fn remove_dependency(from: &mut DynOperation, to: &DynOperation) -> Vec<DynOperation> {
827 match from {
828 DynOperation::CreateModel {
829 table_name,
830 model_ty,
831 fields,
832 } => {
833 let to_type = match to {
834 DynOperation::CreateModel { model_ty, .. } => model_ty,
835 DynOperation::AddField { .. } => {
836 unreachable!(
837 "AddField operation shouldn't be a dependency of CreateModel \
838 because it doesn't create a new model"
839 )
840 }
841 };
842 trace!(
843 "Removing foreign keys from {} to {}",
844 model_ty.to_token_stream().to_string(),
845 to_type.into_token_stream().to_string()
846 );
847
848 let mut result = Vec::new();
849 let (fields_to_remove, fields_to_retain): (Vec<_>, Vec<_>) = std::mem::take(fields)
850 .into_iter()
851 .partition(|field| is_field_foreign_key_to(field, to_type));
852 *fields = fields_to_retain;
853
854 for field in fields_to_remove {
855 result.push(DynOperation::AddField {
856 table_name: table_name.clone(),
857 model_ty: model_ty.clone(),
858 field,
859 });
860 }
861
862 result
863 }
864 DynOperation::AddField { .. } => {
865 unreachable!("AddField operation should never create cycles")
868 }
869 }
870 }
871
872 fn toposort_operations(operations: &mut [DynOperation]) {
885 let graph = Self::construct_dependency_graph(operations);
886
887 let sorted = petgraph::algo::toposort(&graph, None)
888 .expect("cycles shouldn't exist after removing them");
889 let mut sorted = sorted
890 .into_iter()
891 .map(petgraph::graph::NodeIndex::index)
892 .collect::<Vec<_>>();
893 cot::__private::apply_permutation(operations, &mut sorted);
894 }
895
896 #[must_use]
903 fn construct_dependency_graph(operations: &[DynOperation]) -> DiGraph<usize, (), usize> {
904 let create_ops = Self::get_create_ops_map(operations);
905 let ops_adding_foreign_keys = Self::get_ops_adding_foreign_keys(operations);
906
907 let mut graph = DiGraph::with_capacity(operations.len(), 0);
908
909 for i in 0..operations.len() {
910 graph.add_node(i);
911 }
912 for (i, dependency_ty) in &ops_adding_foreign_keys {
913 if let Some(&dependency) = create_ops.get(dependency_ty) {
914 graph.add_edge(
915 petgraph::graph::NodeIndex::new(dependency),
916 petgraph::graph::NodeIndex::new(*i),
917 (),
918 );
919 }
920 }
921
922 graph
923 }
924
925 #[must_use]
928 fn get_create_ops_map(operations: &[DynOperation]) -> HashMap<syn::Type, usize> {
929 #[allow(clippy::match_wildcard_for_single_variants)] operations
931 .iter()
932 .enumerate()
933 .filter_map(|(i, op)| match op {
934 DynOperation::CreateModel { model_ty, .. } => Some((model_ty.clone(), i)),
935 _ => None,
936 })
937 .collect()
938 }
939
940 #[must_use]
943 fn get_ops_adding_foreign_keys(operations: &[DynOperation]) -> Vec<(usize, syn::Type)> {
944 operations
945 .iter()
946 .enumerate()
947 .flat_map(|(i, op)| match op {
948 DynOperation::CreateModel { fields, .. } => fields
949 .iter()
950 .filter_map(foreign_key_for_field)
951 .map(|to_model| (i, to_model))
952 .collect::<Vec<(usize, syn::Type)>>(),
953 DynOperation::AddField {
954 field, model_ty, ..
955 } => {
956 let mut ops = vec![(i, model_ty.clone())];
957
958 if let Some(to_type) = foreign_key_for_field(field) {
959 ops.push((i, to_type));
960 }
961
962 ops
963 }
964 })
965 .collect()
966 }
967}
968
969#[derive(Debug, Clone)]
971pub struct MigrationAsSource {
972 pub name: String,
973 pub content: String,
974}
975
976impl MigrationAsSource {
977 #[must_use]
978 pub(crate) fn new(name: String, content: String) -> Self {
979 Self { name, content }
980 }
981}
982
983#[must_use]
984fn is_model_attr(attr: &syn::Attribute) -> bool {
985 let path = attr.path();
986
987 let model_path: syn::Path = parse_quote!(cot::db::model);
988 let model_path_prefixed: syn::Path = parse_quote!(::cot::db::model);
989
990 attr.style == syn::AttrStyle::Outer
991 && (path.is_ident("model") || path == &model_path || path == &model_path_prefixed)
992}
993
994trait Repr {
995 fn repr(&self) -> TokenStream;
996}
997
998impl Repr for Field {
999 fn repr(&self) -> TokenStream {
1000 let column_name = &self.column_name;
1001 let ty = &self.ty;
1002 let mut tokens = quote! {
1003 ::cot::db::migrations::Field::new(::cot::db::Identifier::new(#column_name), <#ty as ::cot::db::DatabaseField>::TYPE)
1004 };
1005 if self.auto_value {
1006 tokens = quote! { #tokens.auto() }
1007 }
1008 if self.primary_key {
1009 tokens = quote! { #tokens.primary_key() }
1010 }
1011 if let Some(fk_spec) = self.foreign_key.clone() {
1012 let to_model = &fk_spec.to_model;
1013
1014 tokens = quote! {
1015 #tokens.foreign_key(
1016 <#to_model as ::cot::db::Model>::TABLE_NAME,
1017 <#to_model as ::cot::db::Model>::PRIMARY_KEY_NAME,
1018 ::cot::db::ForeignKeyOnDeletePolicy::Restrict,
1019 ::cot::db::ForeignKeyOnUpdatePolicy::Restrict,
1020 )
1021 }
1022 }
1023 tokens = quote! { #tokens.set_null(<#ty as ::cot::db::DatabaseField>::NULLABLE) };
1024 if self.unique {
1025 tokens = quote! { #tokens.unique() }
1026 }
1027 tokens
1028 }
1029}
1030
1031#[derive(Debug, Clone, PartialEq, Eq)]
1032struct Migration {
1033 app_name: String,
1034 name: String,
1035 models: Vec<ModelInSource>,
1036}
1037
1038impl DynMigration for Migration {
1039 fn app_name(&self) -> &str {
1040 &self.app_name
1041 }
1042
1043 fn name(&self) -> &str {
1044 &self.name
1045 }
1046
1047 fn dependencies(&self) -> &[cot::db::migrations::MigrationDependency] {
1048 &[]
1049 }
1050
1051 fn operations(&self) -> &[cot::db::migrations::Operation] {
1052 &[]
1053 }
1054}
1055
1056#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1061pub enum DynDependency {
1062 Migration { app: String, migration: String },
1063 Model { model_type: syn::Type },
1064}
1065
1066impl Repr for DynDependency {
1067 fn repr(&self) -> TokenStream {
1068 match self {
1069 Self::Migration { app, migration } => {
1070 quote! {
1071 ::cot::db::migrations::MigrationDependency::migration(#app, #migration)
1072 }
1073 }
1074 Self::Model { model_type } => {
1075 quote! {
1076 ::cot::db::migrations::MigrationDependency::model(
1077 <#model_type as ::cot::db::Model>::APP_NAME,
1078 <#model_type as ::cot::db::Model>::TABLE_NAME
1079 )
1080 }
1081 }
1082 }
1083 }
1084}
1085
1086#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1091pub enum DynOperation {
1092 CreateModel {
1093 table_name: String,
1094 model_ty: syn::Type,
1095 fields: Vec<Field>,
1096 },
1097 AddField {
1098 table_name: String,
1099 model_ty: syn::Type,
1100 field: Field,
1101 },
1102}
1103
1104fn is_field_foreign_key_to(field: &Field, ty: &syn::Type) -> bool {
1106 foreign_key_for_field(field).is_some_and(|to_model| &to_model == ty)
1107}
1108
1109fn foreign_key_for_field(field: &Field) -> Option<syn::Type> {
1112 match field.foreign_key.clone() {
1113 None => None,
1114 Some(foreign_key_spec) => Some(foreign_key_spec.to_model),
1115 }
1116}
1117
1118impl Repr for DynOperation {
1119 fn repr(&self) -> TokenStream {
1120 match self {
1121 Self::CreateModel {
1122 table_name, fields, ..
1123 } => {
1124 let fields = fields.iter().map(Repr::repr).collect::<Vec<_>>();
1125 quote! {
1126 ::cot::db::migrations::Operation::create_model()
1127 .table_name(::cot::db::Identifier::new(#table_name))
1128 .fields(&[
1129 #(#fields,)*
1130 ])
1131 .build()
1132 }
1133 }
1134 Self::AddField {
1135 table_name, field, ..
1136 } => {
1137 let field = field.repr();
1138 quote! {
1139 ::cot::db::migrations::Operation::add_field()
1140 .table_name(::cot::db::Identifier::new(#table_name))
1141 .field(#field)
1142 .build()
1143 }
1144 }
1145 }
1146 }
1147}
1148
1149#[derive(Debug)]
1150struct ParsingError {
1151 message: String,
1152 path: PathBuf,
1153 location: String,
1154 source: Option<String>,
1155}
1156
1157impl ParsingError {
1158 fn from_darling(message: &str, path: PathBuf, error: &darling::Error) -> Self {
1159 let message = format!("{message}: {error}");
1160 let span = error.span();
1161 let location = format!("{}:{}", span.start().line, span.start().column);
1162
1163 Self {
1164 message,
1165 path,
1166 location,
1167 source: span.source_text().clone(),
1168 }
1169 }
1170}
1171
1172impl Display for ParsingError {
1173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1174 write!(f, "{}", self.message)?;
1175 if let Some(source) = &self.source {
1176 write!(f, "\n{source}")?;
1177 }
1178 write!(f, "\n at {}:{}", self.path.display(), self.location)?;
1179 Ok(())
1180 }
1181}
1182
1183impl Error for ParsingError {}
1184
1185#[cfg(test)]
1186mod tests {
1187 use cot_codegen::model::ForeignKeySpec;
1188
1189 use super::*;
1190
1191 #[test]
1192 fn migration_processor_next_migration_name_empty() {
1193 let migrations = vec![];
1194 let processor = MigrationProcessor::new(migrations).unwrap();
1195
1196 let next_migration_name = processor.next_migration_name().unwrap();
1197 assert_eq!(next_migration_name, "m_0001_initial");
1198 }
1199
1200 #[test]
1201 fn migration_processor_dependencies_empty() {
1202 let migrations = vec![];
1203 let processor = MigrationProcessor::new(migrations).unwrap();
1204
1205 let next_migration_name = processor.base_dependencies();
1206 assert_eq!(next_migration_name, vec![]);
1207 }
1208
1209 #[test]
1210 fn migration_processor_dependencies_previous() {
1211 let migrations = vec![Migration {
1212 app_name: "app1".to_string(),
1213 name: "m0001_initial".to_string(),
1214 models: vec![],
1215 }];
1216 let processor = MigrationProcessor::new(migrations).unwrap();
1217
1218 let next_migration_name = processor.base_dependencies();
1219 assert_eq!(
1220 next_migration_name,
1221 vec![DynDependency::Migration {
1222 app: "app1".to_string(),
1223 migration: "m0001_initial".to_string(),
1224 }]
1225 );
1226 }
1227
1228 #[test]
1229 fn toposort_operations() {
1230 let mut operations = vec![
1231 DynOperation::AddField {
1232 table_name: "table2".to_string(),
1233 model_ty: parse_quote!(Table2),
1234 field: Field {
1235 field_name: format_ident!("field1"),
1236 column_name: "field1".to_string(),
1237 ty: parse_quote!(i32),
1238 auto_value: false,
1239 primary_key: false,
1240 unique: false,
1241 foreign_key: Some(ForeignKeySpec {
1242 to_model: parse_quote!(Table1),
1243 }),
1244 },
1245 },
1246 DynOperation::CreateModel {
1247 table_name: "table1".to_string(),
1248 model_ty: parse_quote!(Table1),
1249 fields: vec![],
1250 },
1251 ];
1252
1253 GeneratedMigration::toposort_operations(&mut operations);
1254
1255 assert_eq!(operations.len(), 2);
1256 if let DynOperation::CreateModel { table_name, .. } = &operations[0] {
1257 assert_eq!(table_name, "table1");
1258 } else {
1259 panic!("Expected CreateModel operation");
1260 }
1261 if let DynOperation::AddField { table_name, .. } = &operations[1] {
1262 assert_eq!(table_name, "table2");
1263 } else {
1264 panic!("Expected AddField operation");
1265 }
1266 }
1267
1268 #[test]
1269 fn remove_cycles() {
1270 let mut operations = vec![
1271 DynOperation::CreateModel {
1272 table_name: "table1".to_string(),
1273 model_ty: parse_quote!(Table1),
1274 fields: vec![Field {
1275 field_name: format_ident!("field1"),
1276 column_name: "field1".to_string(),
1277 ty: parse_quote!(ForeignKey<Table2>),
1278 auto_value: false,
1279 primary_key: false,
1280 unique: false,
1281 foreign_key: Some(ForeignKeySpec {
1282 to_model: parse_quote!(Table2),
1283 }),
1284 }],
1285 },
1286 DynOperation::CreateModel {
1287 table_name: "table2".to_string(),
1288 model_ty: parse_quote!(Table2),
1289 fields: vec![Field {
1290 field_name: format_ident!("field1"),
1291 column_name: "field1".to_string(),
1292 ty: parse_quote!(ForeignKey<Table1>),
1293 auto_value: false,
1294 primary_key: false,
1295 unique: false,
1296 foreign_key: Some(ForeignKeySpec {
1297 to_model: parse_quote!(Table1),
1298 }),
1299 }],
1300 },
1301 ];
1302
1303 GeneratedMigration::remove_cycles(&mut operations);
1304
1305 assert_eq!(operations.len(), 3);
1306 if let DynOperation::CreateModel {
1307 table_name, fields, ..
1308 } = &operations[0]
1309 {
1310 assert_eq!(table_name, "table1");
1311 assert!(!fields.is_empty());
1312 } else {
1313 panic!("Expected CreateModel operation");
1314 }
1315 if let DynOperation::CreateModel {
1316 table_name, fields, ..
1317 } = &operations[1]
1318 {
1319 assert_eq!(table_name, "table2");
1320 assert!(fields.is_empty());
1321 } else {
1322 panic!("Expected CreateModel operation");
1323 }
1324 if let DynOperation::AddField { table_name, .. } = &operations[2] {
1325 assert_eq!(table_name, "table2");
1326 } else {
1327 panic!("Expected AddField operation");
1328 }
1329 }
1330
1331 #[test]
1332 fn remove_dependency() {
1333 let mut create_model_op = DynOperation::CreateModel {
1334 table_name: "table1".to_string(),
1335 model_ty: parse_quote!(Table1),
1336 fields: vec![Field {
1337 field_name: format_ident!("field1"),
1338 column_name: "field1".to_string(),
1339 ty: parse_quote!(ForeignKey<Table2>),
1340 auto_value: false,
1341 primary_key: false,
1342 unique: false,
1343 foreign_key: Some(ForeignKeySpec {
1344 to_model: parse_quote!(Table2),
1345 }),
1346 }],
1347 };
1348
1349 let add_field_op = DynOperation::CreateModel {
1350 table_name: "table2".to_string(),
1351 model_ty: parse_quote!(Table2),
1352 fields: vec![],
1353 };
1354
1355 let additional_ops =
1356 GeneratedMigration::remove_dependency(&mut create_model_op, &add_field_op);
1357
1358 match create_model_op {
1359 DynOperation::CreateModel { fields, .. } => {
1360 assert_eq!(fields.len(), 0);
1361 }
1362 _ => {
1363 panic!("Expected from operation not to change type");
1364 }
1365 }
1366 assert_eq!(additional_ops.len(), 1);
1367 if let DynOperation::AddField { table_name, .. } = &additional_ops[0] {
1368 assert_eq!(table_name, "table1");
1369 } else {
1370 panic!("Expected AddField operation");
1371 }
1372 }
1373
1374 #[test]
1375 fn get_foreign_key_dependencies_no_foreign_keys() {
1376 let operations = vec![DynOperation::CreateModel {
1377 table_name: "table1".to_string(),
1378 model_ty: parse_quote!(Table1),
1379 fields: vec![],
1380 }];
1381
1382 let external_dependencies = GeneratedMigration::get_foreign_key_dependencies(&operations);
1383 assert!(external_dependencies.is_empty());
1384 }
1385
1386 #[test]
1387 fn get_foreign_key_dependencies_with_foreign_keys() {
1388 let operations = vec![DynOperation::CreateModel {
1389 table_name: "table1".to_string(),
1390 model_ty: parse_quote!(Table1),
1391 fields: vec![Field {
1392 field_name: format_ident!("field1"),
1393 column_name: "field1".to_string(),
1394 ty: parse_quote!(ForeignKey<Table2>),
1395 auto_value: false,
1396 primary_key: false,
1397 unique: false,
1398 foreign_key: Some(ForeignKeySpec {
1399 to_model: parse_quote!(crate::Table2),
1400 }),
1401 }],
1402 }];
1403
1404 let external_dependencies = GeneratedMigration::get_foreign_key_dependencies(&operations);
1405 assert_eq!(external_dependencies.len(), 1);
1406 assert_eq!(
1407 external_dependencies[0],
1408 DynDependency::Model {
1409 model_type: parse_quote!(crate::Table2),
1410 }
1411 );
1412 }
1413
1414 #[test]
1415 fn get_foreign_key_dependencies_with_multiple_foreign_keys() {
1416 let operations = vec![
1417 DynOperation::CreateModel {
1418 table_name: "table1".to_string(),
1419 model_ty: parse_quote!(Table1),
1420 fields: vec![Field {
1421 field_name: format_ident!("field1"),
1422 column_name: "field1".to_string(),
1423 ty: parse_quote!(ForeignKey<Table2>),
1424 auto_value: false,
1425 primary_key: false,
1426 unique: false,
1427 foreign_key: Some(ForeignKeySpec {
1428 to_model: parse_quote!(my_crate::Table2),
1429 }),
1430 }],
1431 },
1432 DynOperation::CreateModel {
1433 table_name: "table3".to_string(),
1434 model_ty: parse_quote!(Table3),
1435 fields: vec![Field {
1436 field_name: format_ident!("field2"),
1437 column_name: "field2".to_string(),
1438 ty: parse_quote!(ForeignKey<Table4>),
1439 auto_value: false,
1440 primary_key: false,
1441 unique: false,
1442 foreign_key: Some(ForeignKeySpec {
1443 to_model: parse_quote!(crate::Table4),
1444 }),
1445 }],
1446 },
1447 ];
1448
1449 let external_dependencies = GeneratedMigration::get_foreign_key_dependencies(&operations);
1450 assert_eq!(external_dependencies.len(), 2);
1451 assert!(external_dependencies.contains(&DynDependency::Model {
1452 model_type: parse_quote!(my_crate::Table2),
1453 }));
1454 assert!(external_dependencies.contains(&DynDependency::Model {
1455 model_type: parse_quote!(crate::Table4),
1456 }));
1457 }
1458
1459 #[test]
1460 fn make_add_field_operation() {
1461 let app_model = ModelInSource {
1462 model_item: parse_quote! {
1463 struct TestModel {
1464 #[model(primary_key)]
1465 id: i32,
1466 field1: i32,
1467 }
1468 },
1469 model: Model {
1470 name: format_ident!("TestModel"),
1471 vis: syn::Visibility::Inherited,
1472 original_name: "TestModel".to_string(),
1473 resolved_ty: parse_quote!(TestModel),
1474 model_type: Default::default(),
1475 table_name: "test_model".to_string(),
1476 pk_field: Field {
1477 field_name: format_ident!("id"),
1478 column_name: "id".to_string(),
1479 ty: parse_quote!(i32),
1480 auto_value: true,
1481 primary_key: true,
1482 unique: false,
1483 foreign_key: None,
1484 },
1485 fields: vec![],
1486 },
1487 };
1488
1489 let field = Field {
1490 field_name: format_ident!("new_field"),
1491 column_name: "new_field".to_string(),
1492 ty: parse_quote!(i32),
1493 auto_value: false,
1494 primary_key: false,
1495 unique: false,
1496 foreign_key: None,
1497 };
1498
1499 let operation = MigrationGenerator::make_add_field_operation(&app_model, &field);
1500
1501 match operation {
1502 DynOperation::AddField {
1503 table_name,
1504 model_ty,
1505 field: op_field,
1506 } => {
1507 assert_eq!(table_name, "test_model");
1508 assert_eq!(model_ty, parse_quote!(TestModel));
1509 assert_eq!(op_field.column_name, "new_field");
1510 assert_eq!(op_field.ty, parse_quote!(i32));
1511 }
1512 _ => panic!("Expected AddField operation"),
1513 }
1514 }
1515
1516 #[test]
1517 fn get_migration_list() {
1518 let tempdir = tempfile::tempdir().unwrap();
1519 let migrations_dir = tempdir.path().join("migrations");
1520 std::fs::create_dir(&migrations_dir).unwrap();
1521
1522 File::create(migrations_dir.join("m_0001_initial.rs")).unwrap();
1523 File::create(migrations_dir.join("m_0002_auto.rs")).unwrap();
1524 File::create(migrations_dir.join("dummy.rs")).unwrap();
1525 File::create(migrations_dir.join("m_0003_not_rust_file.txt")).unwrap();
1526
1527 let migration_list = MigrationGenerator::get_migration_list(&migrations_dir).unwrap();
1528 assert_eq!(
1529 migration_list.len(),
1530 2,
1531 "Migration list: {migration_list:?}"
1532 );
1533 assert!(migration_list.contains(&"m_0001_initial".to_string()));
1534 assert!(migration_list.contains(&"m_0002_auto".to_string()));
1535 }
1536
1537 #[test]
1538 fn get_migration_module_contents() {
1539 let contents = MigrationGenerator::get_migration_module_contents(&[
1540 "m_0001_initial".to_string(),
1541 "m_0002_auto".to_string(),
1542 ]);
1543
1544 let expected = quote! {
1545 pub mod m_0001_initial;
1546 pub mod m_0002_auto;
1547
1548 pub const MIGRATIONS: &[&::cot::db::migrations::SyncDynMigration] = &[
1550 &m_0001_initial::Migration,
1551 &m_0002_auto::Migration
1552 ];
1553 };
1554
1555 assert_eq!(contents.to_string(), expected.to_string());
1556 }
1557}