Skip to main content

structured_proxy/
openapi.rs

1//! OpenAPI 3.0 spec generation from proto descriptors.
2//!
3//! Reads `google.api.http` annotations and proto message definitions
4//! to produce a complete OpenAPI 3.0 JSON spec at runtime.
5//! No codegen, no build step — same descriptor pool used for transcoding.
6
7use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor, MethodDescriptor};
8use serde_json::{json, Map, Value};
9
10use crate::config::{AliasConfig, OpenApiConfig};
11
12/// Generate OpenAPI 3.0 JSON spec from a descriptor pool.
13pub fn generate(pool: &DescriptorPool, config: &OpenApiConfig, aliases: &[AliasConfig]) -> Value {
14    let title = config.title.as_deref().unwrap_or("API");
15    let version = config.version.as_deref().unwrap_or("1.0.0");
16
17    let mut paths = Map::new();
18    let mut schemas = Map::new();
19    let mut tags = Vec::new();
20
21    for service in pool.services() {
22        let service_name = service.name().to_string();
23        let service_full = service.full_name().to_string();
24
25        // Proto comments as tag description.
26        let tag_desc = get_comments(&service_full, pool);
27        let mut tag = json!({ "name": service_name });
28        if let Some(desc) = &tag_desc {
29            tag["description"] = json!(desc);
30        }
31        tags.push(tag);
32
33        for method in service.methods() {
34            if method.is_client_streaming() {
35                continue; // No REST mapping for client-streaming.
36            }
37
38            if let Some((http_method, http_path)) = extract_http_rule(&method, pool) {
39                let operation = build_operation(
40                    &method,
41                    &service_name,
42                    &http_method,
43                    &http_path,
44                    pool,
45                    &mut schemas,
46                );
47
48                // Main path.
49                add_path_operation(&mut paths, &http_path, &http_method, operation.clone());
50
51                // Aliases.
52                for alias in aliases {
53                    if let Some(suffix) = http_path.strip_prefix(&alias.to) {
54                        if alias.from.ends_with("/{path}") {
55                            let prefix = alias.from.trim_end_matches("/{path}");
56                            let alias_path = format!("{}{}", prefix, suffix);
57                            add_path_operation(
58                                &mut paths,
59                                &alias_path,
60                                &http_method,
61                                operation.clone(),
62                            );
63                        }
64                    }
65                }
66            }
67        }
68    }
69
70    let mut spec = json!({
71        "openapi": "3.0.3",
72        "info": {
73            "title": title,
74            "version": version,
75        },
76        "paths": paths,
77        "tags": tags,
78    });
79
80    if !schemas.is_empty() {
81        spec["components"] = json!({
82            "schemas": schemas,
83        });
84    }
85
86    // Security scheme for Bearer auth (cookie auth works implicitly via same-origin).
87    spec["components"]["securitySchemes"] = json!({
88        "bearerAuth": {
89            "type": "http",
90            "scheme": "bearer",
91            "bearerFormat": "JWT",
92        },
93        "cookieAuth": {
94            "type": "apiKey",
95            "in": "cookie",
96            "name": "session",
97            "description": "Browser session cookie (same-origin, set by BFF login flow)",
98        },
99    });
100
101    spec
102}
103
104/// Generate Scalar API docs HTML page.
105pub fn docs_html(openapi_path: &str, title: &str) -> String {
106    format!(
107        r#"<!DOCTYPE html>
108<html>
109<head>
110    <title>{title} — API Docs</title>
111    <meta charset="utf-8" />
112    <meta name="viewport" content="width=device-width, initial-scale=1" />
113</head>
114<body>
115    <script id="api-reference" data-url="{openapi_path}"></script>
116    <script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
117</body>
118</html>"#,
119        title = title,
120        openapi_path = openapi_path,
121    )
122}
123
124fn add_path_operation(paths: &mut Map<String, Value>, path: &str, method: &str, operation: Value) {
125    let path_item = paths.entry(path.to_string()).or_insert_with(|| json!({}));
126    if let Some(obj) = path_item.as_object_mut() {
127        obj.insert(method.to_string(), operation);
128    }
129}
130
131fn build_operation(
132    method: &MethodDescriptor,
133    service_name: &str,
134    http_method: &str,
135    http_path: &str,
136    pool: &DescriptorPool,
137    schemas: &mut Map<String, Value>,
138) -> Value {
139    let method_name = method.name().to_string();
140    let full_name = method.full_name().to_string();
141    let input = method.input();
142    let output = method.output();
143
144    let is_streaming = method.is_server_streaming();
145
146    // Description from proto comments.
147    let description = get_comments(&full_name, pool).unwrap_or_default();
148
149    let operation_id = format!("{}.{}", service_name, method_name);
150
151    let mut op = json!({
152        "operationId": operation_id,
153        "tags": [service_name],
154        "summary": method_name,
155    });
156
157    if !description.is_empty() {
158        op["description"] = json!(description);
159    }
160
161    // Path parameters.
162    let path_params = extract_path_params(http_path);
163    if !path_params.is_empty() {
164        let params: Vec<Value> = path_params
165            .iter()
166            .map(|name| {
167                let mut param = json!({
168                    "name": name,
169                    "in": "path",
170                    "required": true,
171                    "schema": { "type": "string" },
172                });
173
174                // Try to get type from input message field.
175                if let Some(field) = input.get_field_by_name(name) {
176                    param["schema"] = field_to_schema(&field);
177                }
178
179                param
180            })
181            .collect();
182        op["parameters"] = json!(params);
183    }
184
185    // Request body (for POST/PUT/PATCH/DELETE with body fields).
186    if http_method != "get" {
187        let has_body_fields = input
188            .fields()
189            .any(|f| !path_params.contains(&f.name().to_string()));
190
191        if has_body_fields {
192            let schema_name = input.name().to_string();
193            let body_schema = message_to_schema(&input, &path_params, schemas);
194
195            schemas.insert(schema_name.clone(), body_schema);
196
197            op["requestBody"] = json!({
198                "required": true,
199                "content": {
200                    "application/json": {
201                        "schema": {
202                            "$ref": format!("#/components/schemas/{}", schema_name),
203                        },
204                    },
205                },
206            });
207        }
208    } else {
209        // GET: non-path fields become query parameters.
210        let query_params: Vec<Value> = input
211            .fields()
212            .filter(|f| !path_params.contains(&f.name().to_string()))
213            .map(|field| {
214                json!({
215                    "name": field.name(),
216                    "in": "query",
217                    "required": false,
218                    "schema": field_to_schema(&field),
219                })
220            })
221            .collect();
222
223        if !query_params.is_empty() {
224            let existing = op
225                .get("parameters")
226                .and_then(|v| v.as_array())
227                .cloned()
228                .unwrap_or_default();
229            let mut all_params = existing;
230            all_params.extend(query_params);
231            op["parameters"] = json!(all_params);
232        }
233    }
234
235    // Response.
236    if is_streaming {
237        op["responses"] = json!({
238            "200": {
239                "description": "Server-streaming response (NDJSON)",
240                "content": {
241                    "application/x-ndjson": {
242                        "schema": message_ref_or_inline(&output, schemas),
243                    },
244                },
245            },
246        });
247    } else if output.full_name() == "google.protobuf.Empty" {
248        op["responses"] = json!({
249            "200": {
250                "description": "Success (empty response)",
251            },
252        });
253    } else {
254        let schema_name = output.name().to_string();
255        let response_schema = message_to_schema(&output, &[], schemas);
256        schemas.insert(schema_name.clone(), response_schema);
257
258        op["responses"] = json!({
259            "200": {
260                "description": "Success",
261                "content": {
262                    "application/json": {
263                        "schema": {
264                            "$ref": format!("#/components/schemas/{}", schema_name),
265                        },
266                    },
267                },
268            },
269        });
270    }
271
272    // Common error responses.
273    if let Some(responses) = op.get_mut("responses").and_then(|r| r.as_object_mut()) {
274        responses.insert(
275            "400".to_string(),
276            json!({ "description": "Invalid argument" }),
277        );
278        responses.insert(
279            "401".to_string(),
280            json!({ "description": "Unauthenticated" }),
281        );
282        responses.insert(
283            "403".to_string(),
284            json!({ "description": "Permission denied" }),
285        );
286        responses.insert("404".to_string(), json!({ "description": "Not found" }));
287        responses.insert(
288            "503".to_string(),
289            json!({ "description": "Service unavailable" }),
290        );
291    }
292
293    op
294}
295
296/// Generate a JSON Schema for a protobuf message, excluding path parameter fields.
297fn message_to_schema(
298    msg: &MessageDescriptor,
299    exclude_fields: &[String],
300    schemas: &mut Map<String, Value>,
301) -> Value {
302    let mut properties = Map::new();
303    let required: Vec<String> = Vec::new();
304
305    for field in msg.fields() {
306        let name = field.name().to_string();
307        if exclude_fields.contains(&name) {
308            continue;
309        }
310
311        let schema = field_to_schema(&field);
312        properties.insert(name, schema);
313    }
314
315    let mut schema = json!({
316        "type": "object",
317        "properties": properties,
318    });
319
320    if !required.is_empty() {
321        schema["required"] = json!(required);
322    }
323
324    // Nested messages: register as separate schemas.
325    for field in msg.fields() {
326        if exclude_fields.contains(&field.name().to_string()) {
327            continue;
328        }
329        if let Kind::Message(nested) = field.kind() {
330            if !is_well_known(&nested) && !schemas.contains_key(nested.name()) {
331                let nested_schema = message_to_schema(&nested, &[], schemas);
332                schemas.insert(nested.name().to_string(), nested_schema);
333            }
334        }
335    }
336
337    schema
338}
339
340fn message_ref_or_inline(msg: &MessageDescriptor, schemas: &mut Map<String, Value>) -> Value {
341    let name = msg.name().to_string();
342    if !schemas.contains_key(&name) {
343        let schema = message_to_schema(msg, &[], schemas);
344        schemas.insert(name.clone(), schema);
345    }
346    json!({ "$ref": format!("#/components/schemas/{}", name) })
347}
348
349fn field_to_schema(field: &FieldDescriptor) -> Value {
350    let base = match field.kind() {
351        Kind::Double | Kind::Float => json!({ "type": "number", "format": "double" }),
352        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
353            json!({ "type": "integer", "format": "int32" })
354        }
355        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
356            json!({ "type": "string", "format": "int64", "description": "64-bit integer (string-encoded)" })
357        }
358        Kind::Uint32 | Kind::Fixed32 => {
359            json!({ "type": "integer", "format": "uint32" })
360        }
361        Kind::Uint64 | Kind::Fixed64 => {
362            json!({ "type": "string", "format": "uint64", "description": "64-bit unsigned integer (string-encoded)" })
363        }
364        Kind::Bool => json!({ "type": "boolean" }),
365        Kind::String => json!({ "type": "string" }),
366        Kind::Bytes => json!({ "type": "string", "format": "byte" }),
367        Kind::Enum(e) => {
368            let values: Vec<Value> = e.values().map(|v| json!(v.name())).collect();
369            json!({ "type": "string", "enum": values })
370        }
371        Kind::Message(msg) => {
372            if is_well_known(&msg) {
373                well_known_schema(&msg)
374            } else {
375                json!({ "$ref": format!("#/components/schemas/{}", msg.name()) })
376            }
377        }
378    };
379
380    if field.is_list() {
381        json!({ "type": "array", "items": base })
382    } else if field.is_map() {
383        // Map<K, V> → object with additionalProperties.
384        if let Kind::Message(entry) = field.kind() {
385            let value_field = entry.get_field_by_name("value");
386            let value_schema = value_field
387                .map(|f| field_to_schema(&f))
388                .unwrap_or_else(|| json!({}));
389            json!({ "type": "object", "additionalProperties": value_schema })
390        } else {
391            json!({ "type": "object" })
392        }
393    } else {
394        base
395    }
396}
397
398fn is_well_known(msg: &MessageDescriptor) -> bool {
399    msg.full_name().starts_with("google.protobuf.")
400}
401
402fn well_known_schema(msg: &MessageDescriptor) -> Value {
403    match msg.full_name() {
404        "google.protobuf.Timestamp" => {
405            json!({ "type": "string", "format": "date-time" })
406        }
407        "google.protobuf.Duration" => {
408            json!({ "type": "string", "format": "duration", "example": "3.5s" })
409        }
410        "google.protobuf.Empty" => json!({ "type": "object" }),
411        "google.protobuf.Struct" => json!({ "type": "object" }),
412        "google.protobuf.Value" => json!({}),
413        "google.protobuf.ListValue" => json!({ "type": "array", "items": {} }),
414        "google.protobuf.StringValue" | "google.protobuf.BytesValue" => {
415            json!({ "type": "string" })
416        }
417        "google.protobuf.BoolValue" => json!({ "type": "boolean" }),
418        "google.protobuf.Int32Value" | "google.protobuf.UInt32Value" => {
419            json!({ "type": "integer" })
420        }
421        "google.protobuf.Int64Value" | "google.protobuf.UInt64Value" => {
422            json!({ "type": "string", "format": "int64" })
423        }
424        "google.protobuf.FloatValue" | "google.protobuf.DoubleValue" => {
425            json!({ "type": "number" })
426        }
427        "google.protobuf.FieldMask" => {
428            json!({ "type": "string", "description": "Comma-separated field paths" })
429        }
430        "google.protobuf.Any" => {
431            json!({ "type": "object", "properties": { "@type": { "type": "string" } }, "additionalProperties": true })
432        }
433        _ => json!({ "type": "object" }),
434    }
435}
436
437/// Extract `{param}` names from a path like `/v1/profiles/{profile_id}/devices`.
438fn extract_path_params(path: &str) -> Vec<String> {
439    let mut params = Vec::new();
440    let mut in_brace = false;
441    let mut current = String::new();
442
443    for ch in path.chars() {
444        match ch {
445            '{' => {
446                in_brace = true;
447                current.clear();
448            }
449            '}' => {
450                in_brace = false;
451                if !current.is_empty() {
452                    params.push(current.clone());
453                }
454            }
455            _ if in_brace => current.push(ch),
456            _ => {}
457        }
458    }
459
460    params
461}
462
463/// Extract HTTP method and path from google.api.http annotation.
464fn extract_http_rule(method: &MethodDescriptor, pool: &DescriptorPool) -> Option<(String, String)> {
465    let http_ext = pool.get_extension_by_name("google.api.http")?;
466    let options = method.options();
467
468    if !options.has_extension(&http_ext) {
469        return None;
470    }
471
472    let http_rule = options.get_extension(&http_ext);
473    if let prost_reflect::Value::Message(rule_msg) = http_rule.into_owned() {
474        for (method_name, _) in [
475            ("get", "get"),
476            ("post", "post"),
477            ("put", "put"),
478            ("delete", "delete"),
479            ("patch", "patch"),
480        ] {
481            if let Some(val) = rule_msg.get_field_by_name(method_name) {
482                if let prost_reflect::Value::String(path) = val.into_owned() {
483                    if !path.is_empty() {
484                        return Some((method_name.to_string(), path));
485                    }
486                }
487            }
488        }
489    }
490
491    None
492}
493
494/// Get proto source comments for a given fully-qualified name.
495fn get_comments(_full_name: &str, _pool: &DescriptorPool) -> Option<String> {
496    // prost-reflect doesn't expose source code info comments easily.
497    // For now, return None. Can be enhanced with protoc-gen-doc or
498    // manual SourceCodeInfo parsing.
499    None
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_extract_path_params() {
508        assert_eq!(
509            extract_path_params("/v1/profiles/{profile_id}"),
510            vec!["profile_id"]
511        );
512        assert_eq!(
513            extract_path_params("/v1/profiles/{profile_id}/devices/{device_id}"),
514            vec!["profile_id", "device_id"]
515        );
516        assert!(extract_path_params("/v1/auth/login").is_empty());
517    }
518
519    #[test]
520    fn test_docs_html_contains_scalar() {
521        let html = docs_html("/openapi.json", "Test API");
522        assert!(html.contains("@scalar/api-reference"));
523        assert!(html.contains("/openapi.json"));
524        assert!(html.contains("Test API"));
525    }
526
527    #[test]
528    fn test_well_known_schemas() {
529        // Verify well-known type mappings are correct.
530        let pool = DescriptorPool::global();
531        if let Some(ts) = pool.get_message_by_name("google.protobuf.Timestamp") {
532            let schema = well_known_schema(&ts);
533            assert_eq!(schema["type"], "string");
534            assert_eq!(schema["format"], "date-time");
535        }
536    }
537
538    #[test]
539    fn test_generate_empty_pool() {
540        let pool = DescriptorPool::new();
541        let config = OpenApiConfig {
542            enabled: true,
543            path: "/openapi.json".into(),
544            docs_path: "/docs".into(),
545            title: Some("Test API".into()),
546            version: Some("0.1.0".into()),
547        };
548        let spec = generate(&pool, &config, &[]);
549
550        assert_eq!(spec["openapi"], "3.0.3");
551        assert_eq!(spec["info"]["title"], "Test API");
552        assert_eq!(spec["info"]["version"], "0.1.0");
553        assert!(spec["paths"].as_object().unwrap().is_empty());
554    }
555
556    #[test]
557    fn test_field_to_schema_primitives() {
558        // Test via JSON output structure.
559        let schema = json!({ "type": "string" });
560        assert_eq!(schema["type"], "string");
561
562        let int_schema = json!({ "type": "integer", "format": "int32" });
563        assert_eq!(int_schema["format"], "int32");
564
565        let i64_schema = json!({ "type": "string", "format": "int64", "description": "64-bit integer (string-encoded)" });
566        assert_eq!(i64_schema["type"], "string");
567        assert_eq!(i64_schema["format"], "int64");
568    }
569}