use super::Parameter;
use bytes::{BufMut, BytesMut};
use num::cast::ToPrimitive;
use pbjson_types::{value::Kind, ListValue, Struct};
use postgres_array::Array;
use tokio_postgres::types::{to_sql_checked, Format, IsNull, ToSql, Type};
impl ToSql for Parameter {
fn to_sql(
&self,
type_: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
if matches!(type_, &Type::JSON | &Type::JSONB) {
return serde_json::to_value(&self.0)?.to_sql(type_, out);
}
match self.encode_format(type_) {
Format::Binary => to_sql_binary(&self.0.kind, type_, out),
Format::Text => to_sql_text(&self.0.kind, type_, out),
}
}
fn accepts(_: &Type) -> bool {
true
}
fn encode_format(&self, type_: &Type) -> Format {
if should_infer(&self.0.kind, type_) {
Format::Text
} else {
Format::Binary
}
}
to_sql_checked!();
}
impl From<pbjson_types::Value> for Parameter {
fn from(value: pbjson_types::Value) -> Self {
Self(value)
}
}
fn to_sql_binary(
kind: &Option<Kind>,
type_: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
match kind {
Some(Kind::NullValue(..)) | None => Ok(IsNull::Yes),
Some(Kind::BoolValue(boolean)) => match *type_ {
Type::BOOL => boolean.to_sql(type_, out),
_ => Err(format!("Cannot encode boolean as type {type_}").into()),
},
Some(Kind::StringValue(text)) => match *type_ {
Type::TEXT | Type::VARCHAR => text.to_sql(type_, out),
_ => Err(format!("Cannot encode text '{text}' as type {type_}").into()),
},
Some(Kind::NumberValue(number)) => {
let has_fractional = number.fract() != 0.0;
match *type_ {
Type::OID if !has_fractional => number
.to_u32()
.ok_or(format!("Cannot encode {number} as {type_}"))?
.to_sql(type_, out),
Type::INT2 if !has_fractional => number
.to_i16()
.ok_or(format!("Cannot encode {number} as {type_}"))?
.to_sql(type_, out),
Type::INT4 if !has_fractional => number
.to_i32()
.ok_or(format!("Cannot encode {number} as {type_}"))?
.to_sql(type_, out),
Type::INT8 if !has_fractional => number
.to_i64()
.ok_or(format!("Cannot encode {number} as {type_}"))?
.to_sql(type_, out),
Type::FLOAT4 => number
.to_f32()
.ok_or(format!("Cannot encode {number} as {type_}"))?
.to_sql(type_, out),
Type::FLOAT8 => number.to_sql(type_, out),
_ => Err(format!("Cannot encode {number} as type {type_}").into()),
}
}
Some(Kind::ListValue(ListValue { values })) => match type_.kind() {
tokio_postgres::types::Kind::Array(array_type) => {
generate_array(array_type, values.to_owned())?.to_sql(type_, out)
}
_ => Err(format!(
"Cannot encode {} as an array of type {type_}",
serde_json::to_value(values)?
)
.into()),
},
Some(Kind::StructValue(Struct { fields })) => match type_.kind() {
tokio_postgres::types::Kind::Composite(composite_fields) => {
out.extend_from_slice(&(composite_fields.len() as i32).to_be_bytes());
for field in composite_fields {
out.extend_from_slice(&field.type_().oid().to_be_bytes());
let base = out.len();
out.extend_from_slice(&[0; 4]);
let name = field.name();
let parameter =
fields
.get(name)
.cloned()
.map(Parameter::from)
.ok_or(format!(
"Field '{name}' of composite type '{type_}' missing from {}",
serde_json::to_value(fields)?
))?;
let result = parameter.to_sql(field.type_(), out);
let count = match result? {
IsNull::Yes => -1,
IsNull::No => {
let len = out.len() - base - 4;
if len > i32::max_value() as usize {
return Result::Err(Into::into("value too large to transmit"));
}
len as i32
}
};
out[base..base + 4].copy_from_slice(&count.to_be_bytes());
}
Ok(IsNull::No)
}
_ => Err(format!(
"Cannot encode struct {} as type {type_}",
serde_json::to_value(fields)?,
)
.into()),
},
}
}
fn to_sql_text(
kind: &Option<Kind>,
type_: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
match kind {
Some(Kind::NullValue(..)) | None => Ok(IsNull::Yes),
Some(Kind::BoolValue(boolean)) => boolean.to_string().to_sql(type_, out),
Some(Kind::StringValue(text)) => text.to_sql(type_, out),
Some(Kind::NumberValue(number)) => number.to_string().to_sql(type_, out),
Some(Kind::ListValue(ListValue { values })) => {
match type_.kind() {
tokio_postgres::types::Kind::Array(..) => {
let mut values = values.iter().peekable();
out.put_slice(b"{");
while let Some(value) = values.next() {
if let Some(Kind::NullValue(..)) | None = value.kind {
out.put_slice(b"null");
} else {
to_sql_text(&value.kind, type_, out)?;
}
if values.peek().is_some() {
out.put_slice(b",");
}
}
out.put_slice(b"}");
Ok(IsNull::No)
}
_ => Err(format!(
"Cannot encode {} as type {type_}",
serde_json::to_value(values)?,
)
.into()),
}
}
Some(Kind::StructValue(Struct { fields })) => match type_.kind() {
tokio_postgres::types::Kind::Composite(composite_fields) => {
let mut composite_fields = composite_fields.iter().peekable();
out.put_slice(b"(");
while let Some(field) = composite_fields.next() {
let name = field.name();
match fields.get(name) {
Some(value) => {
if let Some(Kind::NullValue(..)) | None = value.kind {
out.put_slice(b"null");
} else {
to_sql_text(&value.kind, field.type_(), out)?;
}
if composite_fields.peek().is_some() {
out.put_slice(b",");
}
}
None => {
return Err(format!(
"Field '{name}' of composite type '{type_}' missing from {}",
serde_json::to_value(fields)?
)
.into())
}
}
}
out.put_slice(b")");
Ok(IsNull::No)
}
_ => Err(format!(
"Cannot encode struct {} as type {type_}",
serde_json::to_value(fields)?,
)
.into()),
},
}
}
fn should_infer(kind: &Option<Kind>, type_: &Type) -> bool {
if matches!(
type_,
&Type::JSON | &Type::JSONB | &Type::JSON_ARRAY | &Type::JSONB_ARRAY
) {
return false;
}
match kind {
Some(Kind::NullValue(..) | Kind::BoolValue(..)) | None => false,
Some(Kind::StringValue(..)) => !matches!(*type_, Type::TEXT | Type::VARCHAR),
Some(Kind::NumberValue(..)) => !matches!(
*type_,
Type::OID | Type::INT2 | Type::INT4 | Type::INT8 | Type::FLOAT4 | Type::FLOAT8
),
Some(Kind::ListValue(ListValue { values })) => match type_.kind() {
tokio_postgres::types::Kind::Array(array_type) => {
let mut values = values.iter().peekable();
let type_ = match values.peek().and_then(|value| value.kind.as_ref()) {
Some(Kind::ListValue(..)) => type_,
_ => array_type,
};
values.any(|value| should_infer(&value.kind, type_))
}
_ => false,
},
Some(Kind::StructValue(Struct { fields })) => match type_.kind() {
tokio_postgres::types::Kind::Composite(composite_fields) => composite_fields
.iter()
.any(|field| match fields.get(field.name()) {
Some(value) => should_infer(&value.kind, field.type_()),
None => false,
}),
_ => false,
},
}
}
fn generate_array(
array_type: &Type,
values: Vec<pbjson_types::Value>,
) -> Result<Array<Parameter>, Box<dyn std::error::Error + Sync + Send>> {
let mut values = values.into_iter().flat_map(|value| value.kind);
let array = match values.next() {
Some(Kind::ListValue(ListValue { values: first_row })) => {
let dimension = first_row.len();
let mut array = generate_array(array_type, first_row)?;
array.wrap(0);
for value in values {
match value {
Kind::ListValue(ListValue { values }) if values.len() == dimension => {
let nested_array = generate_array(array_type, values.to_owned())?;
array.push(nested_array);
}
_ => {
return Err(format!(
"Cannot encode {} as an element of {array_type}[{dimension}]",
serde_json::to_value(value)?,
)
.into())
}
}
}
array
}
Some(value) => Array::from_vec(
vec![value]
.into_iter()
.chain(values)
.map(|kind| Parameter::from(pbjson_types::Value { kind: Some(kind) }))
.collect::<Vec<_>>(),
0,
),
None => Array::from_vec(vec![], 0),
};
Ok(array)
}