edgedb_codegen_core/
lib.rs

1use std::sync::Arc;
2
3use check_keyword::CheckKeyword;
4use edgedb_protocol::codec::CAL_DATE_DURATION;
5use edgedb_protocol::codec::CAL_LOCAL_DATE;
6use edgedb_protocol::codec::CAL_LOCAL_DATETIME;
7use edgedb_protocol::codec::CAL_LOCAL_TIME;
8use edgedb_protocol::codec::CAL_RELATIVE_DURATION;
9use edgedb_protocol::codec::CFG_MEMORY;
10use edgedb_protocol::codec::PGVECTOR_VECTOR;
11use edgedb_protocol::codec::STD_BIGINT;
12use edgedb_protocol::codec::STD_BOOL;
13use edgedb_protocol::codec::STD_BYTES;
14use edgedb_protocol::codec::STD_DATETIME;
15use edgedb_protocol::codec::STD_DECIMAL;
16use edgedb_protocol::codec::STD_DURATION;
17use edgedb_protocol::codec::STD_FLOAT32;
18use edgedb_protocol::codec::STD_FLOAT64;
19use edgedb_protocol::codec::STD_INT16;
20use edgedb_protocol::codec::STD_INT32;
21use edgedb_protocol::codec::STD_INT64;
22use edgedb_protocol::codec::STD_JSON;
23use edgedb_protocol::codec::STD_STR;
24use edgedb_protocol::codec::STD_UUID;
25use edgedb_protocol::common::Capabilities;
26use edgedb_protocol::common::Cardinality;
27use edgedb_protocol::common::CompilationOptions;
28use edgedb_protocol::common::IoFormat;
29use edgedb_protocol::descriptors::Descriptor;
30use edgedb_protocol::descriptors::ShapeElement;
31use edgedb_protocol::descriptors::TupleElement;
32use edgedb_protocol::descriptors::TypePos;
33use edgedb_protocol::descriptors::Typedesc;
34use edgedb_protocol::model::Uuid;
35use edgedb_protocol::server_message::CommandDataDescription1;
36use edgedb_tokio::create_client;
37use edgedb_tokio::raw::Pool;
38use edgedb_tokio::raw::PoolState;
39use edgedb_tokio::Builder;
40use heck::ToPascalCase;
41use heck::ToSnakeCase;
42use proc_macro2::TokenStream;
43use quote::format_ident;
44use quote::quote;
45use syn::punctuated::Punctuated;
46use syn::Token;
47use tokio::runtime::Runtime;
48use typed_builder::TypedBuilder;
49
50pub use crate::constants::*;
51pub use crate::errors::*;
52pub use crate::utils::*;
53
54mod constants;
55mod errors;
56mod utils;
57
58/// Get the descriptor asynchronously.
59pub async fn get_descriptor(query: &str) -> Result<CommandDataDescription1> {
60	let builder = Builder::new().build_env().await?;
61	let state = Arc::new(PoolState::default());
62	let pool = Pool::new(&builder);
63	let mut pool_connection = pool.acquire().await?;
64	let connection = pool_connection.inner();
65	let allow_capabilities = Capabilities::MODIFICATIONS | Capabilities::DDL;
66	let flags = CompilationOptions {
67		implicit_limit: None,
68		implicit_typenames: false,
69		implicit_typeids: false,
70		explicit_objectids: true,
71		allow_capabilities,
72		io_format: IoFormat::Binary,
73		expected_cardinality: Cardinality::Many,
74	};
75
76	Ok(connection.parse(&flags, query, &state).await?)
77}
78
79/// Get the descriptor synchronously.
80pub fn get_descriptor_sync(query: &str) -> Result<CommandDataDescription1> {
81	let rt = Runtime::new()?;
82	let descriptor = rt.block_on(async { get_descriptor(query).await })?;
83
84	Ok(descriptor)
85}
86
87pub fn generate_rust_from_query(
88	descriptor: &CommandDataDescription1,
89	name: &str,
90	query: &str,
91) -> Result<TokenStream> {
92	let input_ident = format_ident!("{INPUT_NAME}");
93	let output_ident = format_ident!("{OUTPUT_NAME}");
94	let props_ident = format_ident!("{PROPS_NAME}");
95	let query_ident = format_ident!("{QUERY_NAME}");
96	let query_prop_ident = format_ident!("{QUERY_PROP_NAME}");
97	let transaction_ident = format_ident!("{TRANSACTION_NAME}");
98	let transaction_prop_ident = format_ident!("{TRANSACTION_PROP_NAME}");
99	let module_name: syn::Ident = format_ident!("{}", name.to_snake_case());
100	let input = descriptor.input.decode()?;
101	let output = descriptor.output.decode()?;
102	let mut tokens: TokenStream = TokenStream::new();
103
104	explore_descriptor(
105		ExploreDescriptorProps::builder()
106			.typedesc(&input)
107			.is_input()
108			.is_root()
109			.descriptor(input.root())
110			.root_name(INPUT_NAME)
111			.build(),
112		&mut tokens,
113	)?;
114	explore_descriptor(
115		ExploreDescriptorProps::builder()
116			.typedesc(&output)
117			.is_root()
118			.descriptor(output.root())
119			.root_name(OUTPUT_NAME)
120			.build(),
121		&mut tokens,
122	)?;
123
124	let query_method = match descriptor.result_cardinality {
125		Cardinality::NoResult => quote!(execute),
126		Cardinality::AtMostOne => quote!(query_single),
127		Cardinality::One => quote!(query_required_single),
128		Cardinality::Many | Cardinality::AtLeastOne => quote!(query),
129	};
130
131	let mut query_props = vec![quote!(#query_prop_ident: &#EXPORTS_IDENT::edgedb_tokio::Client)];
132	let mut transaction_props =
133		vec![quote!(#transaction_prop_ident: &mut #EXPORTS_IDENT::edgedb_tokio::Transaction)];
134	let args = vec![
135		quote!(#QUERY_CONSTANT),
136		input.root().map_or(quote!(&()), |_| quote!(#props_ident)),
137	];
138	let inner_return = output.root().map_or(quote!(()), |_| quote!(#output_ident));
139	let returns = wrap_token_with_cardinality(Some(descriptor.result_cardinality), inner_return);
140
141	if input.root().is_some() {
142		query_props.push(quote!(#props_ident: &#input_ident));
143		transaction_props.push(quote!(#props_ident: &#input_ident));
144	}
145
146	let token_stream = quote! {
147		pub mod #module_name {
148			use ::edgedb_codegen::exports as #EXPORTS_IDENT;
149
150			/// Execute the desired query.
151			#[cfg(feature = "query")]
152			pub async fn #query_ident(#(#query_props),*) -> core::result::Result<#returns, #EXPORTS_IDENT::edgedb_errors::Error> {
153				#query_prop_ident.#query_method(#(#args),*).await
154			}
155
156			/// Compose the query as part of a larger transaction.
157			#[cfg(feature = "query")]
158			pub async fn #transaction_ident(#(#transaction_props),*) -> core::result::Result<#returns, #EXPORTS_IDENT::edgedb_errors::Error> {
159				#transaction_prop_ident.#query_method(#(#args),*).await
160			}
161
162			#tokens
163
164			/// The original query string provided to the macro. Can be reused in your codebase.
165			pub const #QUERY_CONSTANT: &str = #query;
166		}
167	};
168
169	Ok(token_stream)
170}
171
172fn wrap_token_with_cardinality(
173	cardinality: Option<Cardinality>,
174	token: TokenStream,
175) -> TokenStream {
176	let Some(cardinality) = cardinality else {
177		return token;
178	};
179
180	match cardinality {
181		Cardinality::NoResult | Cardinality::AtMostOne => quote!(Option<#token>),
182		Cardinality::One => token,
183		Cardinality::Many | Cardinality::AtLeastOne => quote!(Vec<#token>),
184	}
185}
186
187#[derive(Debug, TypedBuilder)]
188struct ExploreDescriptorProps<'a> {
189	typedesc: &'a Typedesc,
190	#[builder(setter(strip_bool(fallback = is_input_bool)))]
191	is_input: bool,
192	#[builder(setter(strip_bool(fallback = is_root_bool)))]
193	is_root: bool,
194	descriptor: Option<&'a Descriptor>,
195	root_name: &'a str,
196}
197
198type PartialExploreDescriptorProps<'a> =
199	ExploreDescriptorPropsBuilder<'a, ((&'a Typedesc,), (bool,), (bool,), (), ())>;
200
201impl<'a> ExploreDescriptorProps<'a> {
202	fn into_props(self) -> PartialExploreDescriptorProps<'a> {
203		let Self {
204			typedesc, is_input, ..
205		} = self;
206
207		Self::builder()
208			.typedesc(typedesc)
209			.is_input_bool(is_input)
210			.is_root_bool(false)
211	}
212}
213
214fn explore_descriptor(
215	props @ ExploreDescriptorProps {
216		typedesc,
217		is_input,
218		is_root,
219		descriptor,
220		root_name,
221	}: ExploreDescriptorProps,
222	tokens: &mut TokenStream,
223) -> Result<Option<TokenStream>> {
224	let root_ident = format_ident!("{root_name}");
225
226	let Some(descriptor) = descriptor else {
227		if is_root {
228			tokens.extend(quote!(pub type #root_ident = ();));
229		}
230
231		return Ok(None);
232	};
233
234	match descriptor {
235		Descriptor::Set(set) => {
236			let set_descriptor = typedesc.get(set.type_pos).ok();
237			let sub_root_name = format!("{root_name}Set");
238			let props = props
239				.into_props()
240				.descriptor(set_descriptor)
241				.root_name(&sub_root_name)
242				.build();
243			let result = explore_descriptor(props, tokens)?.map(|result| quote!(Vec<#result>));
244
245			if is_root {
246				tokens.extend(quote!(pub type #root_ident = #result;));
247				Ok(Some(quote!(#root_ident)))
248			} else {
249				Ok(result)
250			}
251		}
252		Descriptor::ObjectShape(object) => {
253			let result = explore_object_shape_descriptor(
254				StructElement::from_shape(&object.elements),
255				typedesc,
256				root_name,
257				is_input,
258				tokens,
259			)?;
260
261			Ok(result)
262		}
263		Descriptor::BaseScalar(base_scalar) => {
264			let result = uuid_to_token_name(&base_scalar.id);
265
266			if is_root {
267				tokens.extend(quote!(pub type #root_ident = #result;));
268				Ok(Some(quote!(#root_ident)))
269			} else {
270				Ok(Some(result))
271			}
272		}
273		Descriptor::Scalar(scalar) => {
274			let props = props
275				.into_props()
276				.descriptor(typedesc.get(scalar.base_type_pos).ok())
277				.root_name(root_name)
278				.build();
279
280			explore_descriptor(props, tokens)
281		}
282		Descriptor::Tuple(tuple) => {
283			let mut tuple_tokens = Punctuated::<_, Token![,]>::new();
284
285			for (index, element) in tuple.element_types.iter().enumerate() {
286				let sub_root_name = format!("{root_name}{index}");
287				let result = explore_descriptor(
288					ExploreDescriptorProps::builder()
289						.typedesc(typedesc)
290						.is_input_bool(is_input)
291						.descriptor(typedesc.get(*element).ok())
292						.root_name(&sub_root_name)
293						.build(),
294					tokens,
295				)?;
296
297				tuple_tokens.push(result);
298			}
299
300			let result = quote!((#tuple_tokens));
301
302			if is_root {
303				tokens.extend(quote!(pub type #root_ident = #result;));
304				Ok(Some(quote!(#root_ident)))
305			} else {
306				Ok(Some(result))
307			}
308		}
309		Descriptor::NamedTuple(named_tuple) => {
310			let result = explore_object_shape_descriptor(
311				StructElement::from_named_tuple(&named_tuple.elements),
312				typedesc,
313				root_name,
314				is_input,
315				tokens,
316			)?;
317
318			Ok(result)
319		}
320		Descriptor::Array(array) => {
321			let array_descriptor = typedesc.get(array.type_pos).ok();
322			let sub_root_name = format!("{root_name}Array");
323			let props = props
324				.into_props()
325				.descriptor(array_descriptor)
326				.root_name(&sub_root_name)
327				.build();
328			let result = explore_descriptor(props, tokens)?.map(|result| quote!(Vec<#result>));
329
330			if is_root {
331				tokens.extend(quote!(pub type #root_ident = #result;));
332				Ok(Some(quote!(#root_ident)))
333			} else {
334				Ok(result)
335			}
336		}
337		// TODO once `edgedb_protocol` is updated to 2.0 it should be possible to get the enum name.
338		Descriptor::Enumeration(_) => {
339			let result = Some(quote!(String));
340
341			if is_root {
342				tokens.extend(quote!(pub type #root_ident = #result;));
343				Ok(Some(quote!(#root_ident)))
344			} else {
345				Ok(result)
346			}
347		}
348		Descriptor::InputShape(object) => {
349			let result = explore_object_shape_descriptor(
350				StructElement::from_shape(&object.elements),
351				typedesc,
352				root_name,
353				is_input,
354				tokens,
355			)?;
356
357			Ok(result)
358		}
359		Descriptor::Range(range) => {
360			let range_descriptor = typedesc.get(range.type_pos).ok();
361			let sub_root_name = format!("{root_name}Range");
362			let props = props
363				.into_props()
364				.descriptor(range_descriptor)
365				.root_name(&sub_root_name)
366				.build();
367			let result = explore_descriptor(props, tokens)?
368				.map(|result| quote!(#EXPORTS_IDENT::edgedb_protocol::model::Range<#result>));
369
370			if is_root {
371				tokens.extend(quote!(pub type #root_ident = #result;));
372				Ok(Some(quote!(#root_ident)))
373			} else {
374				Ok(result)
375			}
376		}
377		Descriptor::MultiRange(_) => todo!("`multirange` not in the `edgedb_protocol` crate"),
378		Descriptor::TypeAnnotation(_) => todo!("type annotations are not supported"),
379	}
380}
381
382pub fn explore_object_shape_descriptor(
383	elements: Vec<StructElement<'_>>,
384	typedesc: &Typedesc,
385	root_name: &str,
386	is_input: bool,
387	tokens: &mut TokenStream,
388) -> Result<Option<TokenStream>> {
389	let mut impl_named_args = vec![];
390	let mut struct_fields = vec![];
391	let root_ident: syn::Ident = syn::parse_str(root_name)?;
392
393	for element in elements {
394		let descriptor = typedesc.get(element.type_pos()).ok();
395		let name = &element.name();
396		let safe_name = name.to_snake_case().into_safe();
397		let safe_name_ident = format_ident!("{safe_name}");
398		let pascal_name = name.to_pascal_case();
399		let sub_root_name = format!("{root_name}{pascal_name}").into_safe();
400		let sub_props = ExploreDescriptorProps::builder()
401			.typedesc(typedesc)
402			.is_input_bool(is_input)
403			.descriptor(descriptor)
404			.root_name(&sub_root_name)
405			.build();
406		let output = explore_descriptor(sub_props, tokens)?;
407		let output_token = element.wrap(&output);
408		let serde_annotation = (&safe_name != name).then_some(quote!(
409			#[cfg_attr(feature = "serde", serde(rename = #name))]
410		));
411		let builder_fields = {
412			match element.cardinality() {
413				Cardinality::AtMostOne => {
414					let fallback_ident = format_ident!("{safe_name_ident}_opt");
415					Some(quote!(default, setter(into, strip_option(fallback = #fallback_ident))))
416				}
417				Cardinality::One => Some(quote!(setter(into))),
418				Cardinality::Many => Some(quote!(default)),
419				Cardinality::NoResult | Cardinality::AtLeastOne => None,
420			}
421		};
422		let builder_annotation = (is_input && builder_fields.is_some()).then_some(quote!(
423			#[cfg_attr(feature = "builder", builder(#builder_fields))]
424		));
425
426		struct_fields.push(quote! {
427			#serde_annotation
428			#builder_annotation
429			pub #safe_name_ident: #output_token,
430		});
431
432		if is_input {
433			impl_named_args.push(quote!(#name => self.#safe_name_ident.clone(),));
434		}
435	}
436
437	let impl_tokens = is_input.then_some(quote! {
438		impl #EXPORTS_IDENT::edgedb_protocol::query_arg::QueryArgs for #root_ident {
439			fn encode(&self, encoder: &mut #EXPORTS_IDENT::edgedb_protocol::query_arg::Encoder) -> core::result::Result<(), #EXPORTS_IDENT::edgedb_errors::Error> {
440				let map = #EXPORTS_IDENT::edgedb_protocol::named_args! {
441					#(#impl_named_args)*
442				};
443
444				map.encode(encoder)
445			}
446		}
447	});
448	let typed_builder_tokens = is_input.then_some(
449		quote!(#[cfg_attr(feature = "builder", derive(#EXPORTS_IDENT::typed_builder::TypedBuilder))]),
450	);
451	let struct_tokens = quote! {
452		#[derive(Clone, Debug)]
453		#typed_builder_tokens
454		#[cfg_attr(feature = "query", derive(#EXPORTS_IDENT::edgedb_derive::Queryable))]
455		#[cfg_attr(feature = "serde", derive(#EXPORTS_IDENT::serde::Serialize, #EXPORTS_IDENT::serde::Deserialize))]
456		pub struct #root_ident {
457			#(#struct_fields)*
458		}
459
460		#impl_tokens
461	};
462
463	tokens.extend(struct_tokens);
464
465	Ok(Some(quote!(#root_ident)))
466}
467
468pub enum StructElement<'a> {
469	Shape(&'a ShapeElement),
470	Tuple(&'a TupleElement),
471}
472
473impl<'a> StructElement<'a> {
474	pub fn from_shape(elements: &'a [ShapeElement]) -> Vec<StructElement<'a>> {
475		elements.iter().map(From::from).collect::<Vec<_>>()
476	}
477
478	pub fn from_named_tuple(elements: &'a [TupleElement]) -> Vec<StructElement<'a>> {
479		elements.iter().map(From::from).collect::<Vec<_>>()
480	}
481
482	pub fn name(&self) -> String {
483		match self {
484			StructElement::Shape(shape) => shape.name.clone(),
485			StructElement::Tuple(tuple) => tuple.name.clone(),
486		}
487	}
488
489	pub fn type_pos(&self) -> TypePos {
490		match self {
491			StructElement::Shape(shape) => shape.type_pos,
492			StructElement::Tuple(tuple) => tuple.type_pos,
493		}
494	}
495
496	pub fn wrap(&self, token: &Option<TokenStream>) -> TokenStream {
497		if let Cardinality::AtMostOne = self.cardinality() {
498			quote!(Option<#token>)
499		} else {
500			quote!(#token)
501		}
502	}
503
504	pub fn cardinality(&self) -> Cardinality {
505		match self {
506			StructElement::Shape(shape) => shape.cardinality.unwrap_or(Cardinality::NoResult),
507			StructElement::Tuple(_) => Cardinality::NoResult,
508		}
509	}
510}
511
512impl<'a> From<&'a ShapeElement> for StructElement<'a> {
513	fn from(value: &'a ShapeElement) -> Self {
514		StructElement::Shape(value)
515	}
516}
517
518impl<'a> From<&'a TupleElement> for StructElement<'a> {
519	fn from(value: &'a TupleElement) -> Self {
520		StructElement::Tuple(value)
521	}
522}
523
524fn uuid_to_token_name(uuid: &Uuid) -> TokenStream {
525	match *uuid {
526		STD_UUID => quote!(#EXPORTS_IDENT::uuid::Uuid),
527		STD_STR => quote!(String),
528		STD_BYTES => quote!(#EXPORTS_IDENT::bytes::Bytes),
529		STD_INT16 => quote!(i16),
530		STD_INT32 => quote!(i32),
531		STD_INT64 => quote!(i64),
532		STD_FLOAT32 => quote!(f32),
533		STD_FLOAT64 => quote!(f64),
534		#[cfg(not(feature = "with_bigdecimal"))]
535		STD_DECIMAL => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Decimal),
536		#[cfg(feature = "with_bigdecimal")]
537		STD_DECIMAL => quote!(#EXPORTS_IDENT::bigdecimal::BigDecimal),
538		STD_BOOL => quote!(bool),
539		#[cfg(not(feature = "with_chrono"))]
540		STD_DATETIME => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Datetime),
541		#[cfg(feature = "with_chrono")]
542		STD_DATETIME => quote!(#EXPORTS_IDENT::chrono::DateTime<#EXPORTS_IDENT::chrono::Utc>),
543		#[cfg(not(feature = "with_chrono"))]
544		CAL_LOCAL_DATETIME => quote!(#EXPORTS_IDENT::edgedb_protocol::model::LocalDatetime),
545		#[cfg(feature = "with_chrono")]
546		CAL_LOCAL_DATETIME => quote!(#EXPORTS_IDENT::chrono::NaiveDateTime),
547		#[cfg(not(feature = "with_chrono"))]
548		CAL_LOCAL_DATE => quote!(#EXPORTS_IDENT::edgedb_protocol::model::LocalDate),
549		#[cfg(feature = "with_chrono")]
550		CAL_LOCAL_DATE => quote!(#EXPORTS_IDENT::chrono::NaiveDate),
551		#[cfg(not(feature = "with_chrono"))]
552		CAL_LOCAL_TIME => quote!(#EXPORTS_IDENT::edgedb_protocol::model::LocalTime),
553		#[cfg(feature = "with_chrono")]
554		CAL_LOCAL_TIME => quote!(#EXPORTS_IDENT::chrono::NaiveTime),
555		STD_DURATION => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Duration),
556		CAL_RELATIVE_DURATION => quote!(#EXPORTS_IDENT::edgedb_protocol::model::RelativeDuration),
557		CAL_DATE_DURATION => quote!(#EXPORTS_IDENT::edgedb_protocol::model::DateDuration),
558		STD_JSON => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Json),
559		#[cfg(not(feature = "with_bigint"))]
560		STD_BIGINT => quote!(#EXPORTS_IDENT::edgedb_protocol::model::BigInt),
561		#[cfg(feature = "with_bigint")]
562		STD_BIGINT => quote!(#EXPORTS_IDENT::num_bigint::BigInt),
563		CFG_MEMORY => quote!(#EXPORTS_IDENT::edgedb_protocol::model::ConfigMemory),
564		PGVECTOR_VECTOR => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Vector),
565		_ => quote!(()),
566	}
567}
568
569pub async fn get_types() -> Result<()> {
570	let client = create_client().await?;
571	let json = client.query_json(TYPES_QUERY, &()).await?;
572	log::debug!("{}", json.as_ref());
573
574	Ok(())
575}