Skip to main content

shaperail_codegen/
service_client.rs

1use shaperail_core::{EndpointSpec, HttpMethod, ResourceDefinition};
2
3/// Generate a typed inter-service client module for a set of resources.
4///
5/// The generated client provides type-safe methods for calling endpoints
6/// on a remote service, using the resource definitions as the contract.
7/// Type mismatches between services become compile errors.
8pub fn generate_service_client(service_name: &str, resources: &[ResourceDefinition]) -> String {
9    let mod_name = service_name.replace('-', "_");
10    let mut out = String::new();
11
12    out.push_str(&format!(
13        "//! Auto-generated typed client for service `{service_name}`.\n"
14    ));
15    out.push_str("//! DO NOT EDIT — regenerated by `shaperail generate`.\n\n");
16    out.push_str("use serde::{{Deserialize, Serialize}};\n\n");
17
18    // Generate the client struct
19    out.push_str(&format!(
20        "/// Typed HTTP client for the `{service_name}` service.\n"
21    ));
22    out.push_str("#[derive(Debug, Clone)]\n");
23    out.push_str(&format!(
24        "pub struct {client_type} {{\n",
25        client_type = client_type_name(&mod_name)
26    ));
27    out.push_str("    base_url: String,\n");
28    out.push_str("    client: reqwest::Client,\n");
29    out.push_str("    auth_token: Option<String>,\n");
30    out.push_str("}\n\n");
31
32    // Generate constructor
33    out.push_str(&format!(
34        "impl {client_type} {{\n",
35        client_type = client_type_name(&mod_name)
36    ));
37    out.push_str("    /// Create a new client pointing at the given base URL.\n");
38    out.push_str("    pub fn new(base_url: impl Into<String>) -> Self {\n");
39    out.push_str("        Self {\n");
40    out.push_str("            base_url: base_url.into(),\n");
41    out.push_str("            client: reqwest::Client::builder()\n");
42    out.push_str("                .timeout(std::time::Duration::from_secs(10))\n");
43    out.push_str("                .build()\n");
44    out.push_str("                .unwrap_or_default(),\n");
45    out.push_str("            auth_token: None,\n");
46    out.push_str("        }\n");
47    out.push_str("    }\n\n");
48    out.push_str("    /// Set the Bearer token for authenticated requests.\n");
49    out.push_str("    pub fn with_auth(mut self, token: impl Into<String>) -> Self {\n");
50    out.push_str("        self.auth_token = Some(token.into());\n");
51    out.push_str("        self\n");
52    out.push_str("    }\n\n");
53
54    // Generate typed methods per resource endpoint
55    for resource in resources {
56        let endpoints = match &resource.endpoints {
57            Some(ep) => ep,
58            None => continue,
59        };
60
61        for (ep_name, endpoint) in endpoints {
62            let method_name = format!("{}_{}", resource.resource, ep_name);
63            let method_code = generate_endpoint_method(resource, ep_name, endpoint);
64            out.push_str(&method_code);
65            let _ = method_name; // used in generate_endpoint_method
66        }
67    }
68
69    out.push_str("}\n\n");
70
71    // Generate request/response types per resource
72    for resource in resources {
73        out.push_str(&generate_resource_types(resource));
74    }
75
76    out
77}
78
79fn client_type_name(mod_name: &str) -> String {
80    let mut result = String::new();
81    let mut capitalize_next = true;
82    for ch in mod_name.chars() {
83        if ch == '_' {
84            capitalize_next = true;
85        } else if capitalize_next {
86            result.push(ch.to_uppercase().next().unwrap_or(ch));
87            capitalize_next = false;
88        } else {
89            result.push(ch);
90        }
91    }
92    result.push_str("Client");
93    result
94}
95
96fn generate_endpoint_method(
97    resource: &ResourceDefinition,
98    ep_name: &str,
99    endpoint: &EndpointSpec,
100) -> String {
101    let method_name = format!("{}_{}", resource.resource, ep_name);
102    let resource_type = pascal_case(&resource.resource);
103    let version = resource.version;
104
105    let has_id_param = endpoint.path().contains(":id");
106    let has_input = endpoint.input.is_some() && !endpoint.input.as_ref().is_none_or(Vec::is_empty);
107
108    let mut out = String::new();
109
110    // Method signature
111    out.push_str(&format!(
112        "    /// {method} {path}\n",
113        method = endpoint.method(),
114        path = endpoint.path()
115    ));
116    out.push_str(&format!("    pub async fn {method_name}(\n"));
117    out.push_str("        &self,\n");
118    if has_id_param {
119        out.push_str("        id: &str,\n");
120    }
121    if has_input {
122        out.push_str(&format!("        input: &{resource_type}Input,\n"));
123    }
124    out.push_str(&format!(
125        "    ) -> Result<{return_type}, ClientError> {{\n",
126        return_type = match *endpoint.method() {
127            HttpMethod::Delete => "()".to_string(),
128            HttpMethod::Get if ep_name == "list" => format!("Vec<{resource_type}>"),
129            _ => resource_type.clone(),
130        }
131    ));
132
133    // Build URL — generate a Rust format!() call as source code
134    let versioned_path = format!("/v{version}{}", endpoint.path());
135    if has_id_param {
136        // Output: let url = format!("{}/v1/users/{}", self.base_url, id);
137        let path_with_placeholder = versioned_path.replace(":id", "{}");
138        out.push_str("        let url = format!(\"{}");
139        out.push_str(&path_with_placeholder);
140        out.push_str("\", self.base_url, id);\n");
141    } else {
142        // Output: let url = format!("{}/v1/users", self.base_url);
143        out.push_str("        let url = format!(\"{}");
144        out.push_str(&versioned_path);
145        out.push_str("\", self.base_url);\n");
146    }
147
148    // Build request
149    let http_method = match *endpoint.method() {
150        HttpMethod::Get => "get",
151        HttpMethod::Post => "post",
152        HttpMethod::Patch => "patch",
153        HttpMethod::Put => "put",
154        HttpMethod::Delete => "delete",
155    };
156    out.push_str(&format!(
157        "        let mut req = self.client.{http_method}(&url);\n"
158    ));
159    out.push_str("        if let Some(ref token) = self.auth_token {\n");
160    out.push_str("            req = req.bearer_auth(token);\n");
161    out.push_str("        }\n");
162
163    if has_input {
164        out.push_str("        req = req.json(input);\n");
165    }
166
167    // Send and parse
168    out.push_str("        let resp = req.send().await.map_err(ClientError::Request)?;\n");
169    out.push_str("        if !resp.status().is_success() {\n");
170    out.push_str("            let status = resp.status().as_u16();\n");
171    out.push_str("            let body = resp.text().await.unwrap_or_default();\n");
172    out.push_str("            return Err(ClientError::Api { status, body });\n");
173    out.push_str("        }\n");
174
175    match *endpoint.method() {
176        HttpMethod::Delete => {
177            out.push_str("        Ok(())\n");
178        }
179        _ => {
180            out.push_str("        let body = resp.json().await.map_err(ClientError::Request)?;\n");
181            out.push_str("        Ok(body)\n");
182        }
183    }
184
185    out.push_str("    }\n\n");
186    out
187}
188
189fn generate_resource_types(resource: &ResourceDefinition) -> String {
190    let type_name = pascal_case(&resource.resource);
191    let mut out = String::new();
192
193    // Main record type
194    out.push_str(&format!("/// Record type for `{}`.\n", resource.resource));
195    out.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
196    out.push_str(&format!("pub struct {type_name} {{\n"));
197    for (field_name, _field_schema) in &resource.schema {
198        out.push_str(&format!("    pub {field_name}: serde_json::Value,\n"));
199    }
200    out.push_str("}\n\n");
201
202    // Input type (fields that appear in any endpoint's input list)
203    let mut input_fields = std::collections::HashSet::new();
204    if let Some(endpoints) = &resource.endpoints {
205        for (_, ep) in endpoints {
206            if let Some(ref inputs) = ep.input {
207                for f in inputs {
208                    input_fields.insert(f.clone());
209                }
210            }
211        }
212    }
213
214    if !input_fields.is_empty() {
215        out.push_str(&format!("/// Input type for `{}`.\n", resource.resource));
216        out.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
217        out.push_str(&format!("pub struct {type_name}Input {{\n"));
218        for field_name in &input_fields {
219            out.push_str("    #[serde(skip_serializing_if = \"Option::is_none\")]\n");
220            out.push_str(&format!(
221                "    pub {field_name}: Option<serde_json::Value>,\n"
222            ));
223        }
224        out.push_str("}\n\n");
225    }
226
227    out
228}
229
230fn pascal_case(s: &str) -> String {
231    let mut result = String::new();
232    let mut capitalize_next = true;
233    for ch in s.chars() {
234        if ch == '_' || ch == '-' {
235            capitalize_next = true;
236        } else if capitalize_next {
237            result.push(ch.to_uppercase().next().unwrap_or(ch));
238            capitalize_next = false;
239        } else {
240            result.push(ch);
241        }
242    }
243    result
244}
245
246/// Error type for inter-service client calls.
247/// This is included in generated code as a string constant.
248pub const CLIENT_ERROR_TYPE: &str = r#"/// Error from an inter-service client call.
249#[derive(Debug)]
250pub enum ClientError {
251    /// HTTP request failed (network, timeout, etc).
252    Request(reqwest::Error),
253    /// Remote service returned an error status.
254    Api { status: u16, body: String },
255}
256
257impl std::fmt::Display for ClientError {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Self::Request(e) => write!(f, "request error: {e}"),
261            Self::Api { status, body } => write!(f, "API error {status}: {body}"),
262        }
263    }
264}
265
266impl std::error::Error for ClientError {}
267"#;
268
269/// Generate a complete inter-service client module file.
270pub fn generate_client_module(service_name: &str, resources: &[ResourceDefinition]) -> String {
271    let mut out = String::new();
272    out.push_str(CLIENT_ERROR_TYPE);
273    out.push('\n');
274    out.push_str(&generate_service_client(service_name, resources));
275    out
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    fn make_resource() -> ResourceDefinition {
283        let yaml = r#"
284resource: users
285version: 1
286schema:
287  id: { type: uuid, primary: true, generated: true }
288  name: { type: string, required: true }
289endpoints:
290  list:
291    method: GET
292    path: /users
293  create:
294    method: POST
295    path: /users
296    input: [name]
297  get:
298    method: GET
299    path: /users/:id
300  delete:
301    method: DELETE
302    path: /users/:id
303"#;
304        crate::parser::parse_resource(yaml).unwrap()
305    }
306
307    #[test]
308    fn client_type_name_conversion() {
309        assert_eq!(client_type_name("users_api"), "UsersApiClient");
310        assert_eq!(client_type_name("orders"), "OrdersClient");
311    }
312
313    #[test]
314    fn pascal_case_conversion() {
315        assert_eq!(pascal_case("users"), "Users");
316        assert_eq!(pascal_case("order_items"), "OrderItems");
317        assert_eq!(pascal_case("my-service"), "MyService");
318    }
319
320    #[test]
321    fn generate_client_contains_struct() {
322        let resource = make_resource();
323        let code = generate_service_client("users-api", &[resource]);
324        assert!(code.contains("pub struct UsersApiClient"));
325        assert!(code.contains("pub fn new("));
326        assert!(code.contains("pub fn with_auth("));
327    }
328
329    #[test]
330    fn generate_client_contains_methods() {
331        let resource = make_resource();
332        let code = generate_service_client("users-api", &[resource]);
333        assert!(code.contains("pub async fn users_list("));
334        assert!(code.contains("pub async fn users_create("));
335        assert!(code.contains("pub async fn users_get("));
336        assert!(code.contains("pub async fn users_delete("));
337    }
338
339    #[test]
340    fn generate_client_contains_types() {
341        let resource = make_resource();
342        let code = generate_service_client("users-api", &[resource]);
343        assert!(code.contains("pub struct Users {"));
344        assert!(code.contains("pub struct UsersInput {"));
345    }
346
347    #[test]
348    fn generate_client_module_includes_error_type() {
349        let resource = make_resource();
350        let code = generate_client_module("users-api", &[resource]);
351        assert!(code.contains("pub enum ClientError"));
352        assert!(code.contains("pub struct UsersApiClient"));
353    }
354
355    #[test]
356    fn generate_client_empty_resources() {
357        let code = generate_service_client("empty-svc", &[]);
358        assert!(code.contains("pub struct EmptySvcClient"));
359        // No methods generated
360        assert!(!code.contains("pub async fn"));
361    }
362}