postrust_core/plan/
call_plan.rs1use crate::api_request::{ApiRequest, Payload, QualifiedIdentifier};
4use crate::error::{Error, Result};
5use crate::schema_cache::Routine;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct CallPlan {
11 pub function: QualifiedIdentifier,
13 pub params: CallParams,
15 pub returns_scalar: bool,
17 pub returns_set: bool,
19 pub volatility: String,
21}
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
25pub enum CallParams {
26 Named(Vec<(String, String)>),
28 Positional(Vec<String>),
30 SingleObject(bytes::Bytes),
32 None,
34}
35
36impl CallPlan {
37 pub fn from_request(request: &ApiRequest, routine: &Routine) -> Result<Self> {
39 let qi = routine.qualified_identifier();
40
41 let params = extract_call_params(request, routine)?;
42
43 let returns_scalar = !routine.return_type.is_set_returning()
44 && routine.return_type.type_name().map(|t| !t.contains("record")).unwrap_or(true);
45
46 Ok(Self {
47 function: qi,
48 params,
49 returns_scalar,
50 returns_set: routine.return_type.is_set_returning(),
51 volatility: format!("{:?}", routine.volatility),
52 })
53 }
54
55 pub fn has_params(&self) -> bool {
57 !matches!(self.params, CallParams::None)
58 }
59}
60
61fn extract_call_params(request: &ApiRequest, _routine: &Routine) -> Result<CallParams> {
63 if let Some(payload) = &request.payload {
65 match payload {
66 Payload::ProcessedJson { raw, .. } => {
67 let value: serde_json::Value = serde_json::from_slice(raw)
69 .map_err(|e| Error::InvalidBody(e.to_string()))?;
70
71 match value {
72 serde_json::Value::Object(map) => {
73 let params: Vec<(String, String)> = map
75 .into_iter()
76 .map(|(k, v)| (k, v.to_string()))
77 .collect();
78 return Ok(CallParams::Named(params));
79 }
80 serde_json::Value::Array(_) => {
81 return Ok(CallParams::SingleObject(raw.clone()));
83 }
84 _ => {
85 return Ok(CallParams::SingleObject(raw.clone()));
87 }
88 }
89 }
90 Payload::ProcessedUrlEncoded { data, .. } => {
91 return Ok(CallParams::Named(data.clone()));
93 }
94 Payload::RawJson(raw) | Payload::RawPayload(raw) => {
95 return Ok(CallParams::SingleObject(raw.clone()));
96 }
97 }
98 }
99
100 if !request.query_params.params.is_empty() {
102 return Ok(CallParams::Named(request.query_params.params.clone()));
103 }
104
105 Ok(CallParams::None)
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::schema_cache::{FuncVolatility, RetType};
113
114 fn make_routine() -> Routine {
115 Routine {
116 schema: "public".into(),
117 name: "get_users".into(),
118 description: None,
119 params: vec![],
120 return_type: RetType::SetOf("users".into()),
121 volatility: FuncVolatility::Stable,
122 has_variadic: false,
123 isolation_level: None,
124 settings: vec![],
125 is_procedure: false,
126 }
127 }
128
129 #[test]
130 fn test_call_plan_basic() {
131 let request = ApiRequest::default();
132 let routine = make_routine();
133
134 let plan = CallPlan::from_request(&request, &routine).unwrap();
135
136 assert_eq!(plan.function.name, "get_users");
137 assert!(plan.returns_set);
138 assert!(!plan.returns_scalar);
139 }
140
141 #[test]
142 fn test_call_params_none() {
143 let request = ApiRequest::default();
144 let routine = make_routine();
145
146 let plan = CallPlan::from_request(&request, &routine).unwrap();
147 assert!(!plan.has_params());
148 }
149}