1use crate::sql_entity_graph::{NameMacro, PositioningRef};
11use proc_macro2::{TokenStream, TokenTree};
12use quote::{format_ident, quote, ToTokens, TokenStreamExt};
13use std::collections::HashSet;
14use syn::{GenericArgument, PathArguments, Type, TypeParamBound};
15
16pub mod rewriter;
17pub mod sql_entity_graph;
18
19#[doc(hidden)]
20pub mod __reexports {
21 pub use eyre;
22 pub mod std {
24 pub mod collections {
25 pub use std::collections::HashSet;
26 }
27 }
28}
29
30#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
31pub enum ExternArgs {
32 CreateOrReplace,
33 Immutable,
34 Strict,
35 Stable,
36 Volatile,
37 Raw,
38 NoGuard,
39 ParallelSafe,
40 ParallelUnsafe,
41 ParallelRestricted,
42 Error(String),
43 Schema(String),
44 Name(String),
45 Cost(String),
46 Requires(Vec<PositioningRef>),
47}
48
49impl core::fmt::Display for ExternArgs {
50 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51 match self {
52 ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
53 ExternArgs::Immutable => write!(f, "IMMUTABLE"),
54 ExternArgs::Strict => write!(f, "STRICT"),
55 ExternArgs::Stable => write!(f, "STABLE"),
56 ExternArgs::Volatile => write!(f, "VOLATILE"),
57 ExternArgs::Raw => Ok(()),
58 ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
59 ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
60 ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
61 ExternArgs::Error(_) => Ok(()),
62 ExternArgs::NoGuard => Ok(()),
63 ExternArgs::Schema(_) => Ok(()),
64 ExternArgs::Name(_) => Ok(()),
65 ExternArgs::Cost(cost) => write!(f, "COST {}", cost),
66 ExternArgs::Requires(_) => Ok(()),
67 }
68 }
69}
70
71impl ToTokens for ExternArgs {
72 fn to_tokens(&self, tokens: &mut TokenStream) {
73 match self {
74 ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
75 ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
76 ExternArgs::Strict => tokens.append(format_ident!("Strict")),
77 ExternArgs::Stable => tokens.append(format_ident!("Stable")),
78 ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
79 ExternArgs::Raw => tokens.append(format_ident!("Raw")),
80 ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
81 ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
82 ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
83 ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
84 ExternArgs::Error(_s) => {
85 tokens.append_all(
86 quote! {
87 Error(String::from("#_s"))
88 }
89 .to_token_stream(),
90 );
91 }
92 ExternArgs::Schema(_s) => {
93 tokens.append_all(
94 quote! {
95 Schema(String::from("#_s"))
96 }
97 .to_token_stream(),
98 );
99 }
100 ExternArgs::Name(_s) => {
101 tokens.append_all(
102 quote! {
103 Name(String::from("#_s"))
104 }
105 .to_token_stream(),
106 );
107 }
108 ExternArgs::Cost(_s) => {
109 tokens.append_all(
110 quote! {
111 Cost(String::from("#_s"))
112 }
113 .to_token_stream(),
114 );
115 }
116 ExternArgs::Requires(items) => {
117 tokens.append_all(
118 quote! {
119 Requires(vec![#(#items),*])
120 }
121 .to_token_stream(),
122 );
123 }
124 }
125 }
126}
127
128#[derive(Debug, Hash, Ord, PartialOrd, Eq, PartialEq)]
129pub enum FunctionArgs {
130 SearchPath(String),
131}
132
133#[derive(Debug)]
134pub enum CategorizedType {
135 Iterator(Vec<String>),
136 OptionalIterator(Vec<String>),
137 Tuple(Vec<String>),
138 Default,
139}
140
141pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
142 let mut args = HashSet::<ExternArgs>::new();
143 let mut itr = attr.into_iter();
144 while let Some(t) = itr.next() {
145 match t {
146 TokenTree::Group(g) => {
147 for arg in parse_extern_attributes(g.stream()).into_iter() {
148 args.insert(arg);
149 }
150 }
151 TokenTree::Ident(i) => {
152 let name = i.to_string();
153 match name.as_str() {
154 "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
155 "immutable" => args.insert(ExternArgs::Immutable),
156 "strict" => args.insert(ExternArgs::Strict),
157 "stable" => args.insert(ExternArgs::Stable),
158 "volatile" => args.insert(ExternArgs::Volatile),
159 "raw" => args.insert(ExternArgs::Raw),
160 "no_guard" => args.insert(ExternArgs::NoGuard),
161 "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
162 "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
163 "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
164 "error" => {
165 let _punc = itr.next().unwrap();
166 let literal = itr.next().unwrap();
167 let message = literal.to_string();
168 let message = unescape::unescape(&message).expect("failed to unescape");
169
170 let message = message[1..message.len() - 1].to_string();
172 args.insert(ExternArgs::Error(message.to_string()))
173 }
174 "schema" => {
175 let _punc = itr.next().unwrap();
176 let literal = itr.next().unwrap();
177 let schema = literal.to_string();
178 let schema = unescape::unescape(&schema).expect("failed to unescape");
179
180 let schema = schema[1..schema.len() - 1].to_string();
182 args.insert(ExternArgs::Schema(schema.to_string()))
183 }
184 "name" => {
185 let _punc = itr.next().unwrap();
186 let literal = itr.next().unwrap();
187 let name = literal.to_string();
188 let name = unescape::unescape(&name).expect("failed to unescape");
189
190 let name = name[1..name.len() - 1].to_string();
192 args.insert(ExternArgs::Name(name.to_string()))
193 }
194 "sql" => {
196 let _punc = itr.next().unwrap();
197 let _value = itr.next().unwrap();
198 false
199 }
200 _ => false,
201 };
202 }
203 TokenTree::Punct(_) => {}
204 TokenTree::Literal(_) => {}
205 }
206 }
207 args
208}
209
210pub fn categorize_type(ty: &Type) -> CategorizedType {
211 match ty {
212 Type::Path(ty) => {
213 let segments = &ty.path.segments;
214 for segment in segments {
215 let segment_ident = segment.ident.to_string();
216 if segment_ident == "Option" {
217 match &segment.arguments {
218 PathArguments::AngleBracketed(a) => match a.args.first().unwrap() {
219 GenericArgument::Type(ty) => {
220 let result = categorize_type(ty);
221
222 return match result {
223 CategorizedType::Iterator(i) => {
224 CategorizedType::OptionalIterator(i)
225 }
226
227 _ => result,
228 };
229 }
230 _ => {
231 break;
232 }
233 },
234 _ => {
235 break;
236 }
237 }
238 }
239 if segment_ident == "Box" {
240 match &segment.arguments {
241 PathArguments::AngleBracketed(a) => match a.args.first().unwrap() {
242 GenericArgument::Type(ty) => return categorize_type(ty),
243 _ => {
244 break;
245 }
246 },
247 _ => {
248 break;
249 }
250 }
251 }
252 }
253 CategorizedType::Default
254 }
255 Type::TraitObject(trait_object) => {
256 for bound in &trait_object.bounds {
257 return categorize_trait_bound(bound);
258 }
259
260 panic!("Unsupported trait return type");
261 }
262 Type::ImplTrait(ty) => {
263 for bound in &ty.bounds {
264 return categorize_trait_bound(bound);
265 }
266
267 panic!("Unsupported trait return type");
268 }
269 Type::Tuple(tuple) => {
270 if tuple.elems.len() == 0 {
271 CategorizedType::Default
272 } else {
273 let mut types = Vec::new();
274 for ty in &tuple.elems {
275 types.push(quote! {#ty}.to_string())
276 }
277 CategorizedType::Tuple(types)
278 }
279 }
280 _ => CategorizedType::Default,
281 }
282}
283
284pub fn categorize_trait_bound(bound: &TypeParamBound) -> CategorizedType {
285 match bound {
286 TypeParamBound::Trait(trait_bound) => {
287 let segments = &trait_bound.path.segments;
288
289 let mut ident = String::new();
290 for segment in segments {
291 if !ident.is_empty() {
292 ident.push_str("::")
293 }
294 ident.push_str(segment.ident.to_string().as_str());
295 }
296
297 match ident.as_str() {
298 "Iterator" | "std::iter::Iterator" => {
299 let segment = segments.last().unwrap();
300 match &segment.arguments {
301 PathArguments::None => {
302 panic!("Iterator must have at least one generic type")
303 }
304 PathArguments::Parenthesized(_) => {
305 panic!("Unsupported arguments to Iterator")
306 }
307 PathArguments::AngleBracketed(a) => {
308 let args = &a.args;
309 if args.len() > 1 {
310 panic!(
311 "Only one generic type is supported when returning an Iterator"
312 )
313 }
314
315 match args.first().unwrap() {
316 GenericArgument::Binding(b) => {
317 let mut types = Vec::new();
318 let ty = &b.ty;
319 match ty {
320 Type::Tuple(tuple) => {
321 for e in &tuple.elems {
322 types.push(quote! {#e}.to_string());
323 }
324 },
325 _ => {
326 types.push(quote! {#ty}.to_string())
327 }
328 }
329
330 return CategorizedType::Iterator(types);
331 }
332 _ => panic!("Only binding type arguments are supported when returning an Iterator")
333 }
334 }
335 }
336 }
337 _ => panic!("Unsupported trait return type"),
338 }
339 }
340 TypeParamBound::Lifetime(_) => {
341 panic!("Functions can't return traits with lifetime bounds")
342 }
343 }
344}
345
346pub fn staticize_lifetimes_in_type_path(value: syn::TypePath) -> syn::TypePath {
347 let mut ty = syn::Type::Path(value);
348 staticize_lifetimes(&mut ty);
349 match ty {
350 syn::Type::Path(type_path) => type_path,
351
352 _ => panic!("not a TypePath"),
354 }
355}
356
357pub fn staticize_lifetimes(value: &mut syn::Type) {
358 match value {
359 syn::Type::Path(type_path) => {
360 for segment in &mut type_path.path.segments {
361 match &mut segment.arguments {
362 syn::PathArguments::AngleBracketed(bracketed) => {
363 for arg in &mut bracketed.args {
364 match arg {
365 syn::GenericArgument::Lifetime(lifetime) => {
367 lifetime.ident =
368 syn::Ident::new("static", lifetime.ident.span());
369 }
370
371 syn::GenericArgument::Type(ty) => staticize_lifetimes(ty),
373 syn::GenericArgument::Binding(binding) => {
374 staticize_lifetimes(&mut binding.ty)
375 }
376 syn::GenericArgument::Constraint(constraint) => {
377 for bound in constraint.bounds.iter_mut() {
378 match bound {
379 syn::TypeParamBound::Lifetime(lifetime) => {
380 lifetime.ident =
381 syn::Ident::new("static", lifetime.ident.span())
382 }
383 _ => {}
384 }
385 }
386 }
387
388 _ => {}
390 }
391 }
392 }
393 _ => {}
394 }
395 }
396 }
397
398 syn::Type::Reference(type_ref) => match &mut type_ref.lifetime {
399 Some(ref mut lifetime) => {
400 lifetime.ident = syn::Ident::new("static", lifetime.ident.span());
401 }
402 this @ None => *this = Some(syn::parse_quote!('static)),
403 },
404
405 syn::Type::Tuple(type_tuple) => {
406 for elem in &mut type_tuple.elems {
407 staticize_lifetimes(elem);
408 }
409 }
410
411 syn::Type::Macro(type_macro) => {
412 let mac = &type_macro.mac;
413 if let Some(archetype) = mac.path.segments.last() {
414 match archetype.ident.to_string().as_str() {
415 "name" => {
416 if let Ok(out) = mac.parse_body::<NameMacro>() {
417 if let Ok(ident) = syn::parse_str::<TokenStream>(&out.ident) {
422 let mut ty = out.used_ty.resolved_ty;
423
424 staticize_lifetimes(&mut ty);
426 type_macro.mac = syn::parse_quote! {name!(#ident, #ty)};
427 }
428 }
429 }
430 _ => {}
431 }
432 }
433 }
434 _ => {}
435 }
436}
437
438pub fn anonymize_lifetimes_in_type_path(value: syn::TypePath) -> syn::TypePath {
439 let mut ty = syn::Type::Path(value);
440 anonymize_lifetimes(&mut ty);
441 match ty {
442 syn::Type::Path(type_path) => type_path,
443
444 _ => panic!("not a TypePath"),
446 }
447}
448
449pub fn anonymize_lifetimes(value: &mut syn::Type) {
450 match value {
451 syn::Type::Path(type_path) => {
452 for segment in &mut type_path.path.segments {
453 match &mut segment.arguments {
454 syn::PathArguments::AngleBracketed(bracketed) => {
455 for arg in &mut bracketed.args {
456 match arg {
457 syn::GenericArgument::Lifetime(lifetime) => {
459 lifetime.ident = syn::Ident::new("_", lifetime.ident.span());
460 }
461
462 syn::GenericArgument::Type(ty) => anonymize_lifetimes(ty),
464 syn::GenericArgument::Binding(binding) => {
465 anonymize_lifetimes(&mut binding.ty)
466 }
467 syn::GenericArgument::Constraint(constraint) => {
468 for bound in constraint.bounds.iter_mut() {
469 match bound {
470 syn::TypeParamBound::Lifetime(lifetime) => {
471 lifetime.ident =
472 syn::Ident::new("_", lifetime.ident.span())
473 }
474 _ => {}
475 }
476 }
477 }
478
479 _ => {}
481 }
482 }
483 }
484 _ => {}
485 }
486 }
487 }
488
489 syn::Type::Reference(type_ref) => {
490 if let Some(lifetime) = type_ref.lifetime.as_mut() {
491 lifetime.ident = syn::Ident::new("_", lifetime.ident.span());
492 }
493 }
494
495 syn::Type::Tuple(type_tuple) => {
496 for elem in &mut type_tuple.elems {
497 anonymize_lifetimes(elem);
498 }
499 }
500
501 _ => {}
502 }
503}
504
505const POSTGRES_IDENTIFIER_MAX_LEN: usize = 64;
509
510pub fn ident_is_acceptable_to_postgres(ident: &syn::Ident) -> Result<(), syn::Error> {
520 let ident_string = ident.to_string();
521 if ident_string.len() >= POSTGRES_IDENTIFIER_MAX_LEN {
522 return Err(syn::Error::new(
523 ident.span(),
524 &format!(
525 "Identifier `{}` was {} characters long, PostgreSQL will truncate identifiers with less than {POSTGRES_IDENTIFIER_MAX_LEN} characters, opt for an identifier which Postgres won't truncate",
526 ident,
527 ident_string.len(),
528 )
529 ));
530 }
531
532 Ok(())
533}
534
535#[cfg(test)]
536mod tests {
537 use crate::{parse_extern_attributes, ExternArgs};
538 use std::str::FromStr;
539
540 #[test]
541 fn parse_args() {
542 let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
543 let ts = proc_macro2::TokenStream::from_str(s).unwrap();
544
545 let args = parse_extern_attributes(ts);
546 assert!(args.contains(&ExternArgs::Error("syntax error at or near \"THIS\"".to_string())));
547 }
548}