1use std::sync::Arc;
2
3use check_keyword::CheckKeyword;
4use edgedb_protocol::codec::CAL_DATE_DURATION;
5use edgedb_protocol::codec::CAL_LOCAL_DATE;
6use edgedb_protocol::codec::CAL_LOCAL_DATETIME;
7use edgedb_protocol::codec::CAL_LOCAL_TIME;
8use edgedb_protocol::codec::CAL_RELATIVE_DURATION;
9use edgedb_protocol::codec::CFG_MEMORY;
10use edgedb_protocol::codec::PGVECTOR_VECTOR;
11use edgedb_protocol::codec::STD_BIGINT;
12use edgedb_protocol::codec::STD_BOOL;
13use edgedb_protocol::codec::STD_BYTES;
14use edgedb_protocol::codec::STD_DATETIME;
15use edgedb_protocol::codec::STD_DECIMAL;
16use edgedb_protocol::codec::STD_DURATION;
17use edgedb_protocol::codec::STD_FLOAT32;
18use edgedb_protocol::codec::STD_FLOAT64;
19use edgedb_protocol::codec::STD_INT16;
20use edgedb_protocol::codec::STD_INT32;
21use edgedb_protocol::codec::STD_INT64;
22use edgedb_protocol::codec::STD_JSON;
23use edgedb_protocol::codec::STD_STR;
24use edgedb_protocol::codec::STD_UUID;
25use edgedb_protocol::common::Capabilities;
26use edgedb_protocol::common::Cardinality;
27use edgedb_protocol::common::CompilationOptions;
28use edgedb_protocol::common::IoFormat;
29use edgedb_protocol::descriptors::Descriptor;
30use edgedb_protocol::descriptors::ShapeElement;
31use edgedb_protocol::descriptors::TupleElement;
32use edgedb_protocol::descriptors::TypePos;
33use edgedb_protocol::descriptors::Typedesc;
34use edgedb_protocol::model::Uuid;
35use edgedb_protocol::server_message::CommandDataDescription1;
36use edgedb_tokio::create_client;
37use edgedb_tokio::raw::Pool;
38use edgedb_tokio::raw::PoolState;
39use edgedb_tokio::Builder;
40use heck::ToPascalCase;
41use heck::ToSnakeCase;
42use proc_macro2::TokenStream;
43use quote::format_ident;
44use quote::quote;
45use syn::punctuated::Punctuated;
46use syn::Token;
47use tokio::runtime::Runtime;
48use typed_builder::TypedBuilder;
49
50pub use crate::constants::*;
51pub use crate::errors::*;
52pub use crate::utils::*;
53
54mod constants;
55mod errors;
56mod utils;
57
58pub async fn get_descriptor(query: &str) -> Result<CommandDataDescription1> {
60 let builder = Builder::new().build_env().await?;
61 let state = Arc::new(PoolState::default());
62 let pool = Pool::new(&builder);
63 let mut pool_connection = pool.acquire().await?;
64 let connection = pool_connection.inner();
65 let allow_capabilities = Capabilities::MODIFICATIONS | Capabilities::DDL;
66 let flags = CompilationOptions {
67 implicit_limit: None,
68 implicit_typenames: false,
69 implicit_typeids: false,
70 explicit_objectids: true,
71 allow_capabilities,
72 io_format: IoFormat::Binary,
73 expected_cardinality: Cardinality::Many,
74 };
75
76 Ok(connection.parse(&flags, query, &state).await?)
77}
78
79pub fn get_descriptor_sync(query: &str) -> Result<CommandDataDescription1> {
81 let rt = Runtime::new()?;
82 let descriptor = rt.block_on(async { get_descriptor(query).await })?;
83
84 Ok(descriptor)
85}
86
87pub fn generate_rust_from_query(
88 descriptor: &CommandDataDescription1,
89 name: &str,
90 query: &str,
91) -> Result<TokenStream> {
92 let input_ident = format_ident!("{INPUT_NAME}");
93 let output_ident = format_ident!("{OUTPUT_NAME}");
94 let props_ident = format_ident!("{PROPS_NAME}");
95 let query_ident = format_ident!("{QUERY_NAME}");
96 let query_prop_ident = format_ident!("{QUERY_PROP_NAME}");
97 let transaction_ident = format_ident!("{TRANSACTION_NAME}");
98 let transaction_prop_ident = format_ident!("{TRANSACTION_PROP_NAME}");
99 let module_name: syn::Ident = format_ident!("{}", name.to_snake_case());
100 let input = descriptor.input.decode()?;
101 let output = descriptor.output.decode()?;
102 let mut tokens: TokenStream = TokenStream::new();
103
104 explore_descriptor(
105 ExploreDescriptorProps::builder()
106 .typedesc(&input)
107 .is_input()
108 .is_root()
109 .descriptor(input.root())
110 .root_name(INPUT_NAME)
111 .build(),
112 &mut tokens,
113 )?;
114 explore_descriptor(
115 ExploreDescriptorProps::builder()
116 .typedesc(&output)
117 .is_root()
118 .descriptor(output.root())
119 .root_name(OUTPUT_NAME)
120 .build(),
121 &mut tokens,
122 )?;
123
124 let query_method = match descriptor.result_cardinality {
125 Cardinality::NoResult => quote!(execute),
126 Cardinality::AtMostOne => quote!(query_single),
127 Cardinality::One => quote!(query_required_single),
128 Cardinality::Many | Cardinality::AtLeastOne => quote!(query),
129 };
130
131 let mut query_props = vec![quote!(#query_prop_ident: &#EXPORTS_IDENT::edgedb_tokio::Client)];
132 let mut transaction_props =
133 vec![quote!(#transaction_prop_ident: &mut #EXPORTS_IDENT::edgedb_tokio::Transaction)];
134 let args = vec![
135 quote!(#QUERY_CONSTANT),
136 input.root().map_or(quote!(&()), |_| quote!(#props_ident)),
137 ];
138 let inner_return = output.root().map_or(quote!(()), |_| quote!(#output_ident));
139 let returns = wrap_token_with_cardinality(Some(descriptor.result_cardinality), inner_return);
140
141 if input.root().is_some() {
142 query_props.push(quote!(#props_ident: &#input_ident));
143 transaction_props.push(quote!(#props_ident: &#input_ident));
144 }
145
146 let token_stream = quote! {
147 pub mod #module_name {
148 use ::edgedb_codegen::exports as #EXPORTS_IDENT;
149
150 #[cfg(feature = "query")]
152 pub async fn #query_ident(#(#query_props),*) -> core::result::Result<#returns, #EXPORTS_IDENT::edgedb_errors::Error> {
153 #query_prop_ident.#query_method(#(#args),*).await
154 }
155
156 #[cfg(feature = "query")]
158 pub async fn #transaction_ident(#(#transaction_props),*) -> core::result::Result<#returns, #EXPORTS_IDENT::edgedb_errors::Error> {
159 #transaction_prop_ident.#query_method(#(#args),*).await
160 }
161
162 #tokens
163
164 pub const #QUERY_CONSTANT: &str = #query;
166 }
167 };
168
169 Ok(token_stream)
170}
171
172fn wrap_token_with_cardinality(
173 cardinality: Option<Cardinality>,
174 token: TokenStream,
175) -> TokenStream {
176 let Some(cardinality) = cardinality else {
177 return token;
178 };
179
180 match cardinality {
181 Cardinality::NoResult | Cardinality::AtMostOne => quote!(Option<#token>),
182 Cardinality::One => token,
183 Cardinality::Many | Cardinality::AtLeastOne => quote!(Vec<#token>),
184 }
185}
186
187#[derive(Debug, TypedBuilder)]
188struct ExploreDescriptorProps<'a> {
189 typedesc: &'a Typedesc,
190 #[builder(setter(strip_bool(fallback = is_input_bool)))]
191 is_input: bool,
192 #[builder(setter(strip_bool(fallback = is_root_bool)))]
193 is_root: bool,
194 descriptor: Option<&'a Descriptor>,
195 root_name: &'a str,
196}
197
198type PartialExploreDescriptorProps<'a> =
199 ExploreDescriptorPropsBuilder<'a, ((&'a Typedesc,), (bool,), (bool,), (), ())>;
200
201impl<'a> ExploreDescriptorProps<'a> {
202 fn into_props(self) -> PartialExploreDescriptorProps<'a> {
203 let Self {
204 typedesc, is_input, ..
205 } = self;
206
207 Self::builder()
208 .typedesc(typedesc)
209 .is_input_bool(is_input)
210 .is_root_bool(false)
211 }
212}
213
214fn explore_descriptor(
215 props @ ExploreDescriptorProps {
216 typedesc,
217 is_input,
218 is_root,
219 descriptor,
220 root_name,
221 }: ExploreDescriptorProps,
222 tokens: &mut TokenStream,
223) -> Result<Option<TokenStream>> {
224 let root_ident = format_ident!("{root_name}");
225
226 let Some(descriptor) = descriptor else {
227 if is_root {
228 tokens.extend(quote!(pub type #root_ident = ();));
229 }
230
231 return Ok(None);
232 };
233
234 match descriptor {
235 Descriptor::Set(set) => {
236 let set_descriptor = typedesc.get(set.type_pos).ok();
237 let sub_root_name = format!("{root_name}Set");
238 let props = props
239 .into_props()
240 .descriptor(set_descriptor)
241 .root_name(&sub_root_name)
242 .build();
243 let result = explore_descriptor(props, tokens)?.map(|result| quote!(Vec<#result>));
244
245 if is_root {
246 tokens.extend(quote!(pub type #root_ident = #result;));
247 Ok(Some(quote!(#root_ident)))
248 } else {
249 Ok(result)
250 }
251 }
252 Descriptor::ObjectShape(object) => {
253 let result = explore_object_shape_descriptor(
254 StructElement::from_shape(&object.elements),
255 typedesc,
256 root_name,
257 is_input,
258 tokens,
259 )?;
260
261 Ok(result)
262 }
263 Descriptor::BaseScalar(base_scalar) => {
264 let result = uuid_to_token_name(&base_scalar.id);
265
266 if is_root {
267 tokens.extend(quote!(pub type #root_ident = #result;));
268 Ok(Some(quote!(#root_ident)))
269 } else {
270 Ok(Some(result))
271 }
272 }
273 Descriptor::Scalar(scalar) => {
274 let props = props
275 .into_props()
276 .descriptor(typedesc.get(scalar.base_type_pos).ok())
277 .root_name(root_name)
278 .build();
279
280 explore_descriptor(props, tokens)
281 }
282 Descriptor::Tuple(tuple) => {
283 let mut tuple_tokens = Punctuated::<_, Token![,]>::new();
284
285 for (index, element) in tuple.element_types.iter().enumerate() {
286 let sub_root_name = format!("{root_name}{index}");
287 let result = explore_descriptor(
288 ExploreDescriptorProps::builder()
289 .typedesc(typedesc)
290 .is_input_bool(is_input)
291 .descriptor(typedesc.get(*element).ok())
292 .root_name(&sub_root_name)
293 .build(),
294 tokens,
295 )?;
296
297 tuple_tokens.push(result);
298 }
299
300 let result = quote!((#tuple_tokens));
301
302 if is_root {
303 tokens.extend(quote!(pub type #root_ident = #result;));
304 Ok(Some(quote!(#root_ident)))
305 } else {
306 Ok(Some(result))
307 }
308 }
309 Descriptor::NamedTuple(named_tuple) => {
310 let result = explore_object_shape_descriptor(
311 StructElement::from_named_tuple(&named_tuple.elements),
312 typedesc,
313 root_name,
314 is_input,
315 tokens,
316 )?;
317
318 Ok(result)
319 }
320 Descriptor::Array(array) => {
321 let array_descriptor = typedesc.get(array.type_pos).ok();
322 let sub_root_name = format!("{root_name}Array");
323 let props = props
324 .into_props()
325 .descriptor(array_descriptor)
326 .root_name(&sub_root_name)
327 .build();
328 let result = explore_descriptor(props, tokens)?.map(|result| quote!(Vec<#result>));
329
330 if is_root {
331 tokens.extend(quote!(pub type #root_ident = #result;));
332 Ok(Some(quote!(#root_ident)))
333 } else {
334 Ok(result)
335 }
336 }
337 Descriptor::Enumeration(_) => {
339 let result = Some(quote!(String));
340
341 if is_root {
342 tokens.extend(quote!(pub type #root_ident = #result;));
343 Ok(Some(quote!(#root_ident)))
344 } else {
345 Ok(result)
346 }
347 }
348 Descriptor::InputShape(object) => {
349 let result = explore_object_shape_descriptor(
350 StructElement::from_shape(&object.elements),
351 typedesc,
352 root_name,
353 is_input,
354 tokens,
355 )?;
356
357 Ok(result)
358 }
359 Descriptor::Range(range) => {
360 let range_descriptor = typedesc.get(range.type_pos).ok();
361 let sub_root_name = format!("{root_name}Range");
362 let props = props
363 .into_props()
364 .descriptor(range_descriptor)
365 .root_name(&sub_root_name)
366 .build();
367 let result = explore_descriptor(props, tokens)?
368 .map(|result| quote!(#EXPORTS_IDENT::edgedb_protocol::model::Range<#result>));
369
370 if is_root {
371 tokens.extend(quote!(pub type #root_ident = #result;));
372 Ok(Some(quote!(#root_ident)))
373 } else {
374 Ok(result)
375 }
376 }
377 Descriptor::MultiRange(_) => todo!("`multirange` not in the `edgedb_protocol` crate"),
378 Descriptor::TypeAnnotation(_) => todo!("type annotations are not supported"),
379 }
380}
381
382pub fn explore_object_shape_descriptor(
383 elements: Vec<StructElement<'_>>,
384 typedesc: &Typedesc,
385 root_name: &str,
386 is_input: bool,
387 tokens: &mut TokenStream,
388) -> Result<Option<TokenStream>> {
389 let mut impl_named_args = vec![];
390 let mut struct_fields = vec![];
391 let root_ident: syn::Ident = syn::parse_str(root_name)?;
392
393 for element in elements {
394 let descriptor = typedesc.get(element.type_pos()).ok();
395 let name = &element.name();
396 let safe_name = name.to_snake_case().into_safe();
397 let safe_name_ident = format_ident!("{safe_name}");
398 let pascal_name = name.to_pascal_case();
399 let sub_root_name = format!("{root_name}{pascal_name}").into_safe();
400 let sub_props = ExploreDescriptorProps::builder()
401 .typedesc(typedesc)
402 .is_input_bool(is_input)
403 .descriptor(descriptor)
404 .root_name(&sub_root_name)
405 .build();
406 let output = explore_descriptor(sub_props, tokens)?;
407 let output_token = element.wrap(&output);
408 let serde_annotation = (&safe_name != name).then_some(quote!(
409 #[cfg_attr(feature = "serde", serde(rename = #name))]
410 ));
411 let builder_fields = {
412 match element.cardinality() {
413 Cardinality::AtMostOne => {
414 let fallback_ident = format_ident!("{safe_name_ident}_opt");
415 Some(quote!(default, setter(into, strip_option(fallback = #fallback_ident))))
416 }
417 Cardinality::One => Some(quote!(setter(into))),
418 Cardinality::Many => Some(quote!(default)),
419 Cardinality::NoResult | Cardinality::AtLeastOne => None,
420 }
421 };
422 let builder_annotation = (is_input && builder_fields.is_some()).then_some(quote!(
423 #[cfg_attr(feature = "builder", builder(#builder_fields))]
424 ));
425
426 struct_fields.push(quote! {
427 #serde_annotation
428 #builder_annotation
429 pub #safe_name_ident: #output_token,
430 });
431
432 if is_input {
433 impl_named_args.push(quote!(#name => self.#safe_name_ident.clone(),));
434 }
435 }
436
437 let impl_tokens = is_input.then_some(quote! {
438 impl #EXPORTS_IDENT::edgedb_protocol::query_arg::QueryArgs for #root_ident {
439 fn encode(&self, encoder: &mut #EXPORTS_IDENT::edgedb_protocol::query_arg::Encoder) -> core::result::Result<(), #EXPORTS_IDENT::edgedb_errors::Error> {
440 let map = #EXPORTS_IDENT::edgedb_protocol::named_args! {
441 #(#impl_named_args)*
442 };
443
444 map.encode(encoder)
445 }
446 }
447 });
448 let typed_builder_tokens = is_input.then_some(
449 quote!(#[cfg_attr(feature = "builder", derive(#EXPORTS_IDENT::typed_builder::TypedBuilder))]),
450 );
451 let struct_tokens = quote! {
452 #[derive(Clone, Debug)]
453 #typed_builder_tokens
454 #[cfg_attr(feature = "query", derive(#EXPORTS_IDENT::edgedb_derive::Queryable))]
455 #[cfg_attr(feature = "serde", derive(#EXPORTS_IDENT::serde::Serialize, #EXPORTS_IDENT::serde::Deserialize))]
456 pub struct #root_ident {
457 #(#struct_fields)*
458 }
459
460 #impl_tokens
461 };
462
463 tokens.extend(struct_tokens);
464
465 Ok(Some(quote!(#root_ident)))
466}
467
468pub enum StructElement<'a> {
469 Shape(&'a ShapeElement),
470 Tuple(&'a TupleElement),
471}
472
473impl<'a> StructElement<'a> {
474 pub fn from_shape(elements: &'a [ShapeElement]) -> Vec<StructElement<'a>> {
475 elements.iter().map(From::from).collect::<Vec<_>>()
476 }
477
478 pub fn from_named_tuple(elements: &'a [TupleElement]) -> Vec<StructElement<'a>> {
479 elements.iter().map(From::from).collect::<Vec<_>>()
480 }
481
482 pub fn name(&self) -> String {
483 match self {
484 StructElement::Shape(shape) => shape.name.clone(),
485 StructElement::Tuple(tuple) => tuple.name.clone(),
486 }
487 }
488
489 pub fn type_pos(&self) -> TypePos {
490 match self {
491 StructElement::Shape(shape) => shape.type_pos,
492 StructElement::Tuple(tuple) => tuple.type_pos,
493 }
494 }
495
496 pub fn wrap(&self, token: &Option<TokenStream>) -> TokenStream {
497 if let Cardinality::AtMostOne = self.cardinality() {
498 quote!(Option<#token>)
499 } else {
500 quote!(#token)
501 }
502 }
503
504 pub fn cardinality(&self) -> Cardinality {
505 match self {
506 StructElement::Shape(shape) => shape.cardinality.unwrap_or(Cardinality::NoResult),
507 StructElement::Tuple(_) => Cardinality::NoResult,
508 }
509 }
510}
511
512impl<'a> From<&'a ShapeElement> for StructElement<'a> {
513 fn from(value: &'a ShapeElement) -> Self {
514 StructElement::Shape(value)
515 }
516}
517
518impl<'a> From<&'a TupleElement> for StructElement<'a> {
519 fn from(value: &'a TupleElement) -> Self {
520 StructElement::Tuple(value)
521 }
522}
523
524fn uuid_to_token_name(uuid: &Uuid) -> TokenStream {
525 match *uuid {
526 STD_UUID => quote!(#EXPORTS_IDENT::uuid::Uuid),
527 STD_STR => quote!(String),
528 STD_BYTES => quote!(#EXPORTS_IDENT::bytes::Bytes),
529 STD_INT16 => quote!(i16),
530 STD_INT32 => quote!(i32),
531 STD_INT64 => quote!(i64),
532 STD_FLOAT32 => quote!(f32),
533 STD_FLOAT64 => quote!(f64),
534 #[cfg(not(feature = "with_bigdecimal"))]
535 STD_DECIMAL => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Decimal),
536 #[cfg(feature = "with_bigdecimal")]
537 STD_DECIMAL => quote!(#EXPORTS_IDENT::bigdecimal::BigDecimal),
538 STD_BOOL => quote!(bool),
539 #[cfg(not(feature = "with_chrono"))]
540 STD_DATETIME => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Datetime),
541 #[cfg(feature = "with_chrono")]
542 STD_DATETIME => quote!(#EXPORTS_IDENT::chrono::DateTime<#EXPORTS_IDENT::chrono::Utc>),
543 #[cfg(not(feature = "with_chrono"))]
544 CAL_LOCAL_DATETIME => quote!(#EXPORTS_IDENT::edgedb_protocol::model::LocalDatetime),
545 #[cfg(feature = "with_chrono")]
546 CAL_LOCAL_DATETIME => quote!(#EXPORTS_IDENT::chrono::NaiveDateTime),
547 #[cfg(not(feature = "with_chrono"))]
548 CAL_LOCAL_DATE => quote!(#EXPORTS_IDENT::edgedb_protocol::model::LocalDate),
549 #[cfg(feature = "with_chrono")]
550 CAL_LOCAL_DATE => quote!(#EXPORTS_IDENT::chrono::NaiveDate),
551 #[cfg(not(feature = "with_chrono"))]
552 CAL_LOCAL_TIME => quote!(#EXPORTS_IDENT::edgedb_protocol::model::LocalTime),
553 #[cfg(feature = "with_chrono")]
554 CAL_LOCAL_TIME => quote!(#EXPORTS_IDENT::chrono::NaiveTime),
555 STD_DURATION => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Duration),
556 CAL_RELATIVE_DURATION => quote!(#EXPORTS_IDENT::edgedb_protocol::model::RelativeDuration),
557 CAL_DATE_DURATION => quote!(#EXPORTS_IDENT::edgedb_protocol::model::DateDuration),
558 STD_JSON => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Json),
559 #[cfg(not(feature = "with_bigint"))]
560 STD_BIGINT => quote!(#EXPORTS_IDENT::edgedb_protocol::model::BigInt),
561 #[cfg(feature = "with_bigint")]
562 STD_BIGINT => quote!(#EXPORTS_IDENT::num_bigint::BigInt),
563 CFG_MEMORY => quote!(#EXPORTS_IDENT::edgedb_protocol::model::ConfigMemory),
564 PGVECTOR_VECTOR => quote!(#EXPORTS_IDENT::edgedb_protocol::model::Vector),
565 _ => quote!(()),
566 }
567}
568
569pub async fn get_types() -> Result<()> {
570 let client = create_client().await?;
571 let json = client.query_json(TYPES_QUERY, &()).await?;
572 log::debug!("{}", json.as_ref());
573
574 Ok(())
575}