1use ruest_db_schema::{Attribute, FieldKind, ScalarType, Schema};
2
3use crate::naming::{
4 column_name, create_input_name, delegate_name, rust_module, rust_struct, table_name,
5 table_columns, update_input_name,
6};
7
8pub struct GeneratedClient {
9 pub root: String,
10 pub modules: Vec<(String, String)>,
11}
12
13pub fn generate_client(schema: &Schema) -> GeneratedClient {
15 let mut modules = Vec::new();
16 let mut delegate_fields = String::new();
17 let mut mod_decls = String::new();
18
19 for model in &schema.models {
20 let mod_name = rust_module(&model.name);
21 let struct_name = rust_struct(&model.name);
22 let table = table_name(&model.name);
23 let delegate = delegate_name(&model.name);
24 let create_name = create_input_name(&model.name);
25 let update_name = update_input_name(&model.name);
26
27 mod_decls.push_str(&format!("pub mod {mod_name};\n"));
28 delegate_fields.push_str(&format!(
29 " pub {mod_name}: {mod_name}::{delegate},\n"
30 ));
31
32 let id = model.id_field().expect("@id required");
33 let id_name = &id.name;
34 let id_ty = scalar_rust_type(id);
35 let id_col = column_name(id_name);
36
37 let cols: Vec<_> = table_columns(model)
38 .iter()
39 .map(|f| column_name(&f.name))
40 .collect();
41 let select = cols
42 .iter()
43 .map(|c| format!("\"{c}\""))
44 .collect::<Vec<_>>()
45 .join(", ");
46 let mut entity_fields = String::new();
47 let mut create_fields = String::new();
48 let mut update_fields = String::new();
49 let mut map_row = String::new();
50 let mut insert_cols = Vec::new();
51 let mut insert_ph = Vec::new();
52 let mut insert_binds = String::new();
53 let mut insert_idx = 1i32;
54
55 for field in table_columns(model) {
56 let fname = &field.name;
57 let ty = scalar_rust_type(field);
58 entity_fields.push_str(&format!(" pub {fname}: {ty},\n"));
59 map_row.push_str(&format!(
60 " {fname}: row.try_get::<{ty}, _>(\"{fname}\")?,\n"
61 ));
62
63 if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
64 continue;
65 }
66
67 let (create_ty, update_ty) = if field.optional {
68 (format!("Option<{ty}>"), format!("Option<{ty}>"))
69 } else {
70 (ty.clone(), format!("Option<{ty}>"))
71 };
72 create_fields.push_str(&format!(" pub {fname}: {create_ty},\n"));
73 update_fields.push_str(&format!(" pub {fname}: {update_ty},\n"));
74
75 insert_cols.push(format!("\"{}\"", column_name(fname)));
76 insert_ph.push(format!("${insert_idx}"));
77 insert_idx += 1;
78 if field.optional {
79 insert_binds.push_str(&format!(" .bind(&input.{fname})\n"));
80 } else {
81 insert_binds.push_str(&format!(" .bind(input.{fname})\n"));
82 }
83 }
84
85 let insert_cols_s = insert_cols.join(", ");
86 let insert_ph_s = insert_ph.join(", ");
87 let update_set = generate_update_set_sql(model);
88
89 let find_many_sql = rust_string_literal(&format!(
90 "SELECT {select} FROM \"{table}\" ORDER BY \"{id_col}\""
91 ));
92 let find_unique_sql = rust_string_literal(&format!(
93 "SELECT {select} FROM \"{table}\" WHERE \"{id_col}\" = $1"
94 ));
95 let insert_sql = rust_string_literal(&format!(
96 "INSERT INTO \"{table}\" ({insert_cols_s}) VALUES ({insert_ph_s}) RETURNING {select}"
97 ));
98 let update_sql = rust_string_literal(&format!(
99 "UPDATE \"{table}\" SET {update_set} WHERE \"{id_col}\" = $1 RETURNING {select}"
100 ));
101 let delete_sql = rust_string_literal(&format!(
102 "DELETE FROM \"{table}\" WHERE \"{id_col}\" = $1"
103 ));
104
105 let module_src = format!(
106 r##"//! Généré par RuestDB — ne pas modifier.
107
108use ruest_db_runtime::{{RuestDb, RuestDbError}};
109use ruest_db_runtime::serde::{{Deserialize, Serialize}};
110use ruest_db_runtime::Row;
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct {struct_name} {{
114{entity_fields}}}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct {create_name} {{
118{create_fields}}}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
121pub struct {update_name} {{
122{update_fields}}}
123
124pub struct {delegate} {{
125 db: RuestDb,
126}}
127
128impl {delegate} {{
129 pub(crate) fn new(db: RuestDb) -> Self {{
130 Self {{ db }}
131 }}
132
133 fn map_row(row: &ruest_db_runtime::sqlx::postgres::PgRow) -> Result<{struct_name}, RuestDbError> {{
134 Ok({struct_name} {{
135{map_row} }})
136 }}
137
138 pub async fn find_many(&self) -> Result<Vec<{struct_name}>, RuestDbError> {{
139 let sql = {find_many_sql};
140 let rows = ruest_db_runtime::sqlx::query(sql).fetch_all(self.db.pool()).await?;
141 rows.iter().map(Self::map_row).collect()
142 }}
143
144 pub async fn find_unique(&self, id: {id_ty}) -> Result<Option<{struct_name}>, RuestDbError> {{
145 let sql = {find_unique_sql};
146 let row = ruest_db_runtime::sqlx::query(&sql)
147 .bind(id)
148 .fetch_optional(self.db.pool())
149 .await?;
150 row.as_ref().map(Self::map_row).transpose()
151 }}
152
153 pub async fn create(&self, input: {create_name}) -> Result<{struct_name}, RuestDbError> {{
154 let sql = {insert_sql};
155 let row = ruest_db_runtime::sqlx::query(sql)
156{insert_binds} .fetch_one(self.db.pool())
157 .await?;
158 Self::map_row(&row)
159 }}
160
161 pub async fn update(
162 &self,
163 id: {id_ty},
164 input: {update_name},
165 ) -> Result<Option<{struct_name}>, RuestDbError> {{
166 let existing = self.find_unique(id.clone()).await?;
167 let Some(mut current) = existing else {{
168 return Ok(None);
169 }};
170{update_apply}
171 let sql = {update_sql};
172 let row = ruest_db_runtime::sqlx::query(sql)
173 .bind(id)
174{update_binds}
175 .fetch_optional(self.db.pool())
176 .await?;
177 row.as_ref().map(Self::map_row).transpose()
178 }}
179
180 pub async fn delete(&self, id: {id_ty}) -> Result<bool, RuestDbError> {{
181 let sql = {delete_sql};
182 let r = ruest_db_runtime::sqlx::query(sql).bind(id).execute(self.db.pool()).await?;
183 Ok(r.rows_affected() > 0)
184 }}
185}}
186"##,
187 update_apply = generate_update_apply(model),
188 update_binds = generate_update_binds(model),
189 find_many_sql = find_many_sql,
190 find_unique_sql = find_unique_sql,
191 insert_sql = insert_sql,
192 update_sql = update_sql,
193 delete_sql = delete_sql,
194 );
195
196 modules.push((mod_name, module_src));
197 }
198
199 let delegate_inits = schema
200 .models
201 .iter()
202 .map(|m| {
203 let mod_name = rust_module(&m.name);
204 let delegate = delegate_name(&m.name);
205 format!(" {mod_name}: {mod_name}::{delegate}::new(db.clone()),")
206 })
207 .collect::<Vec<_>>()
208 .join("\n");
209
210 let root = format!(
211 r#"//! Client RuestDB généré — `client.user.find_many().await?`
212
213{mod_decls}
214use ruest_db_runtime::RuestDb;
215
216pub struct RuestDbClient {{
217 inner: RuestDb,
218{delegate_fields}}}
219
220impl RuestDbClient {{
221 pub fn new(db: RuestDb) -> Self {{
222 Self {{
223 inner: db.clone(),
224{delegate_inits}
225 }}
226 }}
227
228 pub fn db(&self) -> &RuestDb {{
229 &self.inner
230 }}
231}}
232"#,
233 mod_decls = mod_decls,
234 delegate_fields = delegate_fields,
235 delegate_inits = delegate_inits,
236 );
237
238 GeneratedClient { root, modules }
239}
240
241fn rust_string_literal(content: &str) -> String {
242 format!(
243 "\"{}\"",
244 content.replace('\\', "\\\\").replace('\"', "\\\"")
245 )
246}
247
248fn scalar_rust_type(field: &ruest_db_schema::Field) -> String {
249 match &field.kind {
250 FieldKind::Scalar(t) => match t {
251 ScalarType::String => "String".into(),
252 ScalarType::Int => "i32".into(),
253 ScalarType::Float => "f64".into(),
254 ScalarType::Boolean => "bool".into(),
255 ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".into(),
256 ScalarType::Uuid => "uuid::Uuid".into(),
257 },
258 FieldKind::Model(_) => "String".into(),
259 }
260}
261
262fn generate_update_apply(model: &ruest_db_schema::Model) -> String {
263 let mut s = String::new();
264 for field in table_columns(model) {
265 if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
266 continue;
267 }
268 let fname = &field.name;
269 s.push_str(&format!(
270 " if let Some(v) = input.{fname} {{ current.{fname} = v; }}\n"
271 ));
272 }
273 s
274}
275
276fn generate_update_set_sql(model: &ruest_db_schema::Model) -> String {
277 let mut parts = Vec::new();
278 let mut idx = 2i32;
279 for field in table_columns(model) {
280 if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
281 continue;
282 }
283 parts.push(format!(
284 "\"{}\" = ${idx}",
285 column_name(&field.name),
286 ));
287 idx += 1;
288 }
289 parts.join(", ")
290}
291
292fn generate_update_binds(model: &ruest_db_schema::Model) -> String {
293 let mut s = String::new();
294 for field in table_columns(model) {
295 if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
296 continue;
297 }
298 let fname = &field.name;
299 s.push_str(&format!(" .bind(current.{fname})\n"));
300 }
301 s
302}