postrust_core/plan/
call_plan.rs

1//! RPC (stored function) call planning.
2
3use crate::api_request::{ApiRequest, Payload, QualifiedIdentifier};
4use crate::error::{Error, Result};
5use crate::schema_cache::Routine;
6use serde::{Deserialize, Serialize};
7
8/// A plan for calling a stored function.
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct CallPlan {
11    /// Function identifier
12    pub function: QualifiedIdentifier,
13    /// Call parameters
14    pub params: CallParams,
15    /// Whether to return a scalar result
16    pub returns_scalar: bool,
17    /// Whether the function is set-returning
18    pub returns_set: bool,
19    /// Function volatility (for transaction handling)
20    pub volatility: String,
21}
22
23/// How parameters are passed to the function.
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub enum CallParams {
26    /// Named parameters from URL query or JSON body
27    Named(Vec<(String, String)>),
28    /// Positional parameters (from JSON array)
29    Positional(Vec<String>),
30    /// Single JSON object passed as first argument
31    SingleObject(bytes::Bytes),
32    /// No parameters
33    None,
34}
35
36impl CallPlan {
37    /// Create a call plan from an API request.
38    pub fn from_request(request: &ApiRequest, routine: &Routine) -> Result<Self> {
39        let qi = routine.qualified_identifier();
40
41        let params = extract_call_params(request, routine)?;
42
43        let returns_scalar = !routine.return_type.is_set_returning()
44            && routine.return_type.type_name().map(|t| !t.contains("record")).unwrap_or(true);
45
46        Ok(Self {
47            function: qi,
48            params,
49            returns_scalar,
50            returns_set: routine.return_type.is_set_returning(),
51            volatility: format!("{:?}", routine.volatility),
52        })
53    }
54
55    /// Check if this call has parameters.
56    pub fn has_params(&self) -> bool {
57        !matches!(self.params, CallParams::None)
58    }
59}
60
61/// Extract call parameters from request.
62fn extract_call_params(request: &ApiRequest, _routine: &Routine) -> Result<CallParams> {
63    // Check for JSON body first
64    if let Some(payload) = &request.payload {
65        match payload {
66            Payload::ProcessedJson { raw, .. } => {
67                // Check if it's an object or array
68                let value: serde_json::Value = serde_json::from_slice(raw)
69                    .map_err(|e| Error::InvalidBody(e.to_string()))?;
70
71                match value {
72                    serde_json::Value::Object(map) => {
73                        // Named parameters from JSON object
74                        let params: Vec<(String, String)> = map
75                            .into_iter()
76                            .map(|(k, v)| {
77                                // Extract string values without JSON quotes
78                                let value = match v {
79                                    serde_json::Value::String(s) => s,
80                                    serde_json::Value::Null => String::new(),
81                                    other => other.to_string(),
82                                };
83                                (k, value)
84                            })
85                            .collect();
86                        return Ok(CallParams::Named(params));
87                    }
88                    serde_json::Value::Array(_) => {
89                        // Pass entire JSON as single argument
90                        return Ok(CallParams::SingleObject(raw.clone()));
91                    }
92                    _ => {
93                        // Scalar value - pass as single argument
94                        return Ok(CallParams::SingleObject(raw.clone()));
95                    }
96                }
97            }
98            Payload::ProcessedUrlEncoded { data, .. } => {
99                // Named parameters from form data
100                return Ok(CallParams::Named(data.clone()));
101            }
102            Payload::RawJson(raw) | Payload::RawPayload(raw) => {
103                return Ok(CallParams::SingleObject(raw.clone()));
104            }
105        }
106    }
107
108    // Fall back to query parameters
109    if !request.query_params.params.is_empty() {
110        return Ok(CallParams::Named(request.query_params.params.clone()));
111    }
112
113    // No parameters
114    Ok(CallParams::None)
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::schema_cache::{FuncVolatility, RetType};
121
122    fn make_routine() -> Routine {
123        Routine {
124            schema: "public".into(),
125            name: "get_users".into(),
126            description: None,
127            params: vec![],
128            return_type: RetType::SetOf("users".into()),
129            volatility: FuncVolatility::Stable,
130            has_variadic: false,
131            isolation_level: None,
132            settings: vec![],
133            is_procedure: false,
134        }
135    }
136
137    #[test]
138    fn test_call_plan_basic() {
139        let request = ApiRequest::default();
140        let routine = make_routine();
141
142        let plan = CallPlan::from_request(&request, &routine).unwrap();
143
144        assert_eq!(plan.function.name, "get_users");
145        assert!(plan.returns_set);
146        assert!(!plan.returns_scalar);
147    }
148
149    #[test]
150    fn test_call_params_none() {
151        let request = ApiRequest::default();
152        let routine = make_routine();
153
154        let plan = CallPlan::from_request(&request, &routine).unwrap();
155        assert!(!plan.has_params());
156    }
157}