1use std::collections::BTreeMap;
2use std::fmt::Formatter;
3
4use std::path::Path;
5use anyhow::Result;
6use sqlmo::{Schema, Table, schema::Column};
7use ormlitex_attr::{ColumnMetadata, Ident, InnerType, ModelMetadata, TType};
8use ormlitex_attr::{schema_from_filepaths};
9
10#[derive(Debug)]
11pub struct Options {
12 pub verbose: bool,
13}
14
15pub trait TryFromormlitex: Sized {
16 fn try_from_ormlitex_project(path: &[&Path]) -> Result<Self>;
17}
18
19trait SqlDiffTableExt {
20 fn from_metadata(metadata: &ModelMetadata) -> Result<Self, TypeTranslationError> where Self: Sized;
21}
22
23impl SqlDiffTableExt for Table {
24 fn from_metadata(model: &ModelMetadata) -> Result<Self, TypeTranslationError> {
25 Ok(Self {
26 schema: None,
27 name: model.inner.table_name.clone(),
28 columns: model.inner.columns.iter().map(|c| {
29 if c.skip {
30 return Ok(None);
31 }
32 let Some(mut col) = Column::from_metadata(c)? else {
33 return Ok(None);
34 };
35 col.primary_key = model.pkey.column_name == col.name;
36 Ok(Some(col))
37 })
38 .filter_map(|c| c.transpose())
39 .collect::<Result<Vec<_>, _>>()?,
40 indexes: vec![],
41 })
42 }
43}
44
45trait SqlDiffColumnExt {
46 fn from_metadata(metadata: &ColumnMetadata) -> Result<Option<Column>, TypeTranslationError>;
47}
48
49impl SqlDiffColumnExt for Column {
50 fn from_metadata(metadata: &ColumnMetadata) -> Result<Option<Column>, TypeTranslationError> {
51 let Some(ty) = SqlType::from_type(&metadata.column_type) else {
52 return Ok(None);
53 };
54 Ok(Some(Self {
55 name: metadata.column_name.clone(),
56 typ: ty.ty,
57 default: None,
58 nullable: ty.nullable,
59 primary_key: metadata.marked_primary_key,
60 }))
61 }
62}
63
64struct SqlType {
65 pub ty: sqlmo::Type,
66 pub nullable: bool,
67}
68
69impl From<sqlmo::Type> for SqlType {
70 fn from(value: sqlmo::Type) -> Self {
71 Self {
72 ty: value,
73 nullable: false,
74 }
75 }
76}
77
78#[derive(Debug)]
79pub struct TypeTranslationError(pub String);
80
81impl std::error::Error for TypeTranslationError {}
82
83impl std::fmt::Display for TypeTranslationError {
84 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
85 write!(f, "Could not translate type: {}", self.0)
86 }
87}
88
89impl SqlType {
90 fn from_type(ty: &TType) -> Option<Self> {
91 use sqlmo::Type::*;
92 match ty {
93 TType::Vec(v) => {
94 if let TType::Inner(p) = v.as_ref() {
95 if p.ident.0 == "u8" {
96 return Some(SqlType {
97 ty: Bytes,
98 nullable: false,
99 });
100 }
101 }
102 let v = Self::from_type(v.as_ref())?;
103 Some(SqlType {
104 ty: Array(Box::new(v.ty)),
105 nullable: true,
106 })
107 }
108 TType::Inner(p) => {
109 let ident = p.ident.0.as_str();
110 let ty = match ident {
111 "i8" => I16,
113 "i16" => I16,
114 "i32" => I32,
115 "i64" => I64,
116 "i128" => Decimal,
117 "isize" => I64,
118 "u8" => I16,
120 "u16" => I32,
121 "u32" => I64,
122 "u64" => Decimal,
124 "u128" => Decimal,
125 "usize" => Decimal,
126 "f32" => F32,
128 "f64" => F64,
129 "bool" => Boolean,
131 "String" => Text,
133 "str" => Text,
134 "DateTime" => DateTime,
136 "NaiveDate" => Date,
137 "NaiveTime" => DateTime,
138 "NaiveDateTime" => DateTime,
139 "Decimal" => Decimal,
141 "Uuid" => Uuid,
143 "Json" => Jsonb,
145 z => Other(z.to_string()),
146 };
147 Some(SqlType {
148 ty,
149 nullable: false,
150 })
151 }
152 TType::Option(o) => {
153 let inner = Self::from_type(o)?;
154 Some(SqlType {
155 ty: inner.ty,
156 nullable: true,
157 })
158 }
159 TType::Join(_) => {
160 None
161 }
162 }
163 }
164}
165
166impl TryFromormlitex for Schema {
167 fn try_from_ormlitex_project(paths: &[&Path]) -> Result<Self> {
168 let mut schema = Self::default();
169 let mut fs_schema = schema_from_filepaths(paths)?;
170 let primary_key_type: BTreeMap<String, InnerType> = fs_schema.tables.iter().map(|t| {
171 let pkey_ty = t.pkey.column_type.inner_type().clone();
172 (t.inner.struct_name.to_string(), pkey_ty)
173 }).collect();
174 for t in &mut fs_schema.tables {
175 for c in &mut t.inner.columns {
176 let inner = c.column_type.inner_type_mut();
178 if let Some(f) = fs_schema.type_reprs.get(&inner.ident.0) {
179 inner.ident = Ident(f.clone());
180 }
181 if c.column_type.is_join() {
183 let model_name = c.column_type.inner_type_name();
184 let pkey = primary_key_type.get(&model_name).expect(&format!("Could not find model {} for join", model_name));
185 c.column_type = TType::Inner(pkey.clone());
186 }
187 }
188 }
189 for table in fs_schema.tables {
190 let table = Table::from_metadata(&table)?;
191 schema.tables.push(table);
192 }
193 Ok(schema)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use syn::{parse_str};
201 use assert_matches::assert_matches;
202 use ormlitex_attr::TType;
203 use sqlmo::Type;
204
205 #[test]
206 fn test_convert_type() -> Result<()> {
207
208 let s = TType::from(&parse_str::<syn::Path>("String").unwrap());
209 assert_matches!(SqlType::from_type(&s).unwrap().ty, Type::Text);
210 let s = TType::from(&parse_str::<syn::Path>("u32").unwrap());
211 assert_matches!(SqlType::from_type(&s).unwrap().ty, Type::I64);
212 let s = TType::from(&parse_str::<syn::Path>("Option<String>").unwrap());
213 let s = SqlType::from_type(&s).unwrap();
214 assert_matches!(s.ty, Type::Text);
215 assert!(s.nullable);
216 Ok(())
217 }
218
219 #[test]
220 fn test_support_vec() {
221 let s = TType::from(&parse_str::<syn::Path>("Vec<Uuid>").unwrap());
222 let Type::Array(inner) = SqlType::from_type(&s).unwrap().ty else {
223 panic!("Expected array");
224 };
225 assert_eq!(*inner, Type::Uuid);
226
227 }
228}