Skip to main content

oag_react/emitters/
hooks.rs

1use std::collections::HashSet;
2
3use minijinja::{Environment, context};
4use oag_core::ir::{HttpMethod, IrOperation, IrParameterLocation, IrReturnType, IrSpec, IrType};
5use oag_typescript::type_mapper::ir_type_to_ts;
6
7/// Emit `hooks.ts` — React hooks wrapping the API client.
8pub fn emit_hooks(ir: &IrSpec) -> String {
9    let mut env = Environment::new();
10    env.add_template("hooks.ts.j2", include_str!("../../templates/hooks.ts.j2"))
11        .expect("template should be valid");
12    let tmpl = env.get_template("hooks.ts.j2").unwrap();
13
14    let imported_types = collect_imported_types(ir);
15    let hooks: Vec<minijinja::Value> = ir.operations.iter().flat_map(build_hook_contexts).collect();
16
17    tmpl.render(context! {
18        imported_types => imported_types,
19        hooks => hooks,
20    })
21    .expect("render should succeed")
22}
23
24fn build_hook_contexts(op: &IrOperation) -> Vec<minijinja::Value> {
25    let mut results = Vec::new();
26
27    match (&op.method, &op.return_type) {
28        // GET → useSWR query hook
29        (HttpMethod::Get, IrReturnType::Standard(resp)) => {
30            let return_type = ir_type_to_ts(&resp.response_type);
31            let (params_sig, swr_key, call_args) = build_query_params(op);
32            results.push(context! {
33                kind => "query",
34                hook_name => format!("use{}", op.name.pascal_case),
35                method_name => op.name.camel_case.clone(),
36                params_signature => params_sig,
37                return_type => return_type,
38                swr_key => swr_key,
39                call_args => call_args,
40                description => op.summary.clone().or(op.description.clone()),
41            });
42        }
43        // POST/PUT/DELETE non-streaming → useSWRMutation hook
44        (_, IrReturnType::Standard(_)) | (_, IrReturnType::Void) => {
45            let return_type = match &op.return_type {
46                IrReturnType::Standard(r) => ir_type_to_ts(&r.response_type),
47                _ => "void".to_string(),
48            };
49            let body_type = op
50                .request_body
51                .as_ref()
52                .map(|b| ir_type_to_ts(&b.body_type))
53                .unwrap_or_else(|| "void".to_string());
54
55            let (path_params_sig, swr_key, call_args) = build_mutation_params(op);
56            results.push(context! {
57                kind => "mutation",
58                hook_name => format!("use{}", op.name.pascal_case),
59                method_name => op.name.camel_case.clone(),
60                path_params_signature => path_params_sig,
61                return_type => return_type,
62                body_type => body_type,
63                swr_key => swr_key,
64                call_args => call_args,
65                description => op.summary.clone().or(op.description.clone()),
66            });
67        }
68        // SSE → custom streaming hook
69        (_, IrReturnType::Sse(sse)) => {
70            let event_type = if let Some(ref name) = sse.event_type_name {
71                name.clone()
72            } else {
73                ir_type_to_ts(&sse.event_type)
74            };
75            let method_name = if sse.also_has_json {
76                format!("{}Stream", op.name.camel_case)
77            } else {
78                op.name.camel_case.clone()
79            };
80            let hook_name = if sse.also_has_json {
81                format!("use{}Stream", op.name.pascal_case)
82            } else {
83                format!("use{}", op.name.pascal_case)
84            };
85            let (path_params_sig, trigger_params, stream_call_args, deps) =
86                build_sse_hook_params(op);
87
88            results.push(context! {
89                kind => "sse",
90                hook_name => hook_name,
91                method_name => method_name,
92                path_params_signature => path_params_sig,
93                event_type => event_type,
94                trigger_params => trigger_params,
95                stream_call_args => stream_call_args,
96                deps => deps,
97                description => op.summary.clone().or(op.description.clone()),
98            });
99
100            // If dual endpoint, also generate the JSON query/mutation hook
101            if let Some(ref json_resp) = sse.json_response {
102                let return_type = ir_type_to_ts(&json_resp.response_type);
103                match op.method {
104                    HttpMethod::Get => {
105                        let (params_sig, swr_key, call_args) = build_query_params(op);
106                        results.push(context! {
107                            kind => "query",
108                            hook_name => format!("use{}", op.name.pascal_case),
109                            method_name => op.name.camel_case.clone(),
110                            params_signature => params_sig,
111                            return_type => return_type,
112                            swr_key => swr_key,
113                            call_args => call_args,
114                            description => op.summary.clone().or(op.description.clone()),
115                        });
116                    }
117                    _ => {
118                        let body_type = op
119                            .request_body
120                            .as_ref()
121                            .map(|b| ir_type_to_ts(&b.body_type))
122                            .unwrap_or_else(|| "void".to_string());
123                        let (path_params_sig, swr_key, call_args) = build_mutation_params(op);
124                        results.push(context! {
125                            kind => "mutation",
126                            hook_name => format!("use{}", op.name.pascal_case),
127                            method_name => op.name.camel_case.clone(),
128                            path_params_signature => path_params_sig,
129                            return_type => return_type,
130                            body_type => body_type,
131                            swr_key => swr_key,
132                            call_args => call_args,
133                            description => op.summary.clone().or(op.description.clone()),
134                        });
135                    }
136                }
137            }
138        }
139    }
140
141    results
142}
143
144fn build_query_params(op: &IrOperation) -> (String, String, String) {
145    let mut sig_parts = Vec::new();
146    let mut key_parts = Vec::new();
147    let mut call_parts = Vec::new();
148
149    for param in &op.parameters {
150        if param.location == IrParameterLocation::Path {
151            let ts = ir_type_to_ts(&param.param_type);
152            sig_parts.push(format!("{}: {}", param.name.camel_case, ts));
153            key_parts.push(param.name.camel_case.clone());
154            call_parts.push(param.name.camel_case.clone());
155        }
156    }
157    for param in &op.parameters {
158        if param.location == IrParameterLocation::Query {
159            let ts = ir_type_to_ts(&param.param_type);
160            if param.required {
161                sig_parts.push(format!("{}: {}", param.name.camel_case, ts));
162            } else {
163                sig_parts.push(format!("{}?: {}", param.name.camel_case, ts));
164            }
165            key_parts.push(param.name.camel_case.clone());
166            call_parts.push(param.name.camel_case.clone());
167        }
168    }
169
170    let swr_key = if key_parts.is_empty() {
171        format!("\"{}\"", op.path)
172    } else {
173        format!("[\"{}\" , {}] as const", op.path, key_parts.join(", "))
174    };
175
176    let params_sig = sig_parts.join(", ");
177    let call_args = call_parts.join(", ");
178
179    (params_sig, swr_key, call_args)
180}
181
182fn build_mutation_params(op: &IrOperation) -> (String, String, String) {
183    let mut sig_parts = Vec::new();
184    let mut key_parts = Vec::new();
185    let mut call_parts = Vec::new();
186
187    for param in &op.parameters {
188        if param.location == IrParameterLocation::Path {
189            let ts = ir_type_to_ts(&param.param_type);
190            sig_parts.push(format!("{}: {}, ", param.name.camel_case, ts));
191            key_parts.push(param.name.camel_case.clone());
192            call_parts.push(param.name.camel_case.clone());
193        }
194    }
195
196    // For mutation, the body comes from arg
197    if op.request_body.is_some() {
198        call_parts.push("arg".to_string());
199    }
200
201    let path_params_sig = sig_parts.join("");
202    let swr_key = if key_parts.is_empty() {
203        format!("\"{}\"", op.path)
204    } else {
205        format!("[\"{}\" , {}] as const", op.path, key_parts.join(", "))
206    };
207    let call_args = call_parts.join(", ");
208
209    (path_params_sig, swr_key, call_args)
210}
211
212fn build_sse_hook_params(op: &IrOperation) -> (String, String, String, String) {
213    let mut path_sig_parts = Vec::new();
214    let mut deps_parts = Vec::new();
215    let mut stream_call_parts = Vec::new();
216
217    for param in &op.parameters {
218        if param.location == IrParameterLocation::Path {
219            let ts = ir_type_to_ts(&param.param_type);
220            path_sig_parts.push(format!("{}: {}, ", param.name.camel_case, ts));
221            deps_parts.push(format!(", {}", param.name.camel_case));
222            stream_call_parts.push(param.name.camel_case.clone());
223        }
224    }
225
226    let trigger_params = if let Some(ref body) = op.request_body {
227        let ts = ir_type_to_ts(&body.body_type);
228        stream_call_parts.push("body".to_string());
229        format!("body: {}", ts)
230    } else {
231        String::new()
232    };
233
234    let path_params_sig = path_sig_parts.join("");
235    let stream_call_args = stream_call_parts.join(", ");
236    let deps = deps_parts.join("");
237
238    (path_params_sig, trigger_params, stream_call_args, deps)
239}
240
241fn collect_imported_types(ir: &IrSpec) -> Vec<String> {
242    let mut types = HashSet::new();
243
244    for op in &ir.operations {
245        match &op.return_type {
246            IrReturnType::Standard(resp) => {
247                collect_refs(&resp.response_type, &mut types);
248            }
249            IrReturnType::Sse(sse) => {
250                if let Some(ref name) = sse.event_type_name {
251                    types.insert(name.clone());
252                } else {
253                    collect_refs(&sse.event_type, &mut types);
254                }
255                if let Some(ref json) = sse.json_response {
256                    collect_refs(&json.response_type, &mut types);
257                }
258            }
259            IrReturnType::Void => {}
260        }
261        if let Some(ref body) = op.request_body {
262            collect_refs(&body.body_type, &mut types);
263        }
264    }
265
266    let mut sorted: Vec<String> = types.into_iter().collect();
267    sorted.sort();
268    sorted
269}
270
271fn collect_refs(ir_type: &IrType, types: &mut HashSet<String>) {
272    match ir_type {
273        IrType::Ref(name) => {
274            types.insert(name.clone());
275        }
276        IrType::Array(inner) | IrType::Map(inner) => collect_refs(inner, types),
277        IrType::Union(variants) => {
278            for v in variants {
279                collect_refs(v, types);
280            }
281        }
282        _ => {}
283    }
284}