gpkg_derive/
lib.rs

1#![allow(dead_code)]
2use proc_macro2::{Span, TokenStream};
3use quote::quote;
4use std::ops::Deref;
5use syn::{
6    parse2, Attribute, DeriveInput, Field, GenericArgument, GenericParam, Generics, Ident, Lit,
7    LitInt, Meta, MetaNameValue, Type, TypePath, TypeReference,
8};
9
10const GEO_TYPES: &'static [&'static str] = &[
11    "POLYGON",
12    "LINESTRING",
13    "POINT",
14    "MULTIPOLYGON",
15    "MULTILINESTRING",
16    "MULTIPOINT",
17    "POLYGONM",
18    "LINESTRINGM",
19    "POINTM",
20    "MULTIPOLYGONM",
21    "MULTILINESTRINGM",
22    "MULTIPOINTM",
23    "POLYGONZ",
24    "LINESTRINGZ",
25    "POINTZ",
26    "MULTIPOLYGONZ",
27    "MULTILINESTRINGZ",
28    "MULTIPOINTZ",
29    "POLYGONZM",
30    "LINESTRINGZM",
31    "POINTZM",
32    "MULTIPOLYGONZM",
33    "MULTILINESTRINGZM",
34    "MULTIPOINTZM",
35];
36
37/// A macro for deriving an implementation of GPKGModel for a struct
38///
39/// The layer_name attribute controls the name of the SQLite table that instances of this Struct will be read and written as
40///
41/// The geom_field attribute can only be used on one field, and the geometry type will be cast to uppercase
42/// the used as the geomtry type for the layer. If the letters Z and/or M are present in the geometry type,
43/// the corresponding flags will be set within the GeoPackage indicating that the geometry has M or Z values.
44///
45/// When this macro is used, an "object_id" primary key column will be created in order to comply with the specifcation,
46/// but will be transparent to you as a user of this crate
47///
48/// When using this macro for reading an existing GeoPackage layer, any unspecified columns will not be read.
49/// # Usage
50/// ```ignore
51/// # // would be great to get this test working, but I'm not sure how to do it without curculare dependency issues
52/// # use gpkg_derive::GPKGModel;
53/// # use gpkg::types::{GPKGPoint, GPKGPointZ};
54///
55/// #[derive(GPKGModel)]
56/// #[layer_name = "test_table"]
57/// struct TestTable {
58///     field1: i64,
59///     field2: i32,
60///     #[geom_field("Point")]
61///     shape: GPKGPoint,
62/// }
63///
64/// #[derive(GPKGModel)]
65/// #[layer_name = "test_tableZ"]
66/// struct TestTableZ {
67///     field1: i64,
68///     field2: i32,
69///     #[geom_field("PointZ")]
70///     shape: GPKGPointZ,
71/// }
72#[proc_macro_derive(GPKGModel, attributes(layer_name, geom_field))]
73pub fn derive_gpkg(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
74    let inner_input = proc_macro2::TokenStream::from(input);
75    proc_macro::TokenStream::from(derive_gpkg_inner(inner_input))
76}
77
78fn derive_gpkg_inner(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
79    let ast = parse2::<DeriveInput>(input).unwrap();
80
81    let tbl_name_meta = get_meta_attr(&ast.attrs, "layer_name");
82    let tbl_name = tbl_name_meta.and_then(|meta| match meta {
83        Meta::NameValue(MetaNameValue {
84            lit: Lit::Str(ls), ..
85        }) => Some(ls.value()),
86        _ => None,
87    });
88
89    // ge the name for our table name
90    let name = &ast.ident;
91
92    let fields = match &ast.data {
93        syn::Data::Struct(data) => match &data.fields {
94            syn::Fields::Named(fields) => fields.named.iter(),
95            _ => panic!("GPKGModel derive expected named fields"),
96        },
97        _ => panic!("GPKGModel derive expected a struct"),
98    }
99    .collect();
100
101    impl_model(&name.clone(), &fields, tbl_name, &ast.generics)
102}
103
104fn get_meta_attr<'a>(attrs: &Vec<Attribute>, name: &'a str) -> Option<Meta> {
105    let mut temp = attrs
106        .iter()
107        .filter_map(|attr| attr.parse_meta().ok())
108        .filter(|i| match i.path().get_ident() {
109            Some(i) => i.to_string() == name.to_owned(),
110            None => false,
111        })
112        .collect::<Vec<Meta>>();
113    temp.pop()
114}
115
116#[derive(Debug, Clone, Copy)]
117enum MZOptions {
118    Prohibited = 0,
119    Mandatory = 1,
120    Optional = 2,
121}
122
123#[derive(Debug, Clone)]
124struct GeomInfo {
125    geom_type: String,
126    // this is mostly for future proofing, we'll default to wgs84 for now
127    srs_id: i64,
128    m: MZOptions,
129    z: MZOptions,
130}
131
132#[derive(Debug)]
133struct FieldInfo {
134    name: String,
135    geom_info: Option<GeomInfo>,
136    optional: bool,
137    type_for_sql: String,
138}
139
140// only going to support &str and &[u8] for now
141fn get_reference_type_name(t: &TypeReference) -> String {
142    match t.elem.deref() {
143        syn::Type::Path(p) => {
144            assert!(p.path.segments.len() == 1);
145            match get_path_type_name(p).0.as_str() {
146                "str" => return String::from("str"),
147                _ => panic!("The only reference types supported are &str and &[u8]"),
148            }
149        }
150        syn::Type::Slice(s) => match s.elem.deref() {
151            Type::Path(p) => match get_path_type_name(p).0.as_str() {
152                "u8" => return String::from("buf"),
153                _ => panic!("The only reference types supported are &str and &[u8]"),
154            },
155            _ => panic!("The only reference types supported are &str and &[u8]"),
156        },
157        _ => panic!("The only reference types supported are &str and &[u8]"),
158    };
159}
160
161// return the field name and whether or not it's optional
162fn get_path_type_name(p: &TypePath) -> (String, bool) {
163    let mut optional = false;
164    assert!(p.path.segments.len() > 0);
165    let final_segment = p.path.segments.last().unwrap();
166    let id_string = final_segment.ident.to_string();
167    match id_string.as_str() {
168        // get the inner
169        "Option" => {
170            optional = true;
171            if let syn::PathArguments::AngleBracketed(a) = &final_segment.arguments {
172                assert!(a.args.len() == 1, "Only one argument allowed in an Option");
173                if let GenericArgument::Type(t) = &a.args[0] {
174                    match t {
175                        Type::Path(p) => {
176                            return (get_path_type_name(p).0, optional);
177                        }
178                        Type::Reference(r) => {
179                            return (get_reference_type_name(r), optional);
180                        }
181                        _ => panic!("Unsupported type within Option"),
182                    }
183                }
184            } else {
185                panic!("Unsupported use of the option type");
186            }
187        }
188        "Vec" => {
189            if let syn::PathArguments::AngleBracketed(a) = &final_segment.arguments {
190                assert!(a.args.len() == 1, "Only one argument allowed in a Vec");
191                if let GenericArgument::Type(t) = &a.args[0] {
192                    match t {
193                        Type::Path(p) => {
194                            let type_return = get_path_type_name(p).0;
195                            match type_return.as_str() {
196                                "u8" => return (String::from("buf"), optional),
197                                _ => panic!("Vec<u8> is the only allowed use of the Vec type"),
198                            };
199                        }
200                        _ => panic!("Vec<u8> is the only allowed use of the Vec type"),
201                    }
202                }
203            } else {
204                panic!("Vec<u8> is the only allowed use of the Vec type");
205            }
206        }
207        _ => {}
208    }
209
210    (final_segment.ident.to_string(), false)
211}
212
213fn impl_model(
214    name: &Ident,
215    fields: &Vec<&Field>,
216    tbl_name: Option<String>,
217    generics: &Generics,
218) -> TokenStream {
219    // overwrite the struct name with a provided table name if one is given
220    // TODO: add some level of validation here based on sqlite's rules
221    let layer_name_final = match tbl_name {
222        Some(n) => Ident::new(&n, name.span()),
223        None => name.to_owned(),
224    };
225
226    let geom_field_name: String;
227
228    // need to get this in order to make liftimes on the Impl work correctly
229    let mut final_generics = generics.clone();
230    if let Some(g) = final_generics.params.first_mut() {
231        match g {
232            GenericParam::Lifetime(l) => match l.lifetime.ident.to_string().as_str() {
233                "static" | "_" => {}
234                _ => l.lifetime.ident = Ident::new("_", Span::call_site()),
235            },
236            _ => {}
237        }
238    }
239
240    // the goal is to support everything here (https://www.geopackage.org/spec130/index.html#table_column_data_types)
241    // as well as allow the user change whether a field can have nulls or not with the option type
242    let field_infos: Vec<FieldInfo> = fields
243        .iter()
244        .map(|f| {
245            let mut optional = false;
246            let field_name = f.ident.as_ref().expect("Expected named field").to_string();
247            let type_name: String;
248            let geom_info = get_geom_field_info(&f);
249            match &f.ty {
250                syn::Type::Reference(r) => {
251                    type_name = get_reference_type_name(r);
252                }
253                syn::Type::Path(tp) => {
254                    (type_name, optional) = get_path_type_name(tp);
255                }
256                _ => panic!("Don't know how to map to GPKG type {:?}", f.ty),
257            }
258            let sql_type = match type_name.as_str() {
259                "bool" => quote!(INTEGER),
260                "String" | "str" => quote!(TEXT),
261                "i64" | "i32" | "i16" | "i8" => quote!(INTEGER),
262                "f64" | "f32" => quote!(REAL),
263                "buf" => quote!(BLOB),
264                "u128" | "u64" | "u32" | "u16" | "u8" => {
265                    panic!("SQLite doesn't support unsigned integers, use a signed integer value")
266                }
267                // all geometry types are a blob inside sqlite
268                _ if geom_info.is_some() => quote!(BLOB),
269                _ => panic!("Don't know how to map to SQL type {}", type_name),
270            };
271            FieldInfo {
272                name: field_name,
273                optional,
274                geom_info,
275                type_for_sql: sql_type.to_string(),
276            }
277        })
278        .collect();
279    let geom_fields: Vec<&FieldInfo> = field_infos
280        .iter()
281        .filter(|f| f.geom_info.is_some())
282        .collect();
283    assert!(
284        geom_fields.len() <= 1,
285        "Found {} geometry fields, 1 is the maximum allowed amount",
286        geom_fields.len()
287    );
288    let mut geom_column_sql: Option<String> = None;
289    let mut contents_sql = format!(
290        r#"INSERT INTO gpkg_contents (layer_name, data_type) VALUES ("{}", "{}");"#,
291        layer_name_final, "attributes"
292    );
293
294    if geom_fields.len() > 0 {
295        let geom_field = geom_fields[0];
296        let geom_info = geom_field.geom_info.clone().unwrap();
297        let geom_type_sql = geom_info.geom_type.clone();
298        geom_field_name = geom_field.name.clone();
299        geom_column_sql = Some(format!(
300            r#"INSERT INTO gpkg_geometry_columns VALUES("{}", "{}", "{}", {}, {}, {});"#,
301            layer_name_final,
302            geom_field_name,
303            geom_type_sql.to_uppercase(),
304            geom_info.srs_id,
305            geom_info.m as i32,
306            geom_info.z as i32
307        ));
308        contents_sql = format!(
309            r#"INSERT INTO gpkg_contents (layer_name, data_type, srs_id) VALUES ("{}", "{}", {});"#,
310            layer_name_final, "features", geom_info.srs_id
311        );
312    };
313    let contents_ts: TokenStream = contents_sql
314        .parse()
315        .expect("Unable to convert contents table insert statement into token stream");
316    let geom_column_ts: TokenStream = match geom_column_sql {
317        Some(s) => s
318            .parse()
319            .expect("Unable to convert contents table insert statement into token stream"),
320        None => TokenStream::new(),
321    };
322
323    let column_defs = field_infos
324        .iter()
325        .map(|f| {
326            let null_str = if f.optional { "" } else { " NOT NULL" };
327            format!("{} {}{}", f.name, f.type_for_sql, null_str)
328                .parse()
329                .unwrap()
330        })
331        .collect::<Vec<TokenStream>>();
332
333    let column_names: Vec<Ident> = field_infos
334        .iter()
335        .map(|i| Ident::new(i.name.as_str(), Span::call_site()))
336        .collect();
337
338    let params = vec![quote!(?); column_names.len()];
339
340    let column_nums = (0..column_defs.len())
341        .map(|i| LitInt::new(i.to_string().as_str(), Span::call_site()))
342        .collect::<Vec<LitInt>>();
343
344    // need to add some generic support like in here: https://github.com/diesel-rs/diesel/blob/master/diesel_derives/src/insertable.rs#L88
345    // this is so that lifetimes will work
346    let new = quote!(
347        impl GPKGModel <'_> for #name #final_generics {
348            fn get_gpkg_layer_name() -> &'static str {
349                std::stringify!(#layer_name_final)
350            }
351
352            fn get_create_sql() -> &'static str {
353                std::stringify!(
354                    BEGIN;
355                    CREATE TABLE #layer_name_final (
356                        object_id INTEGER PRIMARY KEY,
357                        #(#column_defs ),*
358                    );
359                    #geom_column_ts
360                    #contents_ts
361                    COMMIT;
362                )
363            }
364
365            fn get_insert_sql() -> &'static str {
366                std::stringify!(
367                    INSERT INTO #layer_name_final (
368                        #(#column_names),*
369                    ) VALUES (
370                        #(#params),*
371                    )
372                )
373            }
374
375            fn get_select_sql() -> &'static str {
376                std::stringify!(
377                    SELECT #(#column_names),* FROM #layer_name_final;
378                )
379            }
380
381            fn get_select_where(predicate: &str) -> String {
382                (std::stringify!(
383                    SELECT #(#column_names),* FROM #layer_name_final WHERE
384                ).to_owned() + " " + predicate + ";")
385            }
386
387            fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
388                Ok(Self {
389                    #(#column_names: row.get((#column_nums))?,)*
390                })
391            }
392
393            fn as_params(&self) -> Vec<&(dyn rusqlite::ToSql + '_)> {
394                vec![
395                    #(&self.#column_names as &dyn rusqlite::ToSql),*
396                ]
397            }
398        }
399    );
400    new
401}
402
403fn get_geom_field_info(field: &Field) -> Option<GeomInfo> {
404    for attr in &field.attrs {
405        if let Some(ident) = attr.path.get_ident() {
406            if ident.to_string() == "geom_field" {
407                let geom_type_name =
408                    get_meta_attr(&field.attrs, "geom_field").and_then(|meta| match meta {
409                        Meta::List(l) => l.nested.first().and_then(|n| match n {
410                            syn::NestedMeta::Lit(Lit::Str(ls)) => Some(ls.value()),
411                            _ => panic!("You must specify a geometry type when using the geom_field attribute"),
412                        }),
413                        _ => panic!("You must specify a geometry type when using the geom_field attribute"),
414                    });
415                if let Some(name) = geom_type_name {
416                    let upper_name = name.to_uppercase();
417                    if GEO_TYPES.contains(&upper_name.as_str()) {
418                        let m = if upper_name.contains("M") {
419                            MZOptions::Optional
420                        } else {
421                            MZOptions::Prohibited
422                        };
423                        let z = if upper_name.contains("Z") {
424                            MZOptions::Optional
425                        } else {
426                            MZOptions::Prohibited
427                        };
428                        return Some(GeomInfo {
429                            geom_type: upper_name,
430                            srs_id: 4326,
431                            m,
432                            z,
433                        });
434                    } else {
435                        panic!("{} is not a supported geometry type", name);
436                    }
437                }
438            }
439        }
440    }
441    None
442}
443
444#[cfg(test)]
445mod test {
446    use super::*;
447    use quote::quote;
448    #[test]
449    fn basic_test() {
450        let tstream = quote!(
451            #[layer_name = "streetlights"]
452            // #[test_thing = "blah"]
453            struct StreetLight {
454                id: i64,
455                height: f64,
456                string_ref: Option<String>,
457                buf_ref: &'a [u8],
458                #[geom_field("LineStringZ")]
459                geom: GPKGLineStringZ,
460            }
461        );
462        println!("{}", derive_gpkg_inner(tstream.into()));
463    }
464}