1use std::collections::HashMap;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use crate::rustgen::{capitalize_first, rust_type_from_schema, to_snake_case};
5use crate::{Modifier, Schema};
6
7fn pk_args(model: &crate::Model) -> (Vec<proc_macro2::Ident>, Vec<proc_macro2::TokenStream>, Vec<String>, Vec<String>, Vec<proc_macro2::TokenStream>) {
8 let pk_fields: Vec<_> = model.fields.iter()
9 .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
10 .collect();
11 let pk_names = pk_fields.iter().map(|pk| format_ident!("{}", to_snake_case(&pk.name))).collect();
12 let pk_types = pk_fields.iter().map(|pk| {
13 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
14 rust_type_from_schema(&pk.type_name, is_nullable)
15 }).collect();
16 let pk_cols: Vec<_> = pk_fields.iter().map(|pk| to_snake_case(&pk.name)).collect();
17 let pk_placeholders: Vec<_> = (1..=pk_fields.len()).map(|i| format!("${}", i)).collect();
18 let pk_arg_refs = pk_fields.iter().map(|pk| {
19 let name = format_ident!("{}", to_snake_case(&pk.name));
20 quote! { &#name }
21 }).collect();
22 (pk_names, pk_types, pk_cols, pk_placeholders, pk_arg_refs)
23}
24
25fn generate_find_unique(model_name: &proc_macro2::Ident, model: &crate::Model) -> TokenStream {
26 let pk_fields: Vec<_> = model.fields.iter()
27 .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
28 .collect();
29
30 if pk_fields.is_empty() {
31 quote! {}
32 } else if pk_fields.len() == 1 {
33 let pk = &pk_fields[0];
34 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
35 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
36 quote! {
37 pub async fn find_unique(&self, id: #pk_type)
38 -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
39 {
40 #model_name::find_by_id(&self.client, id).await
41 }
42 }
43 } else {
44 let pk_params = pk_fields.iter().map(|pk| {
45 let name = format_ident!("{}", to_snake_case(&pk.name));
46 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
47 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
48 quote! { #name: #pk_type }
49 });
50
51 let pk_args = pk_fields.iter().map(|pk| {
52 let name = format_ident!("{}", to_snake_case(&pk.name));
53 quote! { #name }
54 });
55
56 quote! {
57 pub async fn find_unique(&self, #(#pk_params),*)
58 -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
59 {
60 #model_name::find_by_composite_pk(&self.client, #(#pk_args),*).await
61 }
62 }
63 }
64}
65
66fn generate_find_or_create(model_name: &proc_macro2::Ident, model: &crate::Model, table_name: &str) -> TokenStream {
67 let pk_fields: Vec<_> = model.fields.iter()
68 .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
69 .collect();
70
71 if pk_fields.is_empty() {
72 quote! {}
73 } else if pk_fields.len() == 1 {
74 let pk = &pk_fields[0];
75 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
76 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
77 let pk_col = to_snake_case(&pk.name);
78 quote! {
79 pub async fn find_or_create(&self, id: #pk_type)
80 -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>>
81 {
82 self.client.execute(
83 &format!("INSERT INTO {} ({}) VALUES ($1) ON CONFLICT DO NOTHING", #table_name, #pk_col),
84 &[&id]
85 ).await?;
86 self.find_unique(id).await?.ok_or("Record should exist after find_or_create".into())
87 }
88 }
89 } else {
90 let pk_params = pk_fields.iter().map(|pk| {
91 let name = format_ident!("{}", to_snake_case(&pk.name));
92 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
93 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
94 quote! { #name: #pk_type }
95 });
96 let pk_cols: Vec<_> = pk_fields.iter().map(|pk| to_snake_case(&pk.name)).collect();
97 let pk_cols_str = pk_cols.join(", ");
98 let pk_placeholders: Vec<_> = (1..=pk_fields.len()).map(|i| format!("${}", i)).collect();
99 let pk_placeholders_str = pk_placeholders.join(", ");
100 let pk_conflict = pk_cols.join(", ");
101 let pk_args = pk_fields.iter().map(|pk| {
102 let name = format_ident!("{}", to_snake_case(&pk.name));
103 quote! { &#name }
104 });
105 let pk_args_call = pk_fields.iter().map(|pk| {
106 let name = format_ident!("{}", to_snake_case(&pk.name));
107 quote! { #name }
108 });
109 quote! {
110 pub async fn find_or_create(&self, #(#pk_params),*)
111 -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>>
112 {
113 let sql = format!(
114 "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO NOTHING",
115 #table_name, #pk_cols_str, #pk_placeholders_str, #pk_conflict
116 );
117 self.client.execute(&sql, &[#(#pk_args),*]).await?;
118 self.find_unique(#(#pk_args_call),*).await?.ok_or("Record should exist after find_or_create".into())
119 }
120 }
121 }
122}
123
124fn generate_jsonb_sub_accessors(model: &crate::Model, jsonb_defaults: &HashMap<(String, String), String>) -> Vec<TokenStream> {
125 let model_name = &model.name;
126 let query_builder = format_ident!("{}Query", model.name);
127 let table_name = model.name.to_lowercase();
128
129 let pk_fields: Vec<_> = model.fields.iter()
130 .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
131 .collect();
132
133 let jsonb_fields: Vec<_> = model.fields.iter()
134 .filter(|f| f.type_name == "JsonB")
135 .collect();
136
137 jsonb_fields.into_iter().map(|jsonb| {
138 let jsonb_name = &jsonb.name;
139 let jsonb_snake = to_snake_case(jsonb_name);
140 let jsonb_field_ident = format_ident!("{}", jsonb_name);
141 let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(jsonb_name));
142 let defaults_const = format_ident!("{}_DEFAULTS", jsonb_snake.to_uppercase());
143
144 let default_json_init = if let Some(json_content) = jsonb_defaults.get(&(model.name.clone(), jsonb.name.clone())) {
145 quote! {
146 static #defaults_const: Lazy<serde_json::Value> = Lazy::new(|| {
147 serde_json::from_str(#json_content)
148 .expect(&format!("Failed to parse default JSON for {}.{}", stringify!(#model_name), #jsonb_name))
149 });
150 }
151 } else {
152 quote! {
153 static #defaults_const: Lazy<serde_json::Value> = Lazy::new(|| {
154 serde_json::json!({})
155 });
156 }
157 };
158
159 let (pk_params, pk_where_methods, pk_args_for_set) = if pk_fields.len() == 1 {
160 let pk = &pk_fields[0];
161 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
162 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
163 let pk_field_name = format_ident!("where_{}", to_snake_case(&pk.name));
164 (
165 quote! { id: #pk_type },
166 quote! { .#pk_field_name(id) },
167 vec![quote! { &id }],
168 )
169 } else {
170 let params = pk_fields.iter().map(|pk| {
171 let param_name = format_ident!("{}", to_snake_case(&pk.name));
172 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
173 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
174 quote! { #param_name: #pk_type }
175 });
176
177 let where_methods = pk_fields.iter().map(|pk| {
178 let method_name = format_ident!("where_{}", to_snake_case(&pk.name));
179 let param_name = format_ident!("{}", to_snake_case(&pk.name));
180 quote! { .#method_name(#param_name) }
181 });
182
183 let set_args = pk_fields.iter().map(|pk| {
184 let param_name = format_ident!("{}", to_snake_case(&pk.name));
185 quote! { &#param_name }
186 });
187
188 (
189 quote! { #(#params),* },
190 quote! { #(#where_methods)* },
191 set_args.collect::<Vec<_>>(),
192 )
193 };
194
195 let pk_args_clone = if pk_fields.len() == 1 {
196 quote! { id }
197 } else {
198 let args = pk_fields.iter().map(|pk| {
199 let param_name = format_ident!("{}", to_snake_case(&pk.name));
200 quote! { #param_name }
201 });
202 quote! { #(#args),* }
203 };
204
205 let pk_columns: Vec<_> = pk_fields.iter().map(|pk| to_snake_case(&pk.name)).collect();
206 let pk_placeholders: Vec<_> = (1..=pk_fields.len()).map(|i| format!("${}", i)).collect();
207 let insert_pk_part = pk_columns.join(", ");
208 let insert_values_part = pk_placeholders.join(", ");
209 let conflict_clause = pk_columns.join(", ");
210
211 quote! {
212 #default_json_init
213
214 #[derive(Clone)]
215 pub struct #sub_accessor_struct {
216 client: Arc<PgClient>,
217 }
218
219 impl std::fmt::Debug for #sub_accessor_struct {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 f.debug_struct(stringify!(#sub_accessor_struct))
222 .field("client", &"<PgClient>")
223 .finish()
224 }
225 }
226
227 impl #sub_accessor_struct {
228 pub fn new(client: Arc<PgClient>) -> Self {
229 Self { client }
230 }
231
232 pub async fn get(&self, #pk_params, key: &str)
233 -> Result<String, Box<dyn std::error::Error + Send + Sync>>
234 {
235 match #query_builder::new()
236 #pk_where_methods
237 .first(&self.client)
238 .await
239 {
240 Ok(Some(record)) => {
241 record.#jsonb_field_ident.get_string(key)
242 .or_else(|_| #defaults_const.get_string(key))
243 },
244 Ok(None) => #defaults_const.get_string(key),
245 Err(e) => Err(e),
246 }
247 }
248
249 pub async fn get_as<T>(&self, #pk_params, key: &str)
250 -> Result<T, Box<dyn std::error::Error + Send + Sync>>
251 where
252 T: serde::de::DeserializeOwned,
253 {
254 match #query_builder::new()
255 #pk_where_methods
256 .first(&self.client)
257 .await
258 {
259 Ok(Some(record)) => {
260 record.#jsonb_field_ident.get_value(key)
261 .or_else(|_| #defaults_const.get_value(key))
262 },
263 Ok(None) => #defaults_const.get_value(key),
264 Err(e) => Err(e),
265 }
266 }
267
268 pub async fn get_or<T>(&self, #pk_params, key: &str, default: T)
269 -> Result<T, Box<dyn std::error::Error + Send + Sync>>
270 where
271 T: serde::de::DeserializeOwned,
272 {
273 match self.get_as(#pk_args_clone, key).await {
274 Ok(value) => Ok(value),
275 Err(_) => Ok(default),
276 }
277 }
278
279 pub async fn has(&self, #pk_params, key: &str)
280 -> Result<bool, Box<dyn std::error::Error + Send + Sync>>
281 {
282 match #query_builder::new()
283 #pk_where_methods
284 .first(&self.client)
285 .await
286 {
287 Ok(Some(record)) => Ok(record.#jsonb_field_ident.has_key(key) || #defaults_const.has_key(key)),
288 Ok(None) => Ok(#defaults_const.has_key(key)),
289 Err(e) => Err(e),
290 }
291 }
292
293 pub async fn set<T>(&self, #pk_params, key: &str, value: T)
294 -> Result<(), Box<dyn std::error::Error + Send + Sync>>
295 where
296 T: serde::Serialize + Send + Sync,
297 {
298 let value_json = serde_json::to_value(&value)?;
299 let value_str = value_json.to_string();
300
301 let sql = format!(
302 "INSERT INTO {} ({}, {}, updated_at) VALUES ({}, jsonb_build_object($1, $2), NOW()) \
303 ON CONFLICT ({}) DO UPDATE SET {} = jsonb_set(COALESCE({}.{}, '{{}}'::jsonb), $3, $4, true), updated_at = NOW()",
304 #table_name,
305 #insert_pk_part,
306 #jsonb_snake,
307 #insert_values_part,
308 #conflict_clause,
309 #jsonb_snake,
310 #table_name,
311 #jsonb_snake,
312 );
313
314 let key_path = format!("{{{}}}", key);
315 self.client.execute(
316 &sql,
317 &[#(#pk_args_for_set),*, &key, &value_str, &key_path, &value_str]
318 ).await?;
319
320 Ok(())
321 }
322
323 pub async fn get_many(
324 &self, #pk_params, keys: &[&str]
325 ) -> Result<HashMap<String, serde_json::Value>, Box<dyn std::error::Error + Send + Sync>>
326 {
327 let opt = #query_builder::new()
328 #pk_where_methods
329 .first(&self.client).await?;
330
331 let mut out = HashMap::new();
332
333 if let Some(record) = opt {
334 for &key in keys {
335 if let Some(v) = record.#jsonb_field_ident.get(key) {
336 out.insert(key.to_string(), v.clone());
337 }
338 }
339 } else {
340 for &key in keys {
341 if let Some(v) = #defaults_const.get(key) {
342 out.insert(key.to_string(), v.clone());
343 }
344 }
345 }
346 Ok(out)
347 }
348
349 pub async fn get_many_as<T>(
350 &self, #pk_params, keys: &[&str]
351 ) -> Result<HashMap<String, T>, Box<dyn std::error::Error + Send + Sync>>
352 where T: serde::de::DeserializeOwned
353 {
354 let values = self.get_many(#pk_args_clone, keys).await?;
355 let mut map = HashMap::new();
356 for (k, v) in values {
357 if let Ok(x) = serde_json::from_value::<T>(v) {
358 map.insert(k, x);
359 }
360 }
361 Ok(map)
362 }
363 }
364 }
365 }).collect()
366}
367
368pub fn generate_client_struct(schema: &Schema, jsonb_defaults: &HashMap<(String, String), String>) -> TokenStream {
369 let model_accessors = schema.models.iter().map(|model| {
370 let accessor_name = format_ident!("{}", to_snake_case(&model.name));
371 let accessor_struct = format_ident!("{}Accessor", model.name);
372 quote! { pub #accessor_name: #accessor_struct }
373 });
374
375 let accessor_structs = schema.models.iter().map(|model| {
376 let model_name = format_ident!("{}", model.name);
377 let accessor_struct = format_ident!("{}Accessor", model.name);
378 let query_builder = format_ident!("{}Query", model.name);
379 let update_builder = format_ident!("{}Update", model.name);
380 let upsert_builder = format_ident!("{}Upsert", model.name);
381 let table_name = model.name.to_lowercase();
382
383 let find_unique = generate_find_unique(&model_name, model);
384 let find_or_create = generate_find_or_create(&model_name, model, &table_name);
385
386 let jsonb_fields: Vec<_> = model.fields.iter()
387 .filter(|f| f.type_name == "JsonB")
388 .collect();
389
390 let jsonb_accessor_fields = jsonb_fields.iter().map(|jsonb| {
391 let jsonb_snake = to_snake_case(&jsonb.name);
392 let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(&jsonb.name));
393 let sub_accessor_field = format_ident!("{}", jsonb_snake);
394 quote! { pub #sub_accessor_field: #sub_accessor_struct }
395 });
396
397 let jsonb_accessor_inits = jsonb_fields.iter().map(|jsonb| {
398 let jsonb_snake = to_snake_case(&jsonb.name);
399 let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(&jsonb.name));
400 let sub_accessor_field = format_ident!("{}", jsonb_snake);
401 quote! { #sub_accessor_field: #sub_accessor_struct::new(client.clone()) }
402 });
403
404 let jsonb_debug_fields = jsonb_fields.iter().map(|jsonb| {
405 let jsonb_snake = to_snake_case(&jsonb.name);
406 let sub_accessor_field = format_ident!("{}", jsonb_snake);
407 quote! { .field(stringify!(#sub_accessor_field), &self.#sub_accessor_field) }
408 });
409
410 let jsonb_sub_accessors = generate_jsonb_sub_accessors(model, jsonb_defaults);
411
412 quote! {
413 #(#jsonb_sub_accessors)*
414
415 #[derive(Clone)]
416 pub struct #accessor_struct {
417 client: Arc<PgClient>,
418 #(#jsonb_accessor_fields),*
419 }
420 impl std::fmt::Debug for #accessor_struct {
421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 f.debug_struct(stringify!(#accessor_struct))
423 .field("client", &"<PgClient>")
424 #(#jsonb_debug_fields)*
425 .finish()
426 }
427 }
428 impl #accessor_struct {
429 pub fn new(client: Arc<PgClient>) -> Self {
430 Self {
431 client: client.clone(),
432 #(#jsonb_accessor_inits),*
433 }
434 }
435 pub fn find_many(&self) -> #query_builder { #query_builder::new() }
436 pub fn update(&self) -> #update_builder { #update_builder::new(self.client.clone()) }
437 pub fn upsert(&self) -> #upsert_builder { #upsert_builder::new(self.client.clone()) }
438 #find_unique
439 #find_or_create
440 pub async fn find_first(&self) -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>> {
441 #query_builder::new().first(&self.client).await
442 }
443 pub async fn count(&self) -> Result<i64, Box<dyn std::error::Error + Send + Sync>> {
444 #query_builder::new().count(&self.client).await
445 }
446 pub fn client(&self) -> &PgClient { &self.client }
447 }
448 }
449 });
450
451 let accessor_inits = schema.models.iter().map(|model| {
452 let accessor_name = format_ident!("{}", to_snake_case(&model.name));
453 let accessor_struct = format_ident!("{}Accessor", model.name);
454 quote! { #accessor_name: #accessor_struct::new(client.clone()) }
455 });
456
457 let debug_accessor_fields = schema.models.iter().map(|model| {
458 let accessor_name = to_snake_case(&model.name);
459 let accessor_name_ident = format_ident!("{}", accessor_name);
460 quote! { .field(#accessor_name, &self.#accessor_name_ident) }
461 });
462
463 quote! {
464 #(#accessor_structs)*
465
466 pub struct Client {
467 client: Arc<PgClient>,
468 #(#model_accessors),*
469 }
470 impl std::fmt::Debug for Client {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 f.debug_struct("Client")
473 .field("client", &"<PgClient>")
474 #(#debug_accessor_fields)*
475 .finish()
476 }
477 }
478 impl Client {
479 pub async fn new(connection_string: &str) -> Result<Self, Error> {
480 let (client, connection) = tokio_postgres::connect(connection_string, NoTls).await?;
481 tokio::spawn(async move {
482 if let Err(e) = connection.await {
483 eprintln!("connection error: {}", e);
484 }
485 });
486 let client = Arc::new(client);
487 Ok(Self {
488 client: client.clone(),
489 #(#accessor_inits),*
490 })
491 }
492 pub fn client(&self) -> &PgClient { &self.client }
493 }
494 }
495}