use compact_str::CompactString;
use smallvec::SmallVec;
use crate::types::QualifiedIdentifier;
#[derive(Debug, Clone)]
pub struct Routine {
pub schema: CompactString,
pub name: CompactString,
pub description: Option<String>,
pub params: SmallVec<[RoutineParam; 4]>,
pub return_type: ReturnType,
pub volatility: Volatility,
pub is_variadic: bool,
pub executable: bool,
}
impl Routine {
pub fn qi(&self) -> QualifiedIdentifier {
QualifiedIdentifier::new(self.schema.clone(), self.name.clone())
}
pub fn returns_scalar(&self) -> bool {
matches!(self.return_type, ReturnType::Single(PgType::Scalar(_)))
}
pub fn returns_set_of_scalar(&self) -> bool {
matches!(self.return_type, ReturnType::SetOf(PgType::Scalar(_)))
}
pub fn returns_single(&self) -> bool {
matches!(self.return_type, ReturnType::Single(_))
}
pub fn returns_set(&self) -> bool {
matches!(self.return_type, ReturnType::SetOf(_))
}
pub fn returns_composite(&self) -> bool {
matches!(
&self.return_type,
ReturnType::Single(PgType::Composite(_, _))
| ReturnType::SetOf(PgType::Composite(_, _))
)
}
pub fn table_name(&self) -> Option<&str> {
match &self.return_type {
ReturnType::Single(PgType::Composite(qi, _)) => Some(&qi.name),
ReturnType::SetOf(PgType::Composite(qi, _)) => Some(&qi.name),
_ => None,
}
}
pub fn table_qi(&self) -> Option<&QualifiedIdentifier> {
match &self.return_type {
ReturnType::Single(PgType::Composite(qi, _)) => Some(qi),
ReturnType::SetOf(PgType::Composite(qi, _)) => Some(qi),
_ => None,
}
}
pub fn is_return_type_alias(&self) -> bool {
match &self.return_type {
ReturnType::Single(PgType::Composite(_, is_alias)) => *is_alias,
ReturnType::SetOf(PgType::Composite(_, is_alias)) => *is_alias,
_ => false,
}
}
pub fn required_params(&self) -> impl Iterator<Item = &RoutineParam> {
self.params.iter().filter(|p| p.required && !p.is_variadic)
}
pub fn optional_params(&self) -> impl Iterator<Item = &RoutineParam> {
self.params.iter().filter(|p| !p.required && !p.is_variadic)
}
pub fn variadic_param(&self) -> Option<&RoutineParam> {
self.params.iter().find(|p| p.is_variadic)
}
pub fn get_param(&self, name: &str) -> Option<&RoutineParam> {
self.params.iter().find(|p| p.name.as_str() == name)
}
pub fn param_count(&self) -> usize {
self.params.len()
}
pub fn required_param_count(&self) -> usize {
self.params
.iter()
.filter(|p| p.required && !p.is_variadic)
.count()
}
pub fn is_volatile(&self) -> bool {
matches!(self.volatility, Volatility::Volatile)
}
pub fn is_stable(&self) -> bool {
matches!(self.volatility, Volatility::Stable)
}
pub fn is_immutable(&self) -> bool {
matches!(self.volatility, Volatility::Immutable)
}
}
#[derive(Debug, Clone)]
pub struct RoutineParam {
pub name: CompactString,
pub pg_type: CompactString,
pub type_max_length: CompactString,
pub required: bool,
pub is_variadic: bool,
}
impl RoutineParam {
pub fn is_text_type(&self) -> bool {
matches!(
self.pg_type.as_str(),
"text" | "character varying" | "character" | "varchar" | "char" | "name"
)
}
pub fn is_numeric_type(&self) -> bool {
matches!(
self.pg_type.as_str(),
"integer"
| "bigint"
| "smallint"
| "numeric"
| "decimal"
| "real"
| "double precision"
| "int"
| "int4"
| "int8"
| "int2"
)
}
pub fn is_json_type(&self) -> bool {
matches!(self.pg_type.as_str(), "json" | "jsonb")
}
}
#[derive(Debug, Clone)]
pub enum ReturnType {
Single(PgType),
SetOf(PgType),
}
impl ReturnType {
pub fn inner_type(&self) -> &PgType {
match self {
ReturnType::Single(t) => t,
ReturnType::SetOf(t) => t,
}
}
pub fn is_set(&self) -> bool {
matches!(self, ReturnType::SetOf(_))
}
}
#[derive(Debug, Clone)]
pub enum PgType {
Scalar(QualifiedIdentifier),
Composite(QualifiedIdentifier, bool),
}
impl PgType {
pub fn is_scalar(&self) -> bool {
matches!(self, PgType::Scalar(_))
}
pub fn is_composite(&self) -> bool {
matches!(self, PgType::Composite(_, _))
}
pub fn qi(&self) -> &QualifiedIdentifier {
match self {
PgType::Scalar(qi) => qi,
PgType::Composite(qi, _) => qi,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Volatility {
Immutable,
Stable,
#[default]
Volatile,
}
impl Volatility {
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"i" | "immutable" => Some(Volatility::Immutable),
"s" | "stable" => Some(Volatility::Stable),
"v" | "volatile" => Some(Volatility::Volatile),
_ => None,
}
}
pub fn as_sql(&self) -> &'static str {
match self {
Volatility::Immutable => "IMMUTABLE",
Volatility::Stable => "STABLE",
Volatility::Volatile => "VOLATILE",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
#[test]
fn test_routine_qi() {
let routine = test_routine().schema("api").name("get_user").build();
let qi = routine.qi();
assert_eq!(qi.schema.as_str(), "api");
assert_eq!(qi.name.as_str(), "get_user");
}
#[test]
fn test_routine_returns_scalar() {
let scalar_func = test_routine().returns_scalar("integer").build();
assert!(scalar_func.returns_scalar());
assert!(!scalar_func.returns_composite());
let composite_func = test_routine().returns_composite("public", "users").build();
assert!(!composite_func.returns_scalar());
assert!(composite_func.returns_composite());
}
#[test]
fn test_routine_returns_set() {
let single_func = test_routine().returns_scalar("integer").build();
assert!(single_func.returns_single());
assert!(!single_func.returns_set());
let set_func = test_routine().returns_setof_scalar("integer").build();
assert!(!set_func.returns_single());
assert!(set_func.returns_set());
}
#[test]
fn test_routine_returns_set_of_scalar() {
let func = test_routine().returns_setof_scalar("text").build();
assert!(func.returns_set_of_scalar());
let composite_func = test_routine()
.returns_setof_composite("public", "users")
.build();
assert!(!composite_func.returns_set_of_scalar());
}
#[test]
fn test_routine_table_name() {
let scalar_func = test_routine().returns_scalar("integer").build();
assert!(scalar_func.table_name().is_none());
let composite_func = test_routine().returns_composite("api", "users").build();
assert_eq!(composite_func.table_name(), Some("users"));
}
#[test]
fn test_routine_required_params() {
let p1 = test_param().name("id").required(true).build();
let p2 = test_param().name("name").required(false).build();
let p3 = test_param().name("extra").required(true).build();
let routine = test_routine().params([p1, p2, p3]).build();
let required: Vec<_> = routine.required_params().map(|p| p.name.as_str()).collect();
assert_eq!(required, vec!["id", "extra"]);
}
#[test]
fn test_routine_optional_params() {
let p1 = test_param().name("id").required(true).build();
let p2 = test_param().name("limit").required(false).build();
let routine = test_routine().params([p1, p2]).build();
let optional: Vec<_> = routine.optional_params().map(|p| p.name.as_str()).collect();
assert_eq!(optional, vec!["limit"]);
}
#[test]
fn test_routine_variadic_param() {
let p1 = test_param().name("id").build();
let p2 = test_param().name("args").is_variadic(true).build();
let routine = test_routine().params([p1, p2]).build();
let variadic = routine.variadic_param().unwrap();
assert_eq!(variadic.name.as_str(), "args");
}
#[test]
fn test_routine_get_param() {
let p1 = test_param().name("user_id").build();
let routine = test_routine().param(p1).build();
assert!(routine.get_param("user_id").is_some());
assert!(routine.get_param("nonexistent").is_none());
}
#[test]
fn test_routine_param_counts() {
let p1 = test_param().name("a").required(true).build();
let p2 = test_param().name("b").required(true).build();
let p3 = test_param().name("c").required(false).build();
let routine = test_routine().params([p1, p2, p3]).build();
assert_eq!(routine.param_count(), 3);
assert_eq!(routine.required_param_count(), 2);
}
#[test]
fn test_routine_volatility() {
let volatile_func = test_routine().volatility(Volatility::Volatile).build();
assert!(volatile_func.is_volatile());
assert!(!volatile_func.is_stable());
assert!(!volatile_func.is_immutable());
let stable_func = test_routine().volatility(Volatility::Stable).build();
assert!(!stable_func.is_volatile());
assert!(stable_func.is_stable());
let immutable_func = test_routine().volatility(Volatility::Immutable).build();
assert!(immutable_func.is_immutable());
}
#[test]
fn test_routine_param_is_text_type() {
assert!(test_param().pg_type("text").build().is_text_type());
assert!(
test_param()
.pg_type("character varying")
.build()
.is_text_type()
);
assert!(!test_param().pg_type("integer").build().is_text_type());
}
#[test]
fn test_routine_param_is_numeric_type() {
assert!(test_param().pg_type("integer").build().is_numeric_type());
assert!(test_param().pg_type("bigint").build().is_numeric_type());
assert!(!test_param().pg_type("text").build().is_numeric_type());
}
#[test]
fn test_routine_param_is_json_type() {
assert!(test_param().pg_type("json").build().is_json_type());
assert!(test_param().pg_type("jsonb").build().is_json_type());
assert!(!test_param().pg_type("text").build().is_json_type());
}
#[test]
fn test_return_type_inner_type() {
let single = ReturnType::Single(PgType::Scalar(QualifiedIdentifier::new(
"pg_catalog",
"int4",
)));
assert!(single.inner_type().is_scalar());
let setof = ReturnType::SetOf(PgType::Composite(
QualifiedIdentifier::new("public", "users"),
false,
));
assert!(setof.inner_type().is_composite());
}
#[test]
fn test_return_type_is_set() {
let single = ReturnType::Single(PgType::Scalar(QualifiedIdentifier::new(
"pg_catalog",
"int4",
)));
assert!(!single.is_set());
let setof = ReturnType::SetOf(PgType::Scalar(QualifiedIdentifier::new(
"pg_catalog",
"int4",
)));
assert!(setof.is_set());
}
#[test]
fn test_pg_type_is_scalar_composite() {
let scalar = PgType::Scalar(QualifiedIdentifier::new("pg_catalog", "int4"));
assert!(scalar.is_scalar());
assert!(!scalar.is_composite());
let composite = PgType::Composite(QualifiedIdentifier::new("public", "users"), false);
assert!(!composite.is_scalar());
assert!(composite.is_composite());
}
#[test]
fn test_pg_type_qi() {
let scalar = PgType::Scalar(QualifiedIdentifier::new("pg_catalog", "text"));
assert_eq!(scalar.qi().name.as_str(), "text");
let composite = PgType::Composite(QualifiedIdentifier::new("api", "users"), false);
assert_eq!(composite.qi().schema.as_str(), "api");
assert_eq!(composite.qi().name.as_str(), "users");
}
#[test]
fn test_volatility_parse() {
assert_eq!(Volatility::parse("i"), Some(Volatility::Immutable));
assert_eq!(Volatility::parse("immutable"), Some(Volatility::Immutable));
assert_eq!(Volatility::parse("s"), Some(Volatility::Stable));
assert_eq!(Volatility::parse("stable"), Some(Volatility::Stable));
assert_eq!(Volatility::parse("v"), Some(Volatility::Volatile));
assert_eq!(Volatility::parse("volatile"), Some(Volatility::Volatile));
assert_eq!(Volatility::parse("invalid"), None);
}
#[test]
fn test_volatility_as_sql() {
assert_eq!(Volatility::Immutable.as_sql(), "IMMUTABLE");
assert_eq!(Volatility::Stable.as_sql(), "STABLE");
assert_eq!(Volatility::Volatile.as_sql(), "VOLATILE");
}
}