Skip to main content

nfw_core/api/
generator.rs

1use std::path::Path;
2
3use crate::routing::Route;
4
5#[derive(Debug, Clone)]
6pub struct TypeSpec {
7    pub name: String,
8    pub fields: Vec<FieldSpec>,
9    pub is_enum: bool,
10    pub enum_variants: Vec<String>,
11}
12
13#[derive(Debug, Clone)]
14pub struct FieldSpec {
15    pub name: String,
16    pub type_name: String,
17    pub is_optional: bool,
18    pub is_array: bool,
19}
20
21#[derive(Debug, Clone)]
22pub struct EndpointSpec {
23    pub path: String,
24    pub method: String,
25    pub request_type: Option<TypeSpec>,
26    pub response_type: Option<TypeSpec>,
27    pub params: Vec<FieldSpec>,
28    pub query_params: Vec<FieldSpec>,
29}
30
31pub struct ApiGenerator {
32    routes: Vec<Route>,
33    types: Vec<TypeSpec>,
34}
35
36impl ApiGenerator {
37    pub fn new(routes: Vec<Route>) -> Self {
38        Self {
39            routes,
40            types: Vec::new(),
41        }
42    }
43
44    pub fn add_type(&mut self, spec: TypeSpec) {
45        self.types.push(spec);
46    }
47
48    pub fn generate_typescript_types(&self) -> String {
49        let mut output = String::from("// Auto-generated TypeScript types\n\n");
50
51        for typ in &self.types {
52            if typ.is_enum {
53                output.push_str(&format!("export enum {} {{\n", typ.name));
54                for variant in &typ.enum_variants {
55                    output.push_str(&format!("  {} = '{}',\n", variant, variant));
56                }
57                output.push_str("}\n\n");
58            } else {
59                output.push_str(&format!("export interface {} {{\n", typ.name));
60                for field in &typ.fields {
61                    let optional = if field.is_optional { "?" } else { "" };
62                    let array = if field.is_array { "[]" } else { "" };
63                    output.push_str(&format!(
64                        "  {}{}: {}{};\n",
65                        field.name, optional, field.type_name, array
66                    ));
67                }
68                output.push_str("}\n\n");
69            }
70        }
71
72        output
73    }
74
75    pub fn generate_typescript_client(&self) -> String {
76        let mut output = String::from("// Auto-generated API client\n\n");
77        output.push_str(
78            r#"
79export class ApiClient {
80    private baseUrl: string;
81    private defaultHeaders: Record<string, string>;
82
83    constructor(baseUrl: string) {
84        this.baseUrl = baseUrl;
85        this.defaultHeaders = {
86            'Content-Type': 'application/json',
87        };
88    }
89
90    private async request<T>(
91        method: string,
92        path: string,
93        options: {
94            params?: Record<string, string>;
95            query?: Record<string, string | string[]>;
96            body?: unknown;
97        } = {}
98    ): Promise<T> {
99        let url = new URL(path, this.baseUrl);
100
101        if (options.params) {
102            for (const [key, value] of Object.entries(options.params)) {
103                url.pathname = url.pathname.replace(`{${key}}`, value);
104            }
105        }
106
107        if (options.query) {
108            for (const [key, value] of Object.entries(options.query)) {
109                if (Array.isArray(value)) {
110                    value.forEach(v => url.searchParams.append(key, v));
111                } else {
112                    url.searchParams.set(key, value);
113                }
114            }
115        }
116
117        const response = await fetch(url.toString(), {
118            method,
119            headers: this.defaultHeaders,
120            body: options.body ? JSON.stringify(options.body) : undefined,
121        });
122
123        if (!response.ok) {
124            throw new ApiError(response.status, await response.text());
125        }
126
127        return response.json();
128    }
129"#,
130        );
131
132        for endpoint in self.get_endpoint_specs() {
133            let method_name = to_camel_case(&endpoint.path.replace("/", "_"));
134            let method_lower = endpoint.method.to_lowercase();
135
136            output.push_str(&format!("\n    async {}(options?: {{", method_name));
137
138            if !endpoint.params.is_empty() {
139                output.push_str("params: { ");
140                let params: Vec<String> = endpoint
141                    .params
142                    .iter()
143                    .map(|p| format!("{}: {}", p.name, p.type_name))
144                    .collect();
145                output.push_str(&params.join(", "));
146                output.push_str(" }, ");
147            }
148
149            if !endpoint.query_params.is_empty() {
150                output.push_str("query?: { ");
151                let query: Vec<String> = endpoint
152                    .query_params
153                    .iter()
154                    .map(|p| format!("{}?: {}", p.name, p.type_name))
155                    .collect();
156                output.push_str(&query.join(", "));
157                output.push_str(" }, ");
158            }
159
160            if let Some(ref req_type) = endpoint.request_type {
161                output.push_str("body: ");
162                output.push_str(&req_type.name);
163                output.push_str(", ");
164            }
165
166            output.push_str("}): Promise<");
167            if let Some(ref resp) = endpoint.response_type {
168                output.push_str(&resp.name);
169            } else {
170                output.push_str("void");
171            }
172            output.push_str("> {\n");
173
174            output.push_str(&format!(
175                "        return this.request<{}>('{}', '{}', {{",
176                endpoint
177                    .response_type
178                    .as_ref()
179                    .map(|t| t.name.clone())
180                    .unwrap_or_else(|| "void".to_string()),
181                method_lower,
182                endpoint.path
183            ));
184
185            if !endpoint.params.is_empty() {
186                output.push_str(" params: options?.params,");
187            }
188            if !endpoint.query_params.is_empty() {
189                output.push_str(" query: options?.query,");
190            }
191            if endpoint.request_type.is_some() {
192                output.push_str(" body: options?.body,");
193            }
194
195            output.push_str(" });\n    }\n");
196        }
197
198        output.push_str("}\n\n");
199
200        output.push_str(
201            r#"
202export class ApiError extends Error {
203    constructor(public status: number, message: string) {
204        super(message);
205        this.name = 'ApiError';
206    }
207}
208"#,
209        );
210
211        output
212    }
213
214    pub fn generate_rust_client(&self) -> String {
215        let mut output = String::from("// Auto-generated Rust API client\n\n");
216
217        output.push_str("use serde::{Deserialize, Serialize};\n");
218        output.push_str("use std::collections::HashMap;\n\n");
219
220        output.push_str("pub struct ApiClient {\n");
221        output.push_str("    base_url: String,\n");
222        output.push_str("    client: reqwest::Client,\n");
223        output.push_str("    headers: HashMap<String, String>,\n");
224        output.push_str("}\n\n");
225
226        output.push_str("impl ApiClient {\n");
227        output.push_str("    pub fn new(base_url: &str) -> Self {\n");
228        output.push_str("        Self {\n");
229        output.push_str("            base_url: base_url.to_string(),\n");
230        output.push_str("            client: reqwest::Client::new(),\n");
231        output.push_str("            headers: HashMap::new(),\n");
232        output.push_str("        }\n");
233        output.push_str("    }\n\n");
234
235        output.push_str("    pub fn with_header(mut self, key: &str, value: &str) -> Self {\n");
236        output.push_str("        self.headers.insert(key.to_string(), value.to_string());\n");
237        output.push_str("        self\n");
238        output.push_str("    }\n\n");
239
240        for endpoint in self.get_endpoint_specs() {
241            let method_name = to_snake_case(&endpoint.path.replace("/", "_"));
242            let http_method = endpoint.method.to_lowercase();
243            let return_type = endpoint
244                .response_type
245                .as_ref()
246                .map(|_t| "serde_json::Value".to_string())
247                .unwrap_or_else(|| "()".to_string());
248
249            output.push_str(&format!(
250                "    pub async fn {}(&self, path: &str{} {}) -> anyhow::Result<{}> {{\n",
251                method_name,
252                if !endpoint.params.is_empty() {
253                    ", params: HashMap<String, String>"
254                } else {
255                    ""
256                },
257                if endpoint.request_type.is_some() {
258                    ", body: serde_json::Value"
259                } else {
260                    ""
261                },
262                return_type
263            ));
264
265            output.push_str("        let url = format!(\"{}{}\", self.base_url, path);\n");
266            output.push_str(&format!(
267                "        let mut req = self.client.{}(&url);\n",
268                http_method
269            ));
270
271            if endpoint.request_type.is_some() {
272                output.push_str("        req = req.json(&body);\n");
273            }
274
275            output.push_str("        for (k, v) in &self.headers {\n");
276            output.push_str("            req = req.header(k, v);\n");
277            output.push_str("        }\n");
278            output.push_str("        let resp = req.send().await?;\n");
279            output.push_str("        let text = resp.text().await?;\n");
280
281            if endpoint.response_type.is_some() {
282                output.push_str("        Ok(serde_json::from_str(&text)?)\n");
283            } else {
284                output.push_str("        Ok(())\n");
285            }
286
287            output.push_str("    }\n\n");
288        }
289
290        output.push_str("}\n");
291
292        output
293    }
294
295    pub fn get_endpoint_specs(&self) -> Vec<EndpointSpec> {
296        self.routes
297            .iter()
298            .map(|r| EndpointSpec {
299                path: r.path.clone(),
300                method: r.method.as_str().to_string(),
301                request_type: None,
302                response_type: None,
303                params: Vec::new(),
304                query_params: Vec::new(),
305            })
306            .collect()
307    }
308
309    pub fn write_typescript(&self, output_dir: &Path) -> anyhow::Result<()> {
310        let types_path = output_dir.join("api.types.ts");
311        let client_path = output_dir.join("api.client.ts");
312
313        std::fs::write(&types_path, self.generate_typescript_types())?;
314        std::fs::write(&client_path, self.generate_typescript_client())?;
315
316        tracing::info!("Generated TypeScript types at {}", types_path.display());
317        tracing::info!("Generated TypeScript client at {}", client_path.display());
318
319        Ok(())
320    }
321
322    pub fn write_rust(&self, output_dir: &Path) -> anyhow::Result<()> {
323        let client_path = output_dir.join("api_client.rs");
324        std::fs::write(&client_path, self.generate_rust_client())?;
325        tracing::info!("Generated Rust client at {}", client_path.display());
326        Ok(())
327    }
328}
329
330fn to_camel_case(s: &str) -> String {
331    let mut result = String::new();
332    let mut capitalize = false;
333
334    for c in s.chars() {
335        if c == '_' || c == '-' || c == '/' {
336            capitalize = true;
337        } else if capitalize {
338            result.push(c.to_ascii_uppercase());
339            capitalize = false;
340        } else {
341            result.push(c);
342        }
343    }
344
345    result
346}
347
348fn to_snake_case(s: &str) -> String {
349    let mut result = String::new();
350    for (i, c) in s.chars().enumerate() {
351        if c.is_uppercase() && i > 0 {
352            result.push('_');
353        }
354        result.push(c.to_ascii_lowercase());
355    }
356    result.replace(['-', '/'], "_")
357}