dbrest_core/schema_cache/
routine.rs1use compact_str::CompactString;
7use smallvec::SmallVec;
8
9use crate::types::QualifiedIdentifier;
10
11#[derive(Debug, Clone)]
15pub struct Routine {
16 pub schema: CompactString,
18 pub name: CompactString,
20 pub description: Option<String>,
22 pub params: SmallVec<[RoutineParam; 4]>,
24 pub return_type: ReturnType,
26 pub volatility: Volatility,
28 pub is_variadic: bool,
30 pub executable: bool,
32}
33
34impl Routine {
35 pub fn qi(&self) -> QualifiedIdentifier {
37 QualifiedIdentifier::new(self.schema.clone(), self.name.clone())
38 }
39
40 pub fn returns_scalar(&self) -> bool {
42 matches!(self.return_type, ReturnType::Single(PgType::Scalar(_)))
43 }
44
45 pub fn returns_set_of_scalar(&self) -> bool {
47 matches!(self.return_type, ReturnType::SetOf(PgType::Scalar(_)))
48 }
49
50 pub fn returns_single(&self) -> bool {
52 matches!(self.return_type, ReturnType::Single(_))
53 }
54
55 pub fn returns_set(&self) -> bool {
57 matches!(self.return_type, ReturnType::SetOf(_))
58 }
59
60 pub fn returns_composite(&self) -> bool {
62 matches!(
63 &self.return_type,
64 ReturnType::Single(PgType::Composite(_, _))
65 | ReturnType::SetOf(PgType::Composite(_, _))
66 )
67 }
68
69 pub fn table_name(&self) -> Option<&str> {
71 match &self.return_type {
72 ReturnType::Single(PgType::Composite(qi, _)) => Some(&qi.name),
73 ReturnType::SetOf(PgType::Composite(qi, _)) => Some(&qi.name),
74 _ => None,
75 }
76 }
77
78 pub fn table_qi(&self) -> Option<&QualifiedIdentifier> {
80 match &self.return_type {
81 ReturnType::Single(PgType::Composite(qi, _)) => Some(qi),
82 ReturnType::SetOf(PgType::Composite(qi, _)) => Some(qi),
83 _ => None,
84 }
85 }
86
87 pub fn is_return_type_alias(&self) -> bool {
89 match &self.return_type {
90 ReturnType::Single(PgType::Composite(_, is_alias)) => *is_alias,
91 ReturnType::SetOf(PgType::Composite(_, is_alias)) => *is_alias,
92 _ => false,
93 }
94 }
95
96 pub fn required_params(&self) -> impl Iterator<Item = &RoutineParam> {
98 self.params.iter().filter(|p| p.required && !p.is_variadic)
99 }
100
101 pub fn optional_params(&self) -> impl Iterator<Item = &RoutineParam> {
103 self.params.iter().filter(|p| !p.required && !p.is_variadic)
104 }
105
106 pub fn variadic_param(&self) -> Option<&RoutineParam> {
108 self.params.iter().find(|p| p.is_variadic)
109 }
110
111 pub fn get_param(&self, name: &str) -> Option<&RoutineParam> {
113 self.params.iter().find(|p| p.name.as_str() == name)
114 }
115
116 pub fn param_count(&self) -> usize {
118 self.params.len()
119 }
120
121 pub fn required_param_count(&self) -> usize {
123 self.params
124 .iter()
125 .filter(|p| p.required && !p.is_variadic)
126 .count()
127 }
128
129 pub fn is_volatile(&self) -> bool {
131 matches!(self.volatility, Volatility::Volatile)
132 }
133
134 pub fn is_stable(&self) -> bool {
136 matches!(self.volatility, Volatility::Stable)
137 }
138
139 pub fn is_immutable(&self) -> bool {
141 matches!(self.volatility, Volatility::Immutable)
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct RoutineParam {
148 pub name: CompactString,
150 pub pg_type: CompactString,
152 pub type_max_length: CompactString,
154 pub required: bool,
156 pub is_variadic: bool,
158}
159
160impl RoutineParam {
161 pub fn is_text_type(&self) -> bool {
163 matches!(
164 self.pg_type.as_str(),
165 "text" | "character varying" | "character" | "varchar" | "char" | "name"
166 )
167 }
168
169 pub fn is_numeric_type(&self) -> bool {
171 matches!(
172 self.pg_type.as_str(),
173 "integer"
174 | "bigint"
175 | "smallint"
176 | "numeric"
177 | "decimal"
178 | "real"
179 | "double precision"
180 | "int"
181 | "int4"
182 | "int8"
183 | "int2"
184 )
185 }
186
187 pub fn is_json_type(&self) -> bool {
189 matches!(self.pg_type.as_str(), "json" | "jsonb")
190 }
191}
192
193#[derive(Debug, Clone)]
195pub enum ReturnType {
196 Single(PgType),
198 SetOf(PgType),
200}
201
202impl ReturnType {
203 pub fn inner_type(&self) -> &PgType {
205 match self {
206 ReturnType::Single(t) => t,
207 ReturnType::SetOf(t) => t,
208 }
209 }
210
211 pub fn is_set(&self) -> bool {
213 matches!(self, ReturnType::SetOf(_))
214 }
215}
216
217#[derive(Debug, Clone)]
219pub enum PgType {
220 Scalar(QualifiedIdentifier),
222 Composite(QualifiedIdentifier, bool),
226}
227
228impl PgType {
229 pub fn is_scalar(&self) -> bool {
231 matches!(self, PgType::Scalar(_))
232 }
233
234 pub fn is_composite(&self) -> bool {
236 matches!(self, PgType::Composite(_, _))
237 }
238
239 pub fn qi(&self) -> &QualifiedIdentifier {
241 match self {
242 PgType::Scalar(qi) => qi,
243 PgType::Composite(qi, _) => qi,
244 }
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
250pub enum Volatility {
251 Immutable,
253 Stable,
255 #[default]
257 Volatile,
258}
259
260impl Volatility {
261 pub fn parse(s: &str) -> Option<Self> {
263 match s.to_lowercase().as_str() {
264 "i" | "immutable" => Some(Volatility::Immutable),
265 "s" | "stable" => Some(Volatility::Stable),
266 "v" | "volatile" => Some(Volatility::Volatile),
267 _ => None,
268 }
269 }
270
271 pub fn as_sql(&self) -> &'static str {
273 match self {
274 Volatility::Immutable => "IMMUTABLE",
275 Volatility::Stable => "STABLE",
276 Volatility::Volatile => "VOLATILE",
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::test_helpers::*;
285
286 #[test]
291 fn test_routine_qi() {
292 let routine = test_routine().schema("api").name("get_user").build();
293
294 let qi = routine.qi();
295 assert_eq!(qi.schema.as_str(), "api");
296 assert_eq!(qi.name.as_str(), "get_user");
297 }
298
299 #[test]
300 fn test_routine_returns_scalar() {
301 let scalar_func = test_routine().returns_scalar("integer").build();
302 assert!(scalar_func.returns_scalar());
303 assert!(!scalar_func.returns_composite());
304
305 let composite_func = test_routine().returns_composite("public", "users").build();
306 assert!(!composite_func.returns_scalar());
307 assert!(composite_func.returns_composite());
308 }
309
310 #[test]
311 fn test_routine_returns_set() {
312 let single_func = test_routine().returns_scalar("integer").build();
313 assert!(single_func.returns_single());
314 assert!(!single_func.returns_set());
315
316 let set_func = test_routine().returns_setof_scalar("integer").build();
317 assert!(!set_func.returns_single());
318 assert!(set_func.returns_set());
319 }
320
321 #[test]
322 fn test_routine_returns_set_of_scalar() {
323 let func = test_routine().returns_setof_scalar("text").build();
324 assert!(func.returns_set_of_scalar());
325
326 let composite_func = test_routine()
327 .returns_setof_composite("public", "users")
328 .build();
329 assert!(!composite_func.returns_set_of_scalar());
330 }
331
332 #[test]
333 fn test_routine_table_name() {
334 let scalar_func = test_routine().returns_scalar("integer").build();
335 assert!(scalar_func.table_name().is_none());
336
337 let composite_func = test_routine().returns_composite("api", "users").build();
338 assert_eq!(composite_func.table_name(), Some("users"));
339 }
340
341 #[test]
342 fn test_routine_required_params() {
343 let p1 = test_param().name("id").required(true).build();
344 let p2 = test_param().name("name").required(false).build();
345 let p3 = test_param().name("extra").required(true).build();
346
347 let routine = test_routine().params([p1, p2, p3]).build();
348
349 let required: Vec<_> = routine.required_params().map(|p| p.name.as_str()).collect();
350 assert_eq!(required, vec!["id", "extra"]);
351 }
352
353 #[test]
354 fn test_routine_optional_params() {
355 let p1 = test_param().name("id").required(true).build();
356 let p2 = test_param().name("limit").required(false).build();
357
358 let routine = test_routine().params([p1, p2]).build();
359
360 let optional: Vec<_> = routine.optional_params().map(|p| p.name.as_str()).collect();
361 assert_eq!(optional, vec!["limit"]);
362 }
363
364 #[test]
365 fn test_routine_variadic_param() {
366 let p1 = test_param().name("id").build();
367 let p2 = test_param().name("args").is_variadic(true).build();
368
369 let routine = test_routine().params([p1, p2]).build();
370
371 let variadic = routine.variadic_param().unwrap();
372 assert_eq!(variadic.name.as_str(), "args");
373 }
374
375 #[test]
376 fn test_routine_get_param() {
377 let p1 = test_param().name("user_id").build();
378
379 let routine = test_routine().param(p1).build();
380
381 assert!(routine.get_param("user_id").is_some());
382 assert!(routine.get_param("nonexistent").is_none());
383 }
384
385 #[test]
386 fn test_routine_param_counts() {
387 let p1 = test_param().name("a").required(true).build();
388 let p2 = test_param().name("b").required(true).build();
389 let p3 = test_param().name("c").required(false).build();
390
391 let routine = test_routine().params([p1, p2, p3]).build();
392
393 assert_eq!(routine.param_count(), 3);
394 assert_eq!(routine.required_param_count(), 2);
395 }
396
397 #[test]
398 fn test_routine_volatility() {
399 let volatile_func = test_routine().volatility(Volatility::Volatile).build();
400 assert!(volatile_func.is_volatile());
401 assert!(!volatile_func.is_stable());
402 assert!(!volatile_func.is_immutable());
403
404 let stable_func = test_routine().volatility(Volatility::Stable).build();
405 assert!(!stable_func.is_volatile());
406 assert!(stable_func.is_stable());
407
408 let immutable_func = test_routine().volatility(Volatility::Immutable).build();
409 assert!(immutable_func.is_immutable());
410 }
411
412 #[test]
417 fn test_routine_param_is_text_type() {
418 assert!(test_param().pg_type("text").build().is_text_type());
419 assert!(
420 test_param()
421 .pg_type("character varying")
422 .build()
423 .is_text_type()
424 );
425 assert!(!test_param().pg_type("integer").build().is_text_type());
426 }
427
428 #[test]
429 fn test_routine_param_is_numeric_type() {
430 assert!(test_param().pg_type("integer").build().is_numeric_type());
431 assert!(test_param().pg_type("bigint").build().is_numeric_type());
432 assert!(!test_param().pg_type("text").build().is_numeric_type());
433 }
434
435 #[test]
436 fn test_routine_param_is_json_type() {
437 assert!(test_param().pg_type("json").build().is_json_type());
438 assert!(test_param().pg_type("jsonb").build().is_json_type());
439 assert!(!test_param().pg_type("text").build().is_json_type());
440 }
441
442 #[test]
447 fn test_return_type_inner_type() {
448 let single = ReturnType::Single(PgType::Scalar(QualifiedIdentifier::new(
449 "pg_catalog",
450 "int4",
451 )));
452 assert!(single.inner_type().is_scalar());
453
454 let setof = ReturnType::SetOf(PgType::Composite(
455 QualifiedIdentifier::new("public", "users"),
456 false,
457 ));
458 assert!(setof.inner_type().is_composite());
459 }
460
461 #[test]
462 fn test_return_type_is_set() {
463 let single = ReturnType::Single(PgType::Scalar(QualifiedIdentifier::new(
464 "pg_catalog",
465 "int4",
466 )));
467 assert!(!single.is_set());
468
469 let setof = ReturnType::SetOf(PgType::Scalar(QualifiedIdentifier::new(
470 "pg_catalog",
471 "int4",
472 )));
473 assert!(setof.is_set());
474 }
475
476 #[test]
481 fn test_pg_type_is_scalar_composite() {
482 let scalar = PgType::Scalar(QualifiedIdentifier::new("pg_catalog", "int4"));
483 assert!(scalar.is_scalar());
484 assert!(!scalar.is_composite());
485
486 let composite = PgType::Composite(QualifiedIdentifier::new("public", "users"), false);
487 assert!(!composite.is_scalar());
488 assert!(composite.is_composite());
489 }
490
491 #[test]
492 fn test_pg_type_qi() {
493 let scalar = PgType::Scalar(QualifiedIdentifier::new("pg_catalog", "text"));
494 assert_eq!(scalar.qi().name.as_str(), "text");
495
496 let composite = PgType::Composite(QualifiedIdentifier::new("api", "users"), false);
497 assert_eq!(composite.qi().schema.as_str(), "api");
498 assert_eq!(composite.qi().name.as_str(), "users");
499 }
500
501 #[test]
506 fn test_volatility_parse() {
507 assert_eq!(Volatility::parse("i"), Some(Volatility::Immutable));
508 assert_eq!(Volatility::parse("immutable"), Some(Volatility::Immutable));
509 assert_eq!(Volatility::parse("s"), Some(Volatility::Stable));
510 assert_eq!(Volatility::parse("stable"), Some(Volatility::Stable));
511 assert_eq!(Volatility::parse("v"), Some(Volatility::Volatile));
512 assert_eq!(Volatility::parse("volatile"), Some(Volatility::Volatile));
513 assert_eq!(Volatility::parse("invalid"), None);
514 }
515
516 #[test]
517 fn test_volatility_as_sql() {
518 assert_eq!(Volatility::Immutable.as_sql(), "IMMUTABLE");
519 assert_eq!(Volatility::Stable.as_sql(), "STABLE");
520 assert_eq!(Volatility::Volatile.as_sql(), "VOLATILE");
521 }
522}