ormlitex_core/
schema.rs

1use std::collections::BTreeMap;
2use std::fmt::Formatter;
3
4use std::path::Path;
5use anyhow::Result;
6use sqlmo::{Schema, Table, schema::Column};
7use ormlitex_attr::{ColumnMetadata, Ident, InnerType, ModelMetadata, TType};
8use ormlitex_attr::{schema_from_filepaths};
9
10#[derive(Debug)]
11pub struct Options {
12    pub verbose: bool,
13}
14
15pub trait TryFromormlitex: Sized {
16    fn try_from_ormlitex_project(path: &[&Path]) -> Result<Self>;
17}
18
19trait SqlDiffTableExt {
20    fn from_metadata(metadata: &ModelMetadata) -> Result<Self, TypeTranslationError> where Self: Sized;
21}
22
23impl SqlDiffTableExt for Table {
24    fn from_metadata(model: &ModelMetadata) -> Result<Self, TypeTranslationError> {
25        Ok(Self {
26            schema: None,
27            name: model.inner.table_name.clone(),
28            columns: model.inner.columns.iter().map(|c| {
29                if c.skip {
30                    return Ok(None);
31                }
32                let Some(mut col) = Column::from_metadata(c)? else {
33                    return Ok(None);
34                };
35                col.primary_key = model.pkey.column_name == col.name;
36                Ok(Some(col))
37            })
38                .filter_map(|c| c.transpose())
39                .collect::<Result<Vec<_>, _>>()?,
40            indexes: vec![],
41        })
42    }
43}
44
45trait SqlDiffColumnExt {
46    fn from_metadata(metadata: &ColumnMetadata) -> Result<Option<Column>, TypeTranslationError>;
47}
48
49impl SqlDiffColumnExt for Column {
50    fn from_metadata(metadata: &ColumnMetadata) -> Result<Option<Column>, TypeTranslationError> {
51        let Some(ty) = SqlType::from_type(&metadata.column_type) else {
52            return Ok(None);
53        };
54        Ok(Some(Self {
55            name: metadata.column_name.clone(),
56            typ: ty.ty,
57            default: None,
58            nullable: ty.nullable,
59            primary_key: metadata.marked_primary_key,
60        }))
61    }
62}
63
64struct SqlType {
65    pub ty: sqlmo::Type,
66    pub nullable: bool,
67}
68
69impl From<sqlmo::Type> for SqlType {
70    fn from(value: sqlmo::Type) -> Self {
71        Self {
72            ty: value,
73            nullable: false,
74        }
75    }
76}
77
78#[derive(Debug)]
79pub struct TypeTranslationError(pub String);
80
81impl std::error::Error for TypeTranslationError {}
82
83impl std::fmt::Display for TypeTranslationError {
84    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
85        write!(f, "Could not translate type: {}", self.0)
86    }
87}
88
89impl SqlType {
90    fn from_type(ty: &TType) -> Option<Self> {
91        use sqlmo::Type::*;
92        match ty {
93            TType::Vec(v) => {
94                if let TType::Inner(p) = v.as_ref() {
95                    if p.ident.0 == "u8" {
96                        return Some(SqlType {
97                            ty: Bytes,
98                            nullable: false,
99                        });
100                    }
101                }
102                let v = Self::from_type(v.as_ref())?;
103                Some(SqlType {
104                    ty: Array(Box::new(v.ty)),
105                    nullable: true,
106                })
107            }
108            TType::Inner(p) => {
109                let ident = p.ident.0.as_str();
110                let ty = match ident {
111                    // signed
112                    "i8" => I16,
113                    "i16" => I16,
114                    "i32" => I32,
115                    "i64" => I64,
116                    "i128" => Decimal,
117                    "isize" => I64,
118                    // unsigned
119                    "u8" => I16,
120                    "u16" => I32,
121                    "u32" => I64,
122                    // Turns out postgres doesn't support u64.
123                    "u64" => Decimal,
124                    "u128" => Decimal,
125                    "usize" => Decimal,
126                    // float
127                    "f32" => F32,
128                    "f64" => F64,
129                    // bool
130                    "bool" => Boolean,
131                    // string
132                    "String" => Text,
133                    "str" => Text,
134                    // date
135                    "DateTime" => DateTime,
136                    "NaiveDate" => Date,
137                    "NaiveTime" => DateTime,
138                    "NaiveDateTime" => DateTime,
139                    // decimal
140                    "Decimal" => Decimal,
141                    // uuid
142                    "Uuid" => Uuid,
143                    // json
144                    "Json" => Jsonb,
145                    z => Other(z.to_string()),
146                };
147                Some(SqlType {
148                    ty,
149                    nullable: false,
150                })
151            }
152            TType::Option(o) => {
153                let inner = Self::from_type(o)?;
154                Some(SqlType {
155                    ty: inner.ty,
156                    nullable: true,
157                })
158            }
159            TType::Join(_) => {
160                None
161            }
162        }
163    }
164}
165
166impl TryFromormlitex for Schema {
167    fn try_from_ormlitex_project(paths: &[&Path]) -> Result<Self> {
168        let mut schema = Self::default();
169        let mut fs_schema = schema_from_filepaths(paths)?;
170        let primary_key_type: BTreeMap<String, InnerType> = fs_schema.tables.iter().map(|t|  {
171            let pkey_ty = t.pkey.column_type.inner_type().clone();
172            (t.inner.struct_name.to_string(), pkey_ty)
173        }).collect();
174        for t in &mut fs_schema.tables {
175            for c in &mut t.inner.columns {
176                // replace alias types with the real type.
177                let inner = c.column_type.inner_type_mut();
178                if let Some(f) = fs_schema.type_reprs.get(&inner.ident.0) {
179                    inner.ident = Ident(f.clone());
180                }
181                // replace join types with the primary key type.
182                if c.column_type.is_join() {
183                    let model_name = c.column_type.inner_type_name();
184                    let pkey = primary_key_type.get(&model_name).expect(&format!("Could not find model {} for join", model_name));
185                    c.column_type = TType::Inner(pkey.clone());
186                }
187            }
188        }
189        for table in fs_schema.tables {
190            let table = Table::from_metadata(&table)?;
191            schema.tables.push(table);
192        }
193        Ok(schema)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use syn::{parse_str};
201    use assert_matches::assert_matches;
202    use ormlitex_attr::TType;
203    use sqlmo::Type;
204
205    #[test]
206    fn test_convert_type() -> Result<()> {
207
208        let s = TType::from(&parse_str::<syn::Path>("String").unwrap());
209        assert_matches!(SqlType::from_type(&s).unwrap().ty, Type::Text);
210        let s = TType::from(&parse_str::<syn::Path>("u32").unwrap());
211        assert_matches!(SqlType::from_type(&s).unwrap().ty, Type::I64);
212        let s = TType::from(&parse_str::<syn::Path>("Option<String>").unwrap());
213        let s = SqlType::from_type(&s).unwrap();
214        assert_matches!(s.ty, Type::Text);
215        assert!(s.nullable);
216        Ok(())
217    }
218
219    #[test]
220    fn test_support_vec() {
221        let s = TType::from(&parse_str::<syn::Path>("Vec<Uuid>").unwrap());
222        let Type::Array(inner) = SqlType::from_type(&s).unwrap().ty else {
223            panic!("Expected array");
224        };
225        assert_eq!(*inner, Type::Uuid);
226
227    }
228}