Skip to main content

dbrest_core/openapi/
generator.rs

1//! OpenAPI 3.0 specification generator
2//!
3//! Generates OpenAPI 3.0 specifications from the schema cache.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::auth::AuthResult;
9use crate::config::{AppConfig, OpenApiMode};
10use crate::error::Error;
11use crate::schema_cache::{Column, Routine, SchemaCache, Table};
12
13use super::types::*;
14
15/// OpenAPI specification generator
16pub struct OpenApiGenerator {
17    config: Arc<AppConfig>,
18    cache: Arc<SchemaCache>,
19    #[allow(dead_code)] // reserved for future role-scoped OpenAPI
20    auth: Option<AuthResult>,
21}
22
23impl OpenApiGenerator {
24    /// Create a new OpenAPI generator
25    pub fn new(config: Arc<AppConfig>, cache: Arc<SchemaCache>, auth: Option<AuthResult>) -> Self {
26        Self {
27            config,
28            cache,
29            auth,
30        }
31    }
32
33    /// Generate the full OpenAPI 3.0 specification
34    pub fn generate(&self) -> Result<OpenApiSpec, Error> {
35        // Check if OpenAPI is disabled
36        if self.config.openapi_mode == OpenApiMode::Disabled {
37            return Err(Error::OpenApiDisabled);
38        }
39
40        let spec = OpenApiSpec {
41            openapi: "3.0.0".to_string(),
42            info: self.generate_info(),
43            servers: self.generate_servers(),
44            paths: self.generate_paths()?,
45            components: Some(self.generate_components()?),
46            security: self.generate_security(),
47        };
48
49        Ok(spec)
50    }
51
52    fn generate_info(&self) -> Info {
53        Info {
54            title: "PostgREST API".to_string(),
55            description: Some("REST API for PostgreSQL database".to_string()),
56            version: "1.0.0".to_string(),
57        }
58    }
59
60    fn generate_servers(&self) -> Vec<Server> {
61        if let Some(ref proxy_uri) = self.config.openapi_server_proxy_uri {
62            vec![Server {
63                url: proxy_uri.clone(),
64                description: Some("Proxy server".to_string()),
65            }]
66        } else {
67            vec![Server {
68                url: "/".to_string(),
69                description: None,
70            }]
71        }
72    }
73
74    fn generate_paths(&self) -> Result<Paths, Error> {
75        let mut paths = HashMap::new();
76
77        // Generate paths for each table
78        for schema in &self.config.db_schemas {
79            for table in self.cache.tables_in_schema(schema) {
80                // Check privileges if mode is FollowPrivileges
81                if self.config.openapi_mode == OpenApiMode::FollowPrivileges
82                    && !self.can_read_table(table)?
83                {
84                    continue;
85                }
86
87                let path = format!("/{}.{}", table.schema, table.name);
88                let path_item = self.generate_table_path_item(table)?;
89                paths.insert(path, path_item);
90            }
91        }
92
93        // Generate paths for RPC functions
94        for schema in &self.config.db_schemas {
95            if let Some(routines) = self.cache.get_routines_by_name(schema, "") {
96                // Get all routines in schema
97                for routine in routines {
98                    if self.config.openapi_mode == OpenApiMode::FollowPrivileges
99                        && !self.can_execute_routine(routine)?
100                    {
101                        continue;
102                    }
103
104                    let path = format!("/rpc/{}", routine.name);
105                    let path_item = self.generate_rpc_path_item(routine)?;
106                    paths.insert(path, path_item);
107                }
108            }
109        }
110
111        // Also iterate through all routines in the cache
112        for (qi, routines) in self.cache.routines.iter() {
113            if !self.config.db_schemas.contains(&qi.schema.to_string()) {
114                continue;
115            }
116
117            for routine in routines {
118                if self.config.openapi_mode == OpenApiMode::FollowPrivileges
119                    && !self.can_execute_routine(routine)?
120                {
121                    continue;
122                }
123
124                let path = format!("/rpc/{}", routine.name);
125                let path_item = self.generate_rpc_path_item(routine)?;
126                paths.insert(path, path_item);
127            }
128        }
129
130        Ok(Paths { paths })
131    }
132
133    fn generate_table_path_item(&self, table: &Table) -> Result<PathItem, Error> {
134        // GET operation
135        let get_op = self.generate_get_operation(table)?;
136
137        // POST operation (if insertable)
138        let post_op = if table.insertable {
139            Some(self.generate_post_operation(table)?)
140        } else {
141            None
142        };
143
144        // PATCH operation (if updatable)
145        let patch_op = if table.updatable {
146            Some(self.generate_patch_operation(table)?)
147        } else {
148            None
149        };
150
151        // PUT operation (if insertable and updatable)
152        let put_op = if table.insertable && table.updatable {
153            Some(self.generate_put_operation(table)?)
154        } else {
155            None
156        };
157
158        // DELETE operation (if deletable)
159        let delete_op = if table.deletable {
160            Some(self.generate_delete_operation(table)?)
161        } else {
162            None
163        };
164
165        // OPTIONS operation
166        let options_op = Some(self.generate_options_operation());
167
168        // HEAD operation (same as GET but no body)
169        let head_op = Some(self.generate_head_operation(table)?);
170
171        Ok(PathItem {
172            get: Some(get_op),
173            post: post_op,
174            patch: patch_op,
175            put: put_op,
176            delete: delete_op,
177            options: options_op,
178            head: head_op,
179        })
180    }
181
182    fn generate_get_operation(&self, table: &Table) -> Result<Operation, Error> {
183        let mut parameters = vec![
184            // select parameter
185            Parameter {
186                name: "select".to_string(),
187                location: ParameterLocation::Query,
188                description: Some("Columns to select (comma-separated)".to_string()),
189                required: Some(false),
190                schema: Some(Schema::string()),
191                style: None,
192                explode: None,
193            },
194            // order parameter
195            Parameter {
196                name: "order".to_string(),
197                location: ParameterLocation::Query,
198                description: Some("Order by column(s)".to_string()),
199                required: Some(false),
200                schema: Some(Schema::string()),
201                style: None,
202                explode: None,
203            },
204            // limit parameter
205            Parameter {
206                name: "limit".to_string(),
207                location: ParameterLocation::Query,
208                description: Some("Limit number of results".to_string()),
209                required: Some(false),
210                schema: Some(Schema::integer().with_format("int64".to_string())),
211                style: None,
212                explode: None,
213            },
214            // offset parameter
215            Parameter {
216                name: "offset".to_string(),
217                location: ParameterLocation::Query,
218                description: Some("Skip number of results".to_string()),
219                required: Some(false),
220                schema: Some(Schema::integer().with_format("int64".to_string())),
221                style: None,
222                explode: None,
223            },
224        ];
225
226        // Add filter parameters for each column
227        for col in table.columns_list() {
228            parameters.push(Parameter {
229                name: col.name.to_string(),
230                location: ParameterLocation::Query,
231                description: col.description.clone(),
232                required: Some(false),
233                schema: Some(self.column_to_schema(col)?),
234                style: None,
235                explode: None,
236            });
237        }
238
239        let responses = self.generate_read_responses(table)?;
240
241        Ok(Operation {
242            summary: Some(format!("List {} records", table.name)),
243            description: table.description.clone(),
244            operation_id: format!("get_{}_{}", table.schema, table.name),
245            tags: vec![table.schema.to_string()],
246            parameters,
247            request_body: None,
248            responses,
249            security: self.generate_operation_security(),
250        })
251    }
252
253    fn generate_post_operation(&self, table: &Table) -> Result<Operation, Error> {
254        let schema_name = format!("{}_{}", table.schema, table.name);
255        let schema_ref = format!("#/components/schemas/{}", schema_name);
256
257        let mut content = HashMap::new();
258        content.insert(
259            "application/json".to_string(),
260            MediaTypeObject {
261                schema: Some(Schema::ref_(&schema_ref)),
262                example: None,
263            },
264        );
265
266        let request_body = RequestBody {
267            description: Some(format!("{} record to insert", table.name)),
268            required: Some(true),
269            content,
270        };
271
272        let mut responses = HashMap::new();
273        responses.insert(
274            "201".to_string(),
275            Response {
276                description: "Created".to_string(),
277                content: Some({
278                    let mut c = HashMap::new();
279                    c.insert(
280                        "application/json".to_string(),
281                        MediaTypeObject {
282                            schema: Some(Schema::ref_(&schema_ref)),
283                            example: None,
284                        },
285                    );
286                    c
287                }),
288                headers: Some({
289                    let mut h = HashMap::new();
290                    h.insert(
291                        "Location".to_string(),
292                        Header {
293                            description: Some("URL of created resource".to_string()),
294                            required: Some(false),
295                            schema: Schema::string(),
296                        },
297                    );
298                    h
299                }),
300            },
301        );
302        responses.insert(
303            "400".to_string(),
304            Response {
305                description: "Bad Request".to_string(),
306                content: None,
307                headers: None,
308            },
309        );
310
311        Ok(Operation {
312            summary: Some(format!("Create {} record", table.name)),
313            description: table.description.clone(),
314            operation_id: format!("post_{}_{}", table.schema, table.name),
315            tags: vec![table.schema.to_string()],
316            parameters: vec![],
317            request_body: Some(request_body),
318            responses: Responses { responses },
319            security: self.generate_operation_security(),
320        })
321    }
322
323    fn generate_patch_operation(&self, table: &Table) -> Result<Operation, Error> {
324        let schema_name = format!("{}_{}", table.schema, table.name);
325        let schema_ref = format!("#/components/schemas/{}", schema_name);
326
327        let mut content = HashMap::new();
328        content.insert(
329            "application/json".to_string(),
330            MediaTypeObject {
331                schema: Some(Schema::ref_(&schema_ref)),
332                example: None,
333            },
334        );
335
336        let request_body = RequestBody {
337            description: Some(format!("{} record to update", table.name)),
338            required: Some(true),
339            content,
340        };
341
342        let mut responses = HashMap::new();
343        responses.insert(
344            "200".to_string(),
345            Response {
346                description: "OK".to_string(),
347                content: Some({
348                    let mut c = HashMap::new();
349                    c.insert(
350                        "application/json".to_string(),
351                        MediaTypeObject {
352                            schema: Some(Schema::array(Schema::ref_(&schema_ref))),
353                            example: None,
354                        },
355                    );
356                    c
357                }),
358                headers: None,
359            },
360        );
361
362        Ok(Operation {
363            summary: Some(format!("Update {} records", table.name)),
364            description: table.description.clone(),
365            operation_id: format!("patch_{}_{}", table.schema, table.name),
366            tags: vec![table.schema.to_string()],
367            parameters: vec![],
368            request_body: Some(request_body),
369            responses: Responses { responses },
370            security: self.generate_operation_security(),
371        })
372    }
373
374    fn generate_put_operation(&self, table: &Table) -> Result<Operation, Error> {
375        // PUT is similar to POST but for upsert
376        self.generate_post_operation(table)
377    }
378
379    fn generate_delete_operation(&self, table: &Table) -> Result<Operation, Error> {
380        let mut responses = HashMap::new();
381        responses.insert(
382            "204".to_string(),
383            Response {
384                description: "No Content".to_string(),
385                content: None,
386                headers: None,
387            },
388        );
389        responses.insert(
390            "200".to_string(),
391            Response {
392                description: "OK".to_string(),
393                content: None,
394                headers: None,
395            },
396        );
397
398        Ok(Operation {
399            summary: Some(format!("Delete {} records", table.name)),
400            description: table.description.clone(),
401            operation_id: format!("delete_{}_{}", table.schema, table.name),
402            tags: vec![table.schema.to_string()],
403            parameters: vec![],
404            request_body: None,
405            responses: Responses { responses },
406            security: self.generate_operation_security(),
407        })
408    }
409
410    fn generate_options_operation(&self) -> Operation {
411        let mut responses = HashMap::new();
412        responses.insert(
413            "200".to_string(),
414            Response {
415                description: "OK".to_string(),
416                content: None,
417                headers: None,
418            },
419        );
420
421        Operation {
422            summary: None,
423            description: None,
424            operation_id: "options".to_string(),
425            tags: vec![],
426            parameters: vec![],
427            request_body: None,
428            responses: Responses { responses },
429            security: None,
430        }
431    }
432
433    fn generate_head_operation(&self, table: &Table) -> Result<Operation, Error> {
434        // HEAD is same as GET but no body
435        let mut op = self.generate_get_operation(table)?;
436        op.operation_id = format!("head_{}_{}", table.schema, table.name);
437        // Remove content from responses
438        for response in op.responses.responses.values_mut() {
439            response.content = None;
440        }
441        Ok(op)
442    }
443
444    fn generate_read_responses(&self, table: &Table) -> Result<Responses, Error> {
445        let schema_name = format!("{}_{}", table.schema, table.name);
446        let schema_ref = format!("#/components/schemas/{}", schema_name);
447
448        let mut responses = HashMap::new();
449        responses.insert(
450            "200".to_string(),
451            Response {
452                description: "OK".to_string(),
453                content: Some({
454                    let mut c = HashMap::new();
455                    c.insert(
456                        "application/json".to_string(),
457                        MediaTypeObject {
458                            schema: Some(Schema::array(Schema::ref_(&schema_ref))),
459                            example: None,
460                        },
461                    );
462                    c
463                }),
464                headers: Some({
465                    let mut h = HashMap::new();
466                    h.insert(
467                        "Content-Range".to_string(),
468                        Header {
469                            description: Some("Range of results".to_string()),
470                            required: Some(false),
471                            schema: Schema::string(),
472                        },
473                    );
474                    h
475                }),
476            },
477        );
478        responses.insert(
479            "406".to_string(),
480            Response {
481                description: "Not Acceptable".to_string(),
482                content: None,
483                headers: None,
484            },
485        );
486
487        Ok(Responses { responses })
488    }
489
490    fn column_to_schema(&self, col: &Column) -> Result<Schema, Error> {
491        let mut schema = if col.is_array_type() {
492            // For arrays, get the base type
493            let base_type = col.data_type.trim_end_matches("[]");
494            let item_schema = self.pg_type_to_schema(base_type)?;
495            Schema::array(item_schema)
496        } else {
497            self.pg_type_to_schema(&col.data_type)?
498        };
499
500        // Add nullable if column allows null
501        if col.nullable {
502            schema = schema.nullable();
503        }
504
505        // Add description
506        if let Some(ref desc) = col.description {
507            schema = schema.with_description(desc.clone());
508        }
509
510        // Add enum values if it's an enum
511        if col.is_enum()
512            && let Schema::Object { enum_values, .. } = &mut schema
513        {
514            *enum_values = Some(
515                col.enum_values
516                    .iter()
517                    .map(|v| serde_json::Value::String(v.clone()))
518                    .collect(),
519            );
520        }
521
522        Ok(schema)
523    }
524
525    fn pg_type_to_schema(&self, pg_type: &str) -> Result<Schema, Error> {
526        let schema = match pg_type {
527            // Integer types
528            "integer" | "int" | "int4" => Schema::integer().with_format("int32".to_string()),
529            "bigint" | "int8" => Schema::integer().with_format("int64".to_string()),
530            "smallint" | "int2" => Schema::integer().with_format("int32".to_string()),
531            "serial" | "serial4" => Schema::integer().with_format("int32".to_string()),
532            "bigserial" | "serial8" => Schema::integer().with_format("int64".to_string()),
533
534            // Numeric types
535            "numeric" | "decimal" => Schema::number().with_format("double".to_string()),
536            "real" | "float4" => Schema::number().with_format("float".to_string()),
537            "double precision" | "float8" => Schema::number().with_format("double".to_string()),
538
539            // String types
540            "text" | "character varying" | "varchar" | "character" | "char" | "name" => {
541                Schema::string()
542            }
543
544            // Boolean
545            "boolean" | "bool" => Schema::boolean(),
546
547            // Date/time types
548            "date" => Schema::string().with_format("date".to_string()),
549            "time without time zone" | "time" => Schema::string().with_format("time".to_string()),
550            "time with time zone" | "timetz" => Schema::string().with_format("time".to_string()),
551            "timestamp without time zone" | "timestamp" => {
552                Schema::string().with_format("date-time".to_string())
553            }
554            "timestamp with time zone" | "timestamptz" => {
555                Schema::string().with_format("date-time".to_string())
556            }
557            "interval" => Schema::string(),
558
559            // UUID
560            "uuid" => Schema::string().with_format("uuid".to_string()),
561
562            // JSON types
563            "json" | "jsonb" => Schema::object(HashMap::new(), vec![]),
564
565            // Binary
566            "bytea" => Schema::string().with_format("byte".to_string()),
567
568            // Default: treat as string
569            _ => Schema::string(),
570        };
571
572        Ok(schema)
573    }
574
575    fn generate_components(&self) -> Result<Components, Error> {
576        let mut schemas = HashMap::new();
577
578        // Generate schemas for each table
579        for schema in &self.config.db_schemas {
580            for table in self.cache.tables_in_schema(schema) {
581                if self.config.openapi_mode == OpenApiMode::FollowPrivileges
582                    && !self.can_read_table(table)?
583                {
584                    continue;
585                }
586
587                let schema_name = format!("{}_{}", table.schema, table.name);
588                let table_schema = self.table_to_schema(table)?;
589                schemas.insert(schema_name, table_schema);
590            }
591        }
592
593        let security_schemes = if self.config.openapi_security_active {
594            Some(self.generate_security_schemes())
595        } else {
596            None
597        };
598
599        Ok(Components {
600            schemas,
601            security_schemes,
602        })
603    }
604
605    fn table_to_schema(&self, table: &Table) -> Result<Schema, Error> {
606        let mut properties = HashMap::new();
607        let mut required = Vec::new();
608
609        for col in table.columns_list() {
610            let col_schema = self.column_to_schema(col)?;
611            properties.insert(col.name.to_string(), col_schema);
612
613            if !col.nullable && !col.has_default() && !col.is_generated() {
614                required.push(col.name.to_string());
615            }
616        }
617
618        let mut schema = Schema::object(properties, required);
619
620        // Add description
621        if let Some(ref desc) = table.description
622            && let Schema::Object { description, .. } = &mut schema
623        {
624            *description = Some(desc.clone());
625        }
626
627        Ok(schema)
628    }
629
630    fn generate_security_schemes(&self) -> HashMap<String, SecurityScheme> {
631        let mut schemes = HashMap::new();
632        schemes.insert(
633            "bearer".to_string(),
634            SecurityScheme {
635                type_: "http".to_string(),
636                scheme: Some("bearer".to_string()),
637                bearer_format: Some("JWT".to_string()),
638                description: Some("JWT authentication".to_string()),
639            },
640        );
641        schemes
642    }
643
644    fn generate_security(&self) -> Option<Vec<SecurityRequirement>> {
645        if self.config.openapi_security_active {
646            let mut req = HashMap::new();
647            req.insert("bearer".to_string(), vec![]);
648            Some(vec![SecurityRequirement { requirements: req }])
649        } else {
650            None
651        }
652    }
653
654    fn generate_operation_security(&self) -> Option<Vec<SecurityRequirement>> {
655        self.generate_security()
656    }
657
658    fn generate_rpc_path_item(&self, routine: &Routine) -> Result<PathItem, Error> {
659        let get_op = self.generate_rpc_get_operation(routine)?;
660        let post_op = self.generate_rpc_post_operation(routine)?;
661
662        Ok(PathItem {
663            get: Some(get_op),
664            post: Some(post_op),
665            patch: None,
666            put: None,
667            delete: None,
668            options: Some(self.generate_options_operation()),
669            head: None,
670        })
671    }
672
673    fn generate_rpc_get_operation(&self, routine: &Routine) -> Result<Operation, Error> {
674        let mut parameters = vec![];
675
676        // Add parameters for each function parameter
677        for param in &routine.params {
678            parameters.push(Parameter {
679                name: param.name.to_string(),
680                location: ParameterLocation::Query,
681                description: None, // RoutineParam doesn't have description field
682                required: Some(param.required),
683                schema: Some(self.routine_param_to_schema(param)?),
684                style: None,
685                explode: None,
686            });
687        }
688
689        let responses = self.generate_rpc_responses(routine)?;
690
691        Ok(Operation {
692            summary: Some(format!("Call {} function", routine.name)),
693            description: routine.description.clone(),
694            operation_id: format!("rpc_get_{}", routine.name),
695            tags: vec![routine.schema.to_string()],
696            parameters,
697            request_body: None,
698            responses,
699            security: self.generate_operation_security(),
700        })
701    }
702
703    fn generate_rpc_post_operation(&self, routine: &Routine) -> Result<Operation, Error> {
704        // For POST, parameters go in the request body
705        let mut properties = HashMap::new();
706        let mut required = Vec::new();
707
708        for param in &routine.params {
709            let param_schema = self.routine_param_to_schema(param)?;
710            properties.insert(param.name.to_string(), param_schema);
711            if param.required {
712                required.push(param.name.to_string());
713            }
714        }
715
716        let request_body = if properties.is_empty() {
717            None
718        } else {
719            let mut content = HashMap::new();
720            content.insert(
721                "application/json".to_string(),
722                MediaTypeObject {
723                    schema: Some(Schema::object(properties, required)),
724                    example: None,
725                },
726            );
727
728            Some(RequestBody {
729                description: routine.description.clone(),
730                required: Some(true),
731                content,
732            })
733        };
734
735        let responses = self.generate_rpc_responses(routine)?;
736
737        Ok(Operation {
738            summary: Some(format!("Call {} function", routine.name)),
739            description: routine.description.clone(),
740            operation_id: format!("rpc_post_{}", routine.name),
741            tags: vec![routine.schema.to_string()],
742            parameters: vec![],
743            request_body,
744            responses,
745            security: self.generate_operation_security(),
746        })
747    }
748
749    fn generate_rpc_responses(&self, routine: &Routine) -> Result<Responses, Error> {
750        let mut responses = HashMap::new();
751
752        // Determine response schema based on return type
753        let response_schema = self.routine_return_type_to_schema(routine)?;
754
755        responses.insert(
756            "200".to_string(),
757            Response {
758                description: "OK".to_string(),
759                content: Some({
760                    let mut c = HashMap::new();
761                    c.insert(
762                        "application/json".to_string(),
763                        MediaTypeObject {
764                            schema: Some(response_schema),
765                            example: None,
766                        },
767                    );
768                    c
769                }),
770                headers: None,
771            },
772        );
773
774        Ok(Responses { responses })
775    }
776
777    fn routine_param_to_schema(
778        &self,
779        param: &crate::schema_cache::RoutineParam,
780    ) -> Result<Schema, Error> {
781        // Use type_max_length which includes length info, or fall back to pg_type
782        let type_str = if param.type_max_length != param.pg_type {
783            &param.type_max_length
784        } else {
785            &param.pg_type
786        };
787        self.pg_type_to_schema(type_str.as_str())
788    }
789
790    fn routine_return_type_to_schema(&self, routine: &Routine) -> Result<Schema, Error> {
791        use crate::schema_cache::{PgType, ReturnType};
792
793        match &routine.return_type {
794            ReturnType::Single(PgType::Scalar(_)) => {
795                // Scalar return type - return as object with value field
796                let mut props = HashMap::new();
797                props.insert("value".to_string(), Schema::string());
798                Ok(Schema::object(props, vec![]))
799            }
800            ReturnType::SetOf(PgType::Scalar(_)) => {
801                // Array of scalars
802                Ok(Schema::array(Schema::string()))
803            }
804            ReturnType::Single(PgType::Composite(qi, _)) => {
805                // Single composite (table row)
806                let schema_name = format!("{}_{}", qi.schema, qi.name);
807                Ok(Schema::ref_(&format!(
808                    "#/components/schemas/{}",
809                    schema_name
810                )))
811            }
812            ReturnType::SetOf(PgType::Composite(qi, _)) => {
813                // Array of composites
814                let schema_name = format!("{}_{}", qi.schema, qi.name);
815                Ok(Schema::array(Schema::ref_(&format!(
816                    "#/components/schemas/{}",
817                    schema_name
818                ))))
819            }
820        }
821    }
822
823    fn can_read_table(&self, table: &Table) -> Result<bool, Error> {
824        // Check if OpenAPI mode ignores privileges
825        if self.config.openapi_mode == OpenApiMode::IgnorePrivileges {
826            return Ok(true);
827        }
828
829        // Check actual PostgreSQL SELECT privilege
830        Ok(table.readable)
831    }
832
833    fn can_execute_routine(&self, routine: &Routine) -> Result<bool, Error> {
834        // Check if OpenAPI mode ignores privileges
835        if self.config.openapi_mode == OpenApiMode::IgnorePrivileges {
836            return Ok(true);
837        }
838
839        // Check actual PostgreSQL EXECUTE privilege
840        Ok(routine.executable)
841    }
842}