postrust_core/plan/
call_plan.rs1use crate::api_request::{ApiRequest, Payload, QualifiedIdentifier};
4use crate::error::{Error, Result};
5use crate::schema_cache::Routine;
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct CallPlan {
11 pub function: QualifiedIdentifier,
13 pub params: CallParams,
15 pub returns_scalar: bool,
17 pub returns_set: bool,
19 pub volatility: String,
21}
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
25pub enum CallParams {
26 Named(Vec<(String, String)>),
28 Positional(Vec<String>),
30 SingleObject(bytes::Bytes),
32 None,
34}
35
36impl CallPlan {
37 pub fn from_request(request: &ApiRequest, routine: &Routine) -> Result<Self> {
39 let qi = routine.qualified_identifier();
40
41 let params = extract_call_params(request, routine)?;
42
43 let returns_scalar = !routine.return_type.is_set_returning()
44 && routine.return_type.type_name().map(|t| !t.contains("record")).unwrap_or(true);
45
46 Ok(Self {
47 function: qi,
48 params,
49 returns_scalar,
50 returns_set: routine.return_type.is_set_returning(),
51 volatility: format!("{:?}", routine.volatility),
52 })
53 }
54
55 pub fn has_params(&self) -> bool {
57 !matches!(self.params, CallParams::None)
58 }
59}
60
61fn extract_call_params(request: &ApiRequest, _routine: &Routine) -> Result<CallParams> {
63 if let Some(payload) = &request.payload {
65 match payload {
66 Payload::ProcessedJson { raw, .. } => {
67 let value: serde_json::Value = serde_json::from_slice(raw)
69 .map_err(|e| Error::InvalidBody(e.to_string()))?;
70
71 match value {
72 serde_json::Value::Object(map) => {
73 let params: Vec<(String, String)> = map
75 .into_iter()
76 .map(|(k, v)| {
77 let value = match v {
79 serde_json::Value::String(s) => s,
80 serde_json::Value::Null => String::new(),
81 other => other.to_string(),
82 };
83 (k, value)
84 })
85 .collect();
86 return Ok(CallParams::Named(params));
87 }
88 serde_json::Value::Array(_) => {
89 return Ok(CallParams::SingleObject(raw.clone()));
91 }
92 _ => {
93 return Ok(CallParams::SingleObject(raw.clone()));
95 }
96 }
97 }
98 Payload::ProcessedUrlEncoded { data, .. } => {
99 return Ok(CallParams::Named(data.clone()));
101 }
102 Payload::RawJson(raw) | Payload::RawPayload(raw) => {
103 return Ok(CallParams::SingleObject(raw.clone()));
104 }
105 }
106 }
107
108 if !request.query_params.params.is_empty() {
110 return Ok(CallParams::Named(request.query_params.params.clone()));
111 }
112
113 Ok(CallParams::None)
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::schema_cache::{FuncVolatility, RetType};
121
122 fn make_routine() -> Routine {
123 Routine {
124 schema: "public".into(),
125 name: "get_users".into(),
126 description: None,
127 params: vec![],
128 return_type: RetType::SetOf("users".into()),
129 volatility: FuncVolatility::Stable,
130 has_variadic: false,
131 isolation_level: None,
132 settings: vec![],
133 is_procedure: false,
134 }
135 }
136
137 #[test]
138 fn test_call_plan_basic() {
139 let request = ApiRequest::default();
140 let routine = make_routine();
141
142 let plan = CallPlan::from_request(&request, &routine).unwrap();
143
144 assert_eq!(plan.function.name, "get_users");
145 assert!(plan.returns_set);
146 assert!(!plan.returns_scalar);
147 }
148
149 #[test]
150 fn test_call_params_none() {
151 let request = ApiRequest::default();
152 let routine = make_routine();
153
154 let plan = CallPlan::from_request(&request, &routine).unwrap();
155 assert!(!plan.has_params());
156 }
157}