Skip to main content

dbrest_core/schema_cache/
routine.rs

1//! Routine (function/procedure) types for schema cache
2//!
3//! This module defines types for representing PostgreSQL functions and procedures
4//! in the schema cache.
5
6use compact_str::CompactString;
7use smallvec::SmallVec;
8
9use crate::types::QualifiedIdentifier;
10
11/// PostgreSQL function or procedure
12///
13/// Represents a callable routine with its parameters and return type.
14#[derive(Debug, Clone)]
15pub struct Routine {
16    /// Schema name
17    pub schema: CompactString,
18    /// Function/procedure name
19    pub name: CompactString,
20    /// Description from pg_description
21    pub description: Option<String>,
22    /// Function parameters
23    pub params: SmallVec<[RoutineParam; 4]>,
24    /// Return type
25    pub return_type: ReturnType,
26    /// Volatility (immutable, stable, volatile)
27    pub volatility: Volatility,
28    /// Whether the function has a variadic parameter
29    pub is_variadic: bool,
30    /// Whether EXECUTE is allowed (for current role)
31    pub executable: bool,
32}
33
34impl Routine {
35    /// Get the qualified identifier for this routine
36    pub fn qi(&self) -> QualifiedIdentifier {
37        QualifiedIdentifier::new(self.schema.clone(), self.name.clone())
38    }
39
40    /// Check if function returns a scalar value
41    pub fn returns_scalar(&self) -> bool {
42        matches!(self.return_type, ReturnType::Single(PgType::Scalar(_)))
43    }
44
45    /// Check if function returns a set of scalar values
46    pub fn returns_set_of_scalar(&self) -> bool {
47        matches!(self.return_type, ReturnType::SetOf(PgType::Scalar(_)))
48    }
49
50    /// Check if function returns a single row (not a set)
51    pub fn returns_single(&self) -> bool {
52        matches!(self.return_type, ReturnType::Single(_))
53    }
54
55    /// Check if function returns a set of rows
56    pub fn returns_set(&self) -> bool {
57        matches!(self.return_type, ReturnType::SetOf(_))
58    }
59
60    /// Check if function returns a composite type (table row)
61    pub fn returns_composite(&self) -> bool {
62        matches!(
63            &self.return_type,
64            ReturnType::Single(PgType::Composite(_, _))
65                | ReturnType::SetOf(PgType::Composite(_, _))
66        )
67    }
68
69    /// Get the table name if function returns a composite type
70    pub fn table_name(&self) -> Option<&str> {
71        match &self.return_type {
72            ReturnType::Single(PgType::Composite(qi, _)) => Some(&qi.name),
73            ReturnType::SetOf(PgType::Composite(qi, _)) => Some(&qi.name),
74            _ => None,
75        }
76    }
77
78    /// Get the table QI if function returns a composite type
79    pub fn table_qi(&self) -> Option<&QualifiedIdentifier> {
80        match &self.return_type {
81            ReturnType::Single(PgType::Composite(qi, _)) => Some(qi),
82            ReturnType::SetOf(PgType::Composite(qi, _)) => Some(qi),
83            _ => None,
84        }
85    }
86
87    /// Check if the return type is an alias (domain type)
88    pub fn is_return_type_alias(&self) -> bool {
89        match &self.return_type {
90            ReturnType::Single(PgType::Composite(_, is_alias)) => *is_alias,
91            ReturnType::SetOf(PgType::Composite(_, is_alias)) => *is_alias,
92            _ => false,
93        }
94    }
95
96    /// Get required parameters (non-variadic, no default)
97    pub fn required_params(&self) -> impl Iterator<Item = &RoutineParam> {
98        self.params.iter().filter(|p| p.required && !p.is_variadic)
99    }
100
101    /// Get optional parameters (has default)
102    pub fn optional_params(&self) -> impl Iterator<Item = &RoutineParam> {
103        self.params.iter().filter(|p| !p.required && !p.is_variadic)
104    }
105
106    /// Get the variadic parameter if present
107    pub fn variadic_param(&self) -> Option<&RoutineParam> {
108        self.params.iter().find(|p| p.is_variadic)
109    }
110
111    /// Get parameter by name
112    pub fn get_param(&self, name: &str) -> Option<&RoutineParam> {
113        self.params.iter().find(|p| p.name.as_str() == name)
114    }
115
116    /// Count of all parameters
117    pub fn param_count(&self) -> usize {
118        self.params.len()
119    }
120
121    /// Count of required parameters
122    pub fn required_param_count(&self) -> usize {
123        self.params
124            .iter()
125            .filter(|p| p.required && !p.is_variadic)
126            .count()
127    }
128
129    /// Check if this is a volatile function
130    pub fn is_volatile(&self) -> bool {
131        matches!(self.volatility, Volatility::Volatile)
132    }
133
134    /// Check if this is a stable function
135    pub fn is_stable(&self) -> bool {
136        matches!(self.volatility, Volatility::Stable)
137    }
138
139    /// Check if this is an immutable function
140    pub fn is_immutable(&self) -> bool {
141        matches!(self.volatility, Volatility::Immutable)
142    }
143}
144
145/// Function parameter
146#[derive(Debug, Clone)]
147pub struct RoutineParam {
148    /// Parameter name
149    pub name: CompactString,
150    /// PostgreSQL type name
151    pub pg_type: CompactString,
152    /// Type with max length info (e.g., "character varying(255)")
153    pub type_max_length: CompactString,
154    /// Whether this parameter is required (no default value)
155    pub required: bool,
156    /// Whether this is a variadic parameter
157    pub is_variadic: bool,
158}
159
160impl RoutineParam {
161    /// Check if this is a text-like parameter
162    pub fn is_text_type(&self) -> bool {
163        matches!(
164            self.pg_type.as_str(),
165            "text" | "character varying" | "character" | "varchar" | "char" | "name"
166        )
167    }
168
169    /// Check if this is a numeric parameter
170    pub fn is_numeric_type(&self) -> bool {
171        matches!(
172            self.pg_type.as_str(),
173            "integer"
174                | "bigint"
175                | "smallint"
176                | "numeric"
177                | "decimal"
178                | "real"
179                | "double precision"
180                | "int"
181                | "int4"
182                | "int8"
183                | "int2"
184        )
185    }
186
187    /// Check if this is a JSON parameter
188    pub fn is_json_type(&self) -> bool {
189        matches!(self.pg_type.as_str(), "json" | "jsonb")
190    }
191}
192
193/// Function return type
194#[derive(Debug, Clone)]
195pub enum ReturnType {
196    /// Returns a single value/row
197    Single(PgType),
198    /// Returns a set of values/rows (SETOF)
199    SetOf(PgType),
200}
201
202impl ReturnType {
203    /// Get the underlying type
204    pub fn inner_type(&self) -> &PgType {
205        match self {
206            ReturnType::Single(t) => t,
207            ReturnType::SetOf(t) => t,
208        }
209    }
210
211    /// Check if this is a set-returning type
212    pub fn is_set(&self) -> bool {
213        matches!(self, ReturnType::SetOf(_))
214    }
215}
216
217/// PostgreSQL type classification
218#[derive(Debug, Clone)]
219pub enum PgType {
220    /// Scalar type (integer, text, etc.)
221    Scalar(QualifiedIdentifier),
222    /// Composite type (table row type)
223    ///
224    /// The bool indicates whether this is an alias (domain type)
225    Composite(QualifiedIdentifier, bool),
226}
227
228impl PgType {
229    /// Check if this is a scalar type
230    pub fn is_scalar(&self) -> bool {
231        matches!(self, PgType::Scalar(_))
232    }
233
234    /// Check if this is a composite type
235    pub fn is_composite(&self) -> bool {
236        matches!(self, PgType::Composite(_, _))
237    }
238
239    /// Get the type's qualified identifier
240    pub fn qi(&self) -> &QualifiedIdentifier {
241        match self {
242            PgType::Scalar(qi) => qi,
243            PgType::Composite(qi, _) => qi,
244        }
245    }
246}
247
248/// Function volatility category
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
250pub enum Volatility {
251    /// Function always returns same result for same arguments
252    Immutable,
253    /// Function returns same result within a single query
254    Stable,
255    /// Function may return different results even within same query
256    #[default]
257    Volatile,
258}
259
260impl Volatility {
261    /// Parse volatility from PostgreSQL string
262    pub fn parse(s: &str) -> Option<Self> {
263        match s.to_lowercase().as_str() {
264            "i" | "immutable" => Some(Volatility::Immutable),
265            "s" | "stable" => Some(Volatility::Stable),
266            "v" | "volatile" => Some(Volatility::Volatile),
267            _ => None,
268        }
269    }
270
271    /// Get SQL keyword for this volatility
272    pub fn as_sql(&self) -> &'static str {
273        match self {
274            Volatility::Immutable => "IMMUTABLE",
275            Volatility::Stable => "STABLE",
276            Volatility::Volatile => "VOLATILE",
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::test_helpers::*;
285
286    // ========================================================================
287    // Routine Tests
288    // ========================================================================
289
290    #[test]
291    fn test_routine_qi() {
292        let routine = test_routine().schema("api").name("get_user").build();
293
294        let qi = routine.qi();
295        assert_eq!(qi.schema.as_str(), "api");
296        assert_eq!(qi.name.as_str(), "get_user");
297    }
298
299    #[test]
300    fn test_routine_returns_scalar() {
301        let scalar_func = test_routine().returns_scalar("integer").build();
302        assert!(scalar_func.returns_scalar());
303        assert!(!scalar_func.returns_composite());
304
305        let composite_func = test_routine().returns_composite("public", "users").build();
306        assert!(!composite_func.returns_scalar());
307        assert!(composite_func.returns_composite());
308    }
309
310    #[test]
311    fn test_routine_returns_set() {
312        let single_func = test_routine().returns_scalar("integer").build();
313        assert!(single_func.returns_single());
314        assert!(!single_func.returns_set());
315
316        let set_func = test_routine().returns_setof_scalar("integer").build();
317        assert!(!set_func.returns_single());
318        assert!(set_func.returns_set());
319    }
320
321    #[test]
322    fn test_routine_returns_set_of_scalar() {
323        let func = test_routine().returns_setof_scalar("text").build();
324        assert!(func.returns_set_of_scalar());
325
326        let composite_func = test_routine()
327            .returns_setof_composite("public", "users")
328            .build();
329        assert!(!composite_func.returns_set_of_scalar());
330    }
331
332    #[test]
333    fn test_routine_table_name() {
334        let scalar_func = test_routine().returns_scalar("integer").build();
335        assert!(scalar_func.table_name().is_none());
336
337        let composite_func = test_routine().returns_composite("api", "users").build();
338        assert_eq!(composite_func.table_name(), Some("users"));
339    }
340
341    #[test]
342    fn test_routine_required_params() {
343        let p1 = test_param().name("id").required(true).build();
344        let p2 = test_param().name("name").required(false).build();
345        let p3 = test_param().name("extra").required(true).build();
346
347        let routine = test_routine().params([p1, p2, p3]).build();
348
349        let required: Vec<_> = routine.required_params().map(|p| p.name.as_str()).collect();
350        assert_eq!(required, vec!["id", "extra"]);
351    }
352
353    #[test]
354    fn test_routine_optional_params() {
355        let p1 = test_param().name("id").required(true).build();
356        let p2 = test_param().name("limit").required(false).build();
357
358        let routine = test_routine().params([p1, p2]).build();
359
360        let optional: Vec<_> = routine.optional_params().map(|p| p.name.as_str()).collect();
361        assert_eq!(optional, vec!["limit"]);
362    }
363
364    #[test]
365    fn test_routine_variadic_param() {
366        let p1 = test_param().name("id").build();
367        let p2 = test_param().name("args").is_variadic(true).build();
368
369        let routine = test_routine().params([p1, p2]).build();
370
371        let variadic = routine.variadic_param().unwrap();
372        assert_eq!(variadic.name.as_str(), "args");
373    }
374
375    #[test]
376    fn test_routine_get_param() {
377        let p1 = test_param().name("user_id").build();
378
379        let routine = test_routine().param(p1).build();
380
381        assert!(routine.get_param("user_id").is_some());
382        assert!(routine.get_param("nonexistent").is_none());
383    }
384
385    #[test]
386    fn test_routine_param_counts() {
387        let p1 = test_param().name("a").required(true).build();
388        let p2 = test_param().name("b").required(true).build();
389        let p3 = test_param().name("c").required(false).build();
390
391        let routine = test_routine().params([p1, p2, p3]).build();
392
393        assert_eq!(routine.param_count(), 3);
394        assert_eq!(routine.required_param_count(), 2);
395    }
396
397    #[test]
398    fn test_routine_volatility() {
399        let volatile_func = test_routine().volatility(Volatility::Volatile).build();
400        assert!(volatile_func.is_volatile());
401        assert!(!volatile_func.is_stable());
402        assert!(!volatile_func.is_immutable());
403
404        let stable_func = test_routine().volatility(Volatility::Stable).build();
405        assert!(!stable_func.is_volatile());
406        assert!(stable_func.is_stable());
407
408        let immutable_func = test_routine().volatility(Volatility::Immutable).build();
409        assert!(immutable_func.is_immutable());
410    }
411
412    // ========================================================================
413    // RoutineParam Tests
414    // ========================================================================
415
416    #[test]
417    fn test_routine_param_is_text_type() {
418        assert!(test_param().pg_type("text").build().is_text_type());
419        assert!(
420            test_param()
421                .pg_type("character varying")
422                .build()
423                .is_text_type()
424        );
425        assert!(!test_param().pg_type("integer").build().is_text_type());
426    }
427
428    #[test]
429    fn test_routine_param_is_numeric_type() {
430        assert!(test_param().pg_type("integer").build().is_numeric_type());
431        assert!(test_param().pg_type("bigint").build().is_numeric_type());
432        assert!(!test_param().pg_type("text").build().is_numeric_type());
433    }
434
435    #[test]
436    fn test_routine_param_is_json_type() {
437        assert!(test_param().pg_type("json").build().is_json_type());
438        assert!(test_param().pg_type("jsonb").build().is_json_type());
439        assert!(!test_param().pg_type("text").build().is_json_type());
440    }
441
442    // ========================================================================
443    // ReturnType Tests
444    // ========================================================================
445
446    #[test]
447    fn test_return_type_inner_type() {
448        let single = ReturnType::Single(PgType::Scalar(QualifiedIdentifier::new(
449            "pg_catalog",
450            "int4",
451        )));
452        assert!(single.inner_type().is_scalar());
453
454        let setof = ReturnType::SetOf(PgType::Composite(
455            QualifiedIdentifier::new("public", "users"),
456            false,
457        ));
458        assert!(setof.inner_type().is_composite());
459    }
460
461    #[test]
462    fn test_return_type_is_set() {
463        let single = ReturnType::Single(PgType::Scalar(QualifiedIdentifier::new(
464            "pg_catalog",
465            "int4",
466        )));
467        assert!(!single.is_set());
468
469        let setof = ReturnType::SetOf(PgType::Scalar(QualifiedIdentifier::new(
470            "pg_catalog",
471            "int4",
472        )));
473        assert!(setof.is_set());
474    }
475
476    // ========================================================================
477    // PgType Tests
478    // ========================================================================
479
480    #[test]
481    fn test_pg_type_is_scalar_composite() {
482        let scalar = PgType::Scalar(QualifiedIdentifier::new("pg_catalog", "int4"));
483        assert!(scalar.is_scalar());
484        assert!(!scalar.is_composite());
485
486        let composite = PgType::Composite(QualifiedIdentifier::new("public", "users"), false);
487        assert!(!composite.is_scalar());
488        assert!(composite.is_composite());
489    }
490
491    #[test]
492    fn test_pg_type_qi() {
493        let scalar = PgType::Scalar(QualifiedIdentifier::new("pg_catalog", "text"));
494        assert_eq!(scalar.qi().name.as_str(), "text");
495
496        let composite = PgType::Composite(QualifiedIdentifier::new("api", "users"), false);
497        assert_eq!(composite.qi().schema.as_str(), "api");
498        assert_eq!(composite.qi().name.as_str(), "users");
499    }
500
501    // ========================================================================
502    // Volatility Tests
503    // ========================================================================
504
505    #[test]
506    fn test_volatility_parse() {
507        assert_eq!(Volatility::parse("i"), Some(Volatility::Immutable));
508        assert_eq!(Volatility::parse("immutable"), Some(Volatility::Immutable));
509        assert_eq!(Volatility::parse("s"), Some(Volatility::Stable));
510        assert_eq!(Volatility::parse("stable"), Some(Volatility::Stable));
511        assert_eq!(Volatility::parse("v"), Some(Volatility::Volatile));
512        assert_eq!(Volatility::parse("volatile"), Some(Volatility::Volatile));
513        assert_eq!(Volatility::parse("invalid"), None);
514    }
515
516    #[test]
517    fn test_volatility_as_sql() {
518        assert_eq!(Volatility::Immutable.as_sql(), "IMMUTABLE");
519        assert_eq!(Volatility::Stable.as_sql(), "STABLE");
520        assert_eq!(Volatility::Volatile.as_sql(), "VOLATILE");
521    }
522}