use crate::api_request::{ApiRequest, Payload, QualifiedIdentifier};
use crate::error::{Error, Result};
use crate::schema_cache::Routine;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CallPlan {
pub function: QualifiedIdentifier,
pub params: CallParams,
pub returns_scalar: bool,
pub returns_set: bool,
pub volatility: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum CallParams {
Named(Vec<(String, String)>),
Positional(Vec<String>),
SingleObject(bytes::Bytes),
None,
}
impl CallPlan {
pub fn from_request(request: &ApiRequest, routine: &Routine) -> Result<Self> {
let qi = routine.qualified_identifier();
let params = extract_call_params(request, routine)?;
let returns_scalar = !routine.return_type.is_set_returning()
&& routine.return_type.type_name().map(|t| !t.contains("record")).unwrap_or(true);
Ok(Self {
function: qi,
params,
returns_scalar,
returns_set: routine.return_type.is_set_returning(),
volatility: format!("{:?}", routine.volatility),
})
}
pub fn has_params(&self) -> bool {
!matches!(self.params, CallParams::None)
}
}
fn extract_call_params(request: &ApiRequest, _routine: &Routine) -> Result<CallParams> {
if let Some(payload) = &request.payload {
match payload {
Payload::ProcessedJson { raw, .. } => {
let value: serde_json::Value = serde_json::from_slice(raw)
.map_err(|e| Error::InvalidBody(e.to_string()))?;
match value {
serde_json::Value::Object(map) => {
let params: Vec<(String, String)> = map
.into_iter()
.map(|(k, v)| {
let value = match v {
serde_json::Value::String(s) => s,
serde_json::Value::Null => String::new(),
other => other.to_string(),
};
(k, value)
})
.collect();
return Ok(CallParams::Named(params));
}
serde_json::Value::Array(_) => {
return Ok(CallParams::SingleObject(raw.clone()));
}
_ => {
return Ok(CallParams::SingleObject(raw.clone()));
}
}
}
Payload::ProcessedUrlEncoded { data, .. } => {
return Ok(CallParams::Named(data.clone()));
}
Payload::RawJson(raw) | Payload::RawPayload(raw) => {
return Ok(CallParams::SingleObject(raw.clone()));
}
}
}
if !request.query_params.params.is_empty() {
return Ok(CallParams::Named(request.query_params.params.clone()));
}
Ok(CallParams::None)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema_cache::{FuncVolatility, RetType};
fn make_routine() -> Routine {
Routine {
schema: "public".into(),
name: "get_users".into(),
description: None,
params: vec![],
return_type: RetType::SetOf("users".into()),
volatility: FuncVolatility::Stable,
has_variadic: false,
isolation_level: None,
settings: vec![],
is_procedure: false,
}
}
#[test]
fn test_call_plan_basic() {
let request = ApiRequest::default();
let routine = make_routine();
let plan = CallPlan::from_request(&request, &routine).unwrap();
assert_eq!(plan.function.name, "get_users");
assert!(plan.returns_set);
assert!(!plan.returns_scalar);
}
#[test]
fn test_call_params_none() {
let request = ApiRequest::default();
let routine = make_routine();
let plan = CallPlan::from_request(&request, &routine).unwrap();
assert!(!plan.has_params());
}
}