Skip to main content

ruest_db_codegen/
rust_client.rs

1use ruest_db_schema::{Attribute, FieldKind, ScalarType, Schema};
2
3use crate::naming::{
4    column_name, create_input_name, delegate_name, rust_module, rust_struct, table_name,
5    table_columns, update_input_name,
6};
7
8pub struct GeneratedClient {
9    pub root: String,
10    pub modules: Vec<(String, String)>,
11}
12
13/// Génère le client Rust (`generated/ruestdb/`).
14pub fn generate_client(schema: &Schema) -> GeneratedClient {
15    let mut modules = Vec::new();
16    let mut delegate_fields = String::new();
17    let mut mod_decls = String::new();
18
19    for model in &schema.models {
20        let mod_name = rust_module(&model.name);
21        let struct_name = rust_struct(&model.name);
22        let table = table_name(&model.name);
23        let delegate = delegate_name(&model.name);
24        let create_name = create_input_name(&model.name);
25        let update_name = update_input_name(&model.name);
26
27        mod_decls.push_str(&format!("pub mod {mod_name};\n"));
28        delegate_fields.push_str(&format!(
29            "    pub {mod_name}: {mod_name}::{delegate},\n"
30        ));
31
32        let id = model.id_field().expect("@id required");
33        let id_name = &id.name;
34        let id_ty = scalar_rust_type(id);
35        let id_col = column_name(id_name);
36
37        let cols: Vec<_> = table_columns(model)
38            .iter()
39            .map(|f| column_name(&f.name))
40            .collect();
41        let select = cols
42            .iter()
43            .map(|c| format!("\"{c}\""))
44            .collect::<Vec<_>>()
45            .join(", ");
46        let mut entity_fields = String::new();
47        let mut create_fields = String::new();
48        let mut update_fields = String::new();
49        let mut map_row = String::new();
50        let mut insert_cols = Vec::new();
51        let mut insert_ph = Vec::new();
52        let mut insert_binds = String::new();
53        let mut insert_idx = 1i32;
54
55        for field in table_columns(model) {
56            let fname = &field.name;
57            let ty = scalar_rust_type(field);
58            entity_fields.push_str(&format!("    pub {fname}: {ty},\n"));
59            map_row.push_str(&format!(
60                "            {fname}: row.try_get::<{ty}, _>(\"{fname}\")?,\n"
61            ));
62
63            if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
64                continue;
65            }
66
67            let (create_ty, update_ty) = if field.optional {
68                (format!("Option<{ty}>"), format!("Option<{ty}>"))
69            } else {
70                (ty.clone(), format!("Option<{ty}>"))
71            };
72            create_fields.push_str(&format!("    pub {fname}: {create_ty},\n"));
73            update_fields.push_str(&format!("    pub {fname}: {update_ty},\n"));
74
75            insert_cols.push(format!("\"{}\"", column_name(fname)));
76            insert_ph.push(format!("${insert_idx}"));
77            insert_idx += 1;
78            if field.optional {
79                insert_binds.push_str(&format!("            .bind(&input.{fname})\n"));
80            } else {
81                insert_binds.push_str(&format!("            .bind(input.{fname})\n"));
82            }
83        }
84
85        let insert_cols_s = insert_cols.join(", ");
86        let insert_ph_s = insert_ph.join(", ");
87        let update_set = generate_update_set_sql(model);
88
89        let find_many_sql = rust_string_literal(&format!(
90            "SELECT {select} FROM \"{table}\" ORDER BY \"{id_col}\""
91        ));
92        let find_unique_sql = rust_string_literal(&format!(
93            "SELECT {select} FROM \"{table}\" WHERE \"{id_col}\" = $1"
94        ));
95        let insert_sql = rust_string_literal(&format!(
96            "INSERT INTO \"{table}\" ({insert_cols_s}) VALUES ({insert_ph_s}) RETURNING {select}"
97        ));
98        let update_sql = rust_string_literal(&format!(
99            "UPDATE \"{table}\" SET {update_set} WHERE \"{id_col}\" = $1 RETURNING {select}"
100        ));
101        let delete_sql = rust_string_literal(&format!(
102            "DELETE FROM \"{table}\" WHERE \"{id_col}\" = $1"
103        ));
104
105        let module_src = format!(
106            r##"//! Généré par RuestDB — ne pas modifier.
107
108use ruest_db_runtime::{{RuestDb, RuestDbError}};
109use ruest_db_runtime::serde::{{Deserialize, Serialize}};
110use ruest_db_runtime::Row;
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct {struct_name} {{
114{entity_fields}}}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct {create_name} {{
118{create_fields}}}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
121pub struct {update_name} {{
122{update_fields}}}
123
124pub struct {delegate} {{
125    db: RuestDb,
126}}
127
128impl {delegate} {{
129    pub(crate) fn new(db: RuestDb) -> Self {{
130        Self {{ db }}
131    }}
132
133    fn map_row(row: &ruest_db_runtime::sqlx::postgres::PgRow) -> Result<{struct_name}, RuestDbError> {{
134        Ok({struct_name} {{
135{map_row}        }})
136    }}
137
138    pub async fn find_many(&self) -> Result<Vec<{struct_name}>, RuestDbError> {{
139        let sql = {find_many_sql};
140        let rows = ruest_db_runtime::sqlx::query(sql).fetch_all(self.db.pool()).await?;
141        rows.iter().map(Self::map_row).collect()
142    }}
143
144    pub async fn find_unique(&self, id: {id_ty}) -> Result<Option<{struct_name}>, RuestDbError> {{
145        let sql = {find_unique_sql};
146        let row = ruest_db_runtime::sqlx::query(&sql)
147            .bind(id)
148            .fetch_optional(self.db.pool())
149            .await?;
150        row.as_ref().map(Self::map_row).transpose()
151    }}
152
153    pub async fn create(&self, input: {create_name}) -> Result<{struct_name}, RuestDbError> {{
154        let sql = {insert_sql};
155        let row = ruest_db_runtime::sqlx::query(sql)
156{insert_binds}            .fetch_one(self.db.pool())
157            .await?;
158        Self::map_row(&row)
159    }}
160
161    pub async fn update(
162        &self,
163        id: {id_ty},
164        input: {update_name},
165    ) -> Result<Option<{struct_name}>, RuestDbError> {{
166        let existing = self.find_unique(id.clone()).await?;
167        let Some(mut current) = existing else {{
168            return Ok(None);
169        }};
170{update_apply}
171        let sql = {update_sql};
172        let row = ruest_db_runtime::sqlx::query(sql)
173            .bind(id)
174{update_binds}
175            .fetch_optional(self.db.pool())
176            .await?;
177        row.as_ref().map(Self::map_row).transpose()
178    }}
179
180    pub async fn delete(&self, id: {id_ty}) -> Result<bool, RuestDbError> {{
181        let sql = {delete_sql};
182        let r = ruest_db_runtime::sqlx::query(sql).bind(id).execute(self.db.pool()).await?;
183        Ok(r.rows_affected() > 0)
184    }}
185}}
186"##,
187            update_apply = generate_update_apply(model),
188            update_binds = generate_update_binds(model),
189            find_many_sql = find_many_sql,
190            find_unique_sql = find_unique_sql,
191            insert_sql = insert_sql,
192            update_sql = update_sql,
193            delete_sql = delete_sql,
194        );
195
196        modules.push((mod_name, module_src));
197    }
198
199    let delegate_inits = schema
200        .models
201        .iter()
202        .map(|m| {
203            let mod_name = rust_module(&m.name);
204            let delegate = delegate_name(&m.name);
205            format!("            {mod_name}: {mod_name}::{delegate}::new(db.clone()),")
206        })
207        .collect::<Vec<_>>()
208        .join("\n");
209
210    let root = format!(
211        r#"//! Client RuestDB généré — `client.user.find_many().await?`
212
213{mod_decls}
214use ruest_db_runtime::RuestDb;
215
216pub struct RuestDbClient {{
217    inner: RuestDb,
218{delegate_fields}}}
219
220impl RuestDbClient {{
221    pub fn new(db: RuestDb) -> Self {{
222        Self {{
223            inner: db.clone(),
224{delegate_inits}
225        }}
226    }}
227
228    pub fn db(&self) -> &RuestDb {{
229        &self.inner
230    }}
231}}
232"#,
233        mod_decls = mod_decls,
234        delegate_fields = delegate_fields,
235        delegate_inits = delegate_inits,
236    );
237
238    GeneratedClient { root, modules }
239}
240
241fn rust_string_literal(content: &str) -> String {
242    format!(
243        "\"{}\"",
244        content.replace('\\', "\\\\").replace('\"', "\\\"")
245    )
246}
247
248fn scalar_rust_type(field: &ruest_db_schema::Field) -> String {
249    match &field.kind {
250        FieldKind::Scalar(t) => match t {
251            ScalarType::String => "String".into(),
252            ScalarType::Int => "i32".into(),
253            ScalarType::Float => "f64".into(),
254            ScalarType::Boolean => "bool".into(),
255            ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".into(),
256            ScalarType::Uuid => "uuid::Uuid".into(),
257        },
258        FieldKind::Model(_) => "String".into(),
259    }
260}
261
262fn generate_update_apply(model: &ruest_db_schema::Model) -> String {
263    let mut s = String::new();
264    for field in table_columns(model) {
265        if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
266            continue;
267        }
268        let fname = &field.name;
269        s.push_str(&format!(
270            "        if let Some(v) = input.{fname} {{ current.{fname} = v; }}\n"
271        ));
272    }
273    s
274}
275
276fn generate_update_set_sql(model: &ruest_db_schema::Model) -> String {
277    let mut parts = Vec::new();
278    let mut idx = 2i32;
279    for field in table_columns(model) {
280        if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
281            continue;
282        }
283        parts.push(format!(
284            "\"{}\" = ${idx}",
285            column_name(&field.name),
286        ));
287        idx += 1;
288    }
289    parts.join(", ")
290}
291
292fn generate_update_binds(model: &ruest_db_schema::Model) -> String {
293    let mut s = String::new();
294    for field in table_columns(model) {
295        if field.attributes.iter().any(|a| matches!(a, Attribute::Id)) {
296            continue;
297        }
298        let fname = &field.name;
299        s.push_str(&format!("            .bind(current.{fname})\n"));
300    }
301    s
302}