inline_sql_macros/
util.rs

1use syn::spanned::Spanned;
2
3pub fn type_strip_paren(typ: &syn::Type) -> &syn::Type {
4	let mut typ = typ;
5	loop {
6		match typ {
7			syn::Type::Paren(inner) => typ = &inner.elem,
8			syn::Type::Group(inner) => typ = &inner.elem,
9			x => return x,
10		}
11	}
12}
13
14pub fn type_as_path(typ: &syn::Type) -> Option<&syn::Path> {
15	match type_strip_paren(typ) {
16		syn::Type::Path(x) => Some(&x.path),
17		_ => None,
18	}
19}
20
21pub fn type_strip_result(typ: &syn::Type) -> Result<&syn::Type, syn::Error> {
22	fn short_error<S: syn::spanned::Spanned + quote::ToTokens>(spanned: S) -> syn::Error {
23		syn::Error::new_spanned(spanned, "The function must return a `Result<_, _>`")
24	}
25
26	fn long_error<S: syn::spanned::Spanned + quote::ToTokens>(spanned: S, message: &'static str) -> syn::Error {
27		let note = concat!(
28			"Note: The function must return a `Result<_, _>`.\n",
29			"If you are using a type alias, the macro can not resolve it.\n",
30			"Replace the alias with the actual type name.",
31		);
32		syn::Error::new_spanned(spanned, format!("{message}\n\n{note}"))
33	}
34
35	let path = type_as_path(typ)
36		.ok_or_else(|| short_error(typ))?;
37	let segment = path.segments.last()
38		.ok_or_else(|| short_error(path))?;
39
40	if segment.ident != "Result" {
41		return Err(long_error(segment, "Expected `Result<_, _>`"));
42	}
43
44	let arguments = match &segment.arguments {
45		syn::PathArguments::AngleBracketed(arguments) => arguments,
46		_ => return Err(long_error(segment, "Expected `Result<_, _>`")),
47	};
48
49	if arguments.args.is_empty() || arguments.args.len() > 2 {
50		return Err(long_error(segment, "Expected `Result<_, _>`"));
51	}
52
53	if let syn::GenericArgument::Type(typ) = &arguments.args[0] {
54		Ok(typ)
55	} else {
56		Err(long_error(&arguments.args[0], "Expected a type argument"))
57	}
58}
59
60pub fn type_result_args(typ: &syn::Type) -> Option<&syn::AngleBracketedGenericArguments> {
61	let path = type_as_path(typ)?;
62	let segment = path.segments.last()?;
63	if segment.ident != "Result" {
64		return None;
65	}
66
67	match &segment.arguments {
68		syn::PathArguments::AngleBracketed(arguments) => Some(arguments),
69		_ => None,
70	}
71}
72
73pub fn type_strip_vec(typ: &syn::Type) -> Option<&syn::Type> {
74	let path = type_as_path(typ)?;
75
76	let candidates = &[
77		["Vec"].as_slice(),
78		["std", "vec", "Vec"].as_slice(),
79		["alloc", "vec", "Vec"].as_slice(),
80		["", "std", "vec", "Vec"].as_slice(),
81		["", "alloc", "vec", "Vec"].as_slice(),
82	];
83
84	if !path_is_one_of(path, candidates) {
85		return None;
86	}
87
88	let last = path.segments.last()?;
89	let arguments = match &last.arguments {
90		syn::PathArguments::AngleBracketed(args) => args,
91		_ => return None,
92	};
93
94	match &arguments.args[0] {
95		syn::GenericArgument::Type(x) => Some(x),
96		_ => None,
97	}
98}
99
100pub fn type_strip_option(typ: &syn::Type) -> Option<&syn::Type> {
101	let path = type_as_path(typ)?;
102
103	let candidates = &[
104		["Option"].as_slice(),
105		["std", "option", "Option"].as_slice(),
106		["core", "option", "Option"].as_slice(),
107		["", "std", "option", "Option"].as_slice(),
108		["", "core", "option", "Option"].as_slice(),
109	];
110
111	if !path_is_one_of(path, candidates) {
112		return None;
113	}
114
115	let last = path.segments.last()?;
116	let arguments = match &last.arguments {
117		syn::PathArguments::AngleBracketed(args) => args,
118		_ => return None,
119	};
120
121	match &arguments.args[0] {
122		syn::GenericArgument::Type(x) => Some(x),
123		_ => None,
124	}
125}
126
127pub fn type_is_row_stream(typ: &syn::Type) -> bool {
128	let candidates = &[
129		["RowStream"].as_slice(),
130		["tokio_postgres", "RowStream"].as_slice(),
131		["", "tokio_postgres", "RowStream"].as_slice(),
132		["RowIter"].as_slice(),
133		["postgres", "RowIter"].as_slice(),
134		["", "postgres", "RowIter"].as_slice(),
135	];
136
137	if let Some(path) = type_as_path(typ) {
138		path_is_one_of(path, candidates)
139	} else {
140		false
141	}
142}
143
144pub fn type_is_unit(typ: &syn::Type) -> bool {
145	if let syn::Type::Tuple(tuple) = type_strip_paren(typ) {
146		tuple.elems.is_empty()
147	} else {
148		false
149	}
150}
151
152pub fn type_is_u64(typ: &syn::Type) -> bool {
153	match type_as_path(typ) {
154		None => false,
155		Some(path) => path.is_ident("u64"),
156	}
157}
158
159pub fn path_is(path: &syn::Path, components: &[&str]) -> bool {
160	if path.segments.len() != components.len() {
161		return false;
162	}
163
164	for (segment, component) in path.segments.iter().zip(components.iter()) {
165		if segment.ident != component {
166			return false;
167		}
168	}
169
170	true
171}
172
173pub fn path_is_one_of(path: &syn::Path, candidates: &[&[&str]]) -> bool {
174	candidates.iter().any(|candidate| path_is(path, candidate))
175}
176
177pub fn type_span(typ: &syn::Type) -> proc_macro2::Span {
178	if let Some(path) = type_as_path(typ) {
179		path.segments.last().span()
180	} else {
181		typ.span()
182	}
183}
184
185pub fn type_result_ok(typ: &syn::Type) -> Option<&syn::Type> {
186	let args = type_result_args(typ)?;
187	if args.args.is_empty() {
188		return None;
189	}
190
191	if let syn::GenericArgument::Type(typ) = &args.args[0] {
192		Some(typ)
193	} else {
194		None
195	}
196}
197
198pub fn type_result_err(typ: &syn::Type) -> Option<&syn::Type> {
199	let args = type_result_args(typ)?;
200	if args.args.len() != 2 {
201		return None;
202	}
203
204	if let syn::GenericArgument::Type(typ) = &args.args[1] {
205		Some(typ)
206	} else {
207		None
208	}
209}
210
211pub fn return_type_ok_span(signature: &syn::Signature) -> proc_macro2::Span {
212	match &signature.output {
213		syn::ReturnType::Default => signature.ident.span(),
214		syn::ReturnType::Type(_, typ) => {
215			if let Some(typ) = type_result_ok(typ) {
216				type_span(typ)
217			} else {
218				type_span(typ)
219			}
220		}
221	}
222}
223
224pub fn return_type_err_span(signature: &syn::Signature) -> proc_macro2::Span {
225	match &signature.output {
226		syn::ReturnType::Default => signature.ident.span(),
227		syn::ReturnType::Type(_, typ) => {
228			if let Some(typ) = type_result_err(typ) {
229				type_span(typ)
230			} else {
231				type_span(typ)
232			}
233		}
234	}
235}