1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
//LICENSE
//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
//LICENSE
//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
//LICENSE
//LICENSE All rights reserved.
//LICENSE
//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
use pgrx_sql_entity_graph::{PostgresHash, PostgresOrd};

use crate::{parse_postgres_type_args, PostgresTypeAttribute};
use proc_macro2::Ident;
use quote::{quote, ToTokens};
use syn::DeriveInput;

fn ident_and_path(ast: &DeriveInput) -> (&Ident, proc_macro2::TokenStream) {
    let ident = &ast.ident;
    let args = parse_postgres_type_args(&ast.attrs);
    let path = if args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) {
        quote! { ::pgrx::datum::PgVarlena<#ident> }
    } else {
        quote! { #ident }
    };
    (ident, path)
}

pub(crate) fn deriving_postgres_eq(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
    let mut stream = proc_macro2::TokenStream::new();
    let (ident, path) = ident_and_path(&ast);
    stream.extend(derive_pg_eq(ident, &path));
    stream.extend(derive_pg_ne(ident, &path));

    Ok(stream)
}

pub(crate) fn deriving_postgres_ord(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
    let mut stream = proc_macro2::TokenStream::new();
    let (ident, path) = ident_and_path(&ast);

    stream.extend(derive_pg_lt(ident, &path));
    stream.extend(derive_pg_gt(ident, &path));
    stream.extend(derive_pg_le(ident, &path));
    stream.extend(derive_pg_ge(ident, &path));
    stream.extend(derive_pg_cmp(ident, &path));

    let sql_graph_entity_item = PostgresOrd::from_derive_input(ast)?;
    sql_graph_entity_item.to_tokens(&mut stream);

    Ok(stream)
}

pub(crate) fn deriving_postgres_hash(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
    let mut stream = proc_macro2::TokenStream::new();
    let (ident, path) = ident_and_path(&ast);

    stream.extend(derive_pg_hash(ident, &path));

    let sql_graph_entity_item = PostgresHash::from_derive_input(ast)?;
    sql_graph_entity_item.to_tokens(&mut stream);

    Ok(stream)
}

/// Derive a Postgres `=` operator from Rust `==`
///
/// Note this expansion applies a number of assumptions that may not be true:
/// - PartialEq::eq is referentially transparent (immutable and parallel-safe)
/// - PartialEq::ne must reverse PartialEq::eq (negator)
/// - PartialEq::eq is commutative
///
/// Postgres swears that these are just ["optimization hints"], and they can be
/// defined to use regular SQL or PL/pgSQL functions with spurious results.
///
/// However, it is entirely plausible these assumptions actually are venomous.
/// It is deeply unlikely that we can audit the millions of lines of C code in
/// Postgres to confirm that it avoids using these assumptions in a way that
/// would lead to UB or unacceptable behavior from PGRX if Eq is incorrectly
/// implemented, and we have no realistic means of guaranteeing this.
///
/// Further, Postgres adds a disclaimer to these "optimization hints":
///
/// ```text
/// But if you provide them, you must be sure that they are right!
/// Incorrect use of an optimization clause can result in
/// slow queries, subtly wrong output, or other Bad Things.
/// ```
///
/// In practice, most Eq impls are in fact correct, referentially transparent,
/// and commutative. So this note could be for nothing. This signpost is left
/// in order to guide anyone unfortunate enough to be debugging an issue that
/// finally leads them here.
///
/// ["optimization hints"]: https://www.postgresql.org/docs/current/xoper-optimization.html
pub fn derive_pg_eq(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_eq", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
        #[::pgrx::pgrx_macros::opname(=)]
        #[::pgrx::pgrx_macros::commutator(=)]
        #[::pgrx::pgrx_macros::negator(<>)]
        #[::pgrx::pgrx_macros::restrict(eqsel)]
        #[::pgrx::pgrx_macros::join(eqjoinsel)]
        #[::pgrx::pgrx_macros::merges]
        #[::pgrx::pgrx_macros::hashes]
        fn #pg_name(left: #path, right: #path) -> bool
        where
            #path: ::core::cmp::Eq,
        {
            left == right
        }
    }
}

/// Derive a Postgres `<>` operator from Rust `!=`
///
/// Note that this expansion applies a number of assumptions that aren't necessarily true:
/// - PartialEq::ne is referentially transparent (immutable and parallel-safe)
/// - PartialEq::eq must reverse PartialEq::ne (negator)
/// - PartialEq::ne is commutative
///
/// See `derive_pg_eq` for the implications of this assumption.
pub fn derive_pg_ne(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_ne", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
        #[::pgrx::pgrx_macros::opname(<>)]
        #[::pgrx::pgrx_macros::commutator(<>)]
        #[::pgrx::pgrx_macros::negator(=)]
        #[::pgrx::pgrx_macros::restrict(neqsel)]
        #[::pgrx::pgrx_macros::join(neqjoinsel)]
        fn #pg_name(left: #path, right: #path) -> bool {
            left != right
        }
    }
}

pub fn derive_pg_lt(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_lt", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
        #[::pgrx::pgrx_macros::opname(<)]
        #[::pgrx::pgrx_macros::negator(>=)]
        #[::pgrx::pgrx_macros::commutator(>)]
        #[::pgrx::pgrx_macros::restrict(scalarltsel)]
        #[::pgrx::pgrx_macros::join(scalarltjoinsel)]
        fn #pg_name(left: #path, right: #path) -> bool {
            left < right
        }

    }
}

pub fn derive_pg_gt(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_gt", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
        #[::pgrx::pgrx_macros::opname(>)]
        #[::pgrx::pgrx_macros::negator(<=)]
        #[::pgrx::pgrx_macros::commutator(<)]
        #[::pgrx::pgrx_macros::restrict(scalargtsel)]
        #[::pgrx::pgrx_macros::join(scalargtjoinsel)]
        fn #pg_name(left: #path, right: #path) -> bool {
            left > right
        }
    }
}

pub fn derive_pg_le(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_le", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
        #[::pgrx::pgrx_macros::opname(<=)]
        #[::pgrx::pgrx_macros::negator(>)]
        #[::pgrx::pgrx_macros::commutator(>=)]
        #[::pgrx::pgrx_macros::restrict(scalarlesel)]
        #[::pgrx::pgrx_macros::join(scalarlejoinsel)]
        fn #pg_name(left: #path, right: #path) -> bool {
            left <= right
        }
    }
}

pub fn derive_pg_ge(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_ge", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
        #[::pgrx::pgrx_macros::opname(>=)]
        #[::pgrx::pgrx_macros::negator(<)]
        #[::pgrx::pgrx_macros::commutator(<=)]
        #[::pgrx::pgrx_macros::restrict(scalargesel)]
        #[::pgrx::pgrx_macros::join(scalargejoinsel)]
        fn #pg_name(left: #path, right: #path) -> bool {
            left >= right
        }
    }
}

pub fn derive_pg_cmp(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_cmp", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_extern(immutable, parallel_safe)]
        fn #pg_name(left: #path, right: #path) -> i32 {
            ::core::cmp::Ord::cmp(&left, &right) as i32
        }
    }
}

/// Derive a Postgres hash operator using a provided hash function
///
/// # HashEq?
///
/// To quote the std documentation:
///
/// "When implementing both Hash and Eq, it is important that the following property holds:
/// ```text
/// k1 == k2 -> hash(k1) == hash(k2)
/// ```
/// In other words, if two keys are equal, their hashes must also be equal. HashMap and HashSet both rely on this behavior."
///
/// Postgres is no different: this hashing is for the explicit purpose of equality checks,
/// and it also needs to be able to reason from hash equality to actual equality.
pub fn derive_pg_hash(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
    let pg_name = Ident::new(&format!("{}_hash", name).to_lowercase(), name.span());
    quote! {
        #[allow(non_snake_case)]
        #[::pgrx::pgrx_macros::pg_extern(immutable, parallel_safe)]
        fn #pg_name(value: #path) -> i32
        where
            #path: ::core::hash::Hash + ::core::cmp::Eq,
        {
            ::pgrx::misc::pgrx_seahash(&value) as i32
        }
    }
}