ormlite_core/
schema.rs

1use crate::config::Config;
2use anyhow::Result as AnyResult;
3use ormlite_attr::ModelMeta;
4use ormlite_attr::Type;
5use ormlite_attr::{schema_from_filepaths, ColumnMeta, Ident, InnerType};
6use sql::{schema::Column, Constraint, Schema, Table};
7use std::collections::HashMap;
8use std::path::Path;
9
10pub fn schema_from_ormlite_project(paths: &[&Path], c: &Config) -> AnyResult<Schema> {
11    let mut schema = Schema::default();
12    let mut fs_schema = schema_from_filepaths(paths)?;
13    let primary_key_type: HashMap<String, InnerType> = fs_schema
14        .tables
15        .iter()
16        .map(|t| {
17            let pkey_ty = t.pkey.ty.inner_type().clone();
18            (t.ident.to_string(), pkey_ty)
19        })
20        .collect();
21    for t in &mut fs_schema.tables {
22        for c in &mut t.table.columns {
23            // replace alias types with the real type.
24            let inner = c.ty.inner_type_mut();
25            if let Some(f) = fs_schema.type_reprs.get(&inner.ident.to_string()) {
26                inner.ident = Ident::from(f);
27            }
28            // replace join types with the primary key type.
29            if c.ty.is_join() {
30                let model_name = c.ty.inner_type_name();
31                let pkey = primary_key_type
32                    .get(&model_name)
33                    .expect(&format!("Could not find model {} for join", model_name));
34                c.ty = Type::Inner(pkey.clone());
35            }
36        }
37    }
38    for table in fs_schema.tables {
39        let table = Table::from_meta(&table);
40        schema.tables.push(table);
41    }
42    let mut table_names: HashMap<String, (String, String)> = schema
43        .tables
44        .iter()
45        .map(|t| (t.name.clone(), (t.name.clone(), t.primary_key().unwrap().name.clone())))
46        .collect();
47    for (alias, real) in &c.table.aliases {
48        let Some(real) = table_names.get(real) else {
49            continue;
50        };
51        table_names.insert(alias.clone(), real.clone());
52    }
53    for table in &mut schema.tables {
54        for column in &mut table.columns {
55            if column.primary_key {
56                continue;
57            }
58            if column.name.ends_with("_id") || column.name.ends_with("_uuid") {
59                let Some((model_name, _)) = column.name.rsplit_once('_') else {
60                    continue;
61                };
62                if let Some((t, pkey)) = table_names.get(model_name) {
63                    let constraint = Constraint::foreign_key(t.to_string(), vec![pkey.clone()]);
64                    column.constraint = Some(constraint);
65                }
66            }
67        }
68    }
69    Ok(schema)
70}
71
72#[derive(Debug)]
73pub struct Options {
74    pub verbose: bool,
75}
76
77pub trait FromMeta: Sized {
78    type Input;
79    fn from_meta(meta: &Self::Input) -> Self;
80}
81
82impl FromMeta for Table {
83    type Input = ModelMeta;
84    fn from_meta(model: &ModelMeta) -> Self {
85        let columns = model
86            .columns
87            .iter()
88            .flat_map(|c| {
89                if c.skip {
90                    return None;
91                }
92                let mut col = Option::<Column>::from_meta(c)?;
93                col.primary_key = model.pkey.name == col.name;
94                Some(col)
95            })
96            .collect();
97        Self {
98            schema: None,
99            name: model.name.clone(),
100            columns,
101        }
102    }
103}
104
105impl FromMeta for Option<Column> {
106    type Input = ColumnMeta;
107    fn from_meta(meta: &Self::Input) -> Self {
108        let mut ty = Nullable::from_type(&meta.ty)?;
109        if meta.json {
110            ty.ty = sql::Type::Jsonb;
111        }
112        Some(Column {
113            name: meta.name.clone(),
114            typ: ty.ty,
115            default: None,
116            nullable: ty.nullable,
117            primary_key: meta.marked_primary_key,
118            constraint: None,
119            generated: None,
120        })
121    }
122}
123
124struct Nullable {
125    pub ty: sql::Type,
126    pub nullable: bool,
127}
128
129impl From<sql::Type> for Nullable {
130    fn from(value: sql::Type) -> Self {
131        Self {
132            ty: value,
133            nullable: false,
134        }
135    }
136}
137
138impl Nullable {
139    fn from_type(ty: &Type) -> Option<Self> {
140        use sql::Type::*;
141        match ty {
142            Type::Vec(v) => {
143                if let Type::Inner(p) = v.as_ref() {
144                    if p.ident == "u8" {
145                        return Some(Nullable {
146                            ty: Bytes,
147                            nullable: false,
148                        });
149                    }
150                }
151                let v = Self::from_type(v.as_ref())?;
152                Some(Nullable {
153                    ty: Array(Box::new(v.ty)),
154                    nullable: false,
155                })
156            }
157            Type::Inner(p) => {
158                let ident = p.ident.to_string();
159                let ty = match ident.as_str() {
160                    // signed
161                    "i8" => I16,
162                    "i16" => I16,
163                    "i32" => I32,
164                    "i64" => I64,
165                    "i128" => Decimal,
166                    "isize" => I64,
167                    // unsigned
168                    "u8" => I16,
169                    "u16" => I32,
170                    "u32" => I64,
171                    // Turns out postgres doesn't support u64.
172                    "u64" => Decimal,
173                    "u128" => Decimal,
174                    "usize" => Decimal,
175                    // float
176                    "f32" => F32,
177                    "f64" => F64,
178                    // bool
179                    "bool" => Boolean,
180                    // string
181                    "String" => Text,
182                    "str" => Text,
183                    // date
184                    "DateTime" => DateTime,
185                    "NaiveDate" => Date,
186                    "NaiveTime" => DateTime,
187                    "NaiveDateTime" => DateTime,
188                    // decimal
189                    "Decimal" => Decimal,
190                    // uuid
191                    "Uuid" => Uuid,
192                    // json
193                    "Json" => Jsonb,
194                    z => Other(z.to_string()),
195                };
196                Some(Nullable { ty, nullable: false })
197            }
198            Type::Option(o) => {
199                let inner = Self::from_type(o)?;
200                Some(Nullable {
201                    ty: inner.ty,
202                    nullable: true,
203                })
204            }
205            Type::Join(_) => None,
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use anyhow::Result;
214    use assert_matches::assert_matches;
215    use ormlite_attr::Type;
216    use syn::parse_str;
217
218    #[test]
219    fn test_convert_type() -> Result<()> {
220        use sql::Type as SqlType;
221        let s = Type::from(&parse_str::<syn::Path>("String").unwrap());
222        assert_matches!(Nullable::from_type(&s).unwrap().ty, SqlType::Text);
223        let s = Type::from(&parse_str::<syn::Path>("u32").unwrap());
224        assert_matches!(Nullable::from_type(&s).unwrap().ty, SqlType::I64);
225        let s = Type::from(&parse_str::<syn::Path>("Option<String>").unwrap());
226        let s = Nullable::from_type(&s).unwrap();
227        assert_matches!(s.ty, SqlType::Text);
228        assert!(s.nullable);
229        Ok(())
230    }
231
232    #[test]
233    fn test_support_vec() {
234        use sql::Type as SqlType;
235        let s = Type::from(&parse_str::<syn::Path>("Vec<Uuid>").unwrap());
236        let SqlType::Array(inner) = Nullable::from_type(&s).unwrap().ty else {
237            panic!("Expected array");
238        };
239        assert_eq!(*inner, SqlType::Uuid);
240    }
241}