puff-rs 0.1.8

Puff - Deep Stack Python Runtime and GraphQL library.
Documentation
use crate::graphql::scalar::{AggroScalarValue, AggroValue};

use crate::errors::PuffResult;
use crate::python::postgres::column_to_python;
use crate::types::{Bytes, Text, UtcDateTime};
use anyhow::{anyhow, bail, Result};
use juniper::Object;
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use std::collections::HashMap;
use tokio_postgres::{Column, Row, Statement};

pub fn convert_pyany_to_jupiter(attribute_val: &PyAny) -> AggroValue {
    if attribute_val.is_none() {
        return AggroValue::Null;
    }
    if let Ok(s) = attribute_val.extract() {
        return AggroValue::Scalar(AggroScalarValue::String(s));
    }
    if let Ok(s) = attribute_val.extract() {
        return AggroValue::Scalar(AggroScalarValue::Boolean(s));
    }
    if let Ok(r) = attribute_val.extract::<i64>() {
        return if let Ok(s) = r.try_into() {
            AggroValue::Scalar(AggroScalarValue::Int(s))
        } else {
            AggroValue::Scalar(AggroScalarValue::Long(r))
        };
    }
    if let Ok(s) = attribute_val.extract() {
        return AggroValue::Scalar(AggroScalarValue::Float(s));
    }
    if let Ok(l) = attribute_val.extract::<&PyList>() {
        return AggroValue::List(l.into_iter().map(|s| convert_pyany_to_jupiter(s)).collect());
    }
    if let Ok(l) = attribute_val.extract::<&PyDict>() {
        return AggroValue::Object(
            l.into_iter()
                .map(|(k, s)| (k.to_string(), convert_pyany_to_jupiter(s)))
                .collect(),
        );
    }
    let size = attribute_val.dir().len();
    let mut obj = Object::with_capacity(size);

    for v in attribute_val.dir().iter() {
        if let Ok(s) = v.downcast::<PyString>() {
            let key = s.to_str().expect("Python dir could not unwrap key to str.");
            if !key.starts_with("__") {
                let py_val = attribute_val
                    .getattr(key)
                    .expect(&format!("Could not get {}", key));
                if !py_val.is_callable() {
                    let juniper_val = convert_pyany_to_jupiter(py_val);
                    obj.add_field(key, juniper_val);
                }
            }
        }
    }
    AggroValue::Object(obj)
}

pub fn convert_postgres_to_juniper(
    r: &Row,
    column_index: usize,
    t: &tokio_postgres::types::Type,
) -> Result<AggroValue> {
    match t {
        &tokio_postgres::types::Type::BOOL => {
            if let Some(r) = r.get::<_, Option<bool>>(column_index) {
                Ok(AggroValue::scalar(r))
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::INT2 | &tokio_postgres::types::Type::INT4 => {
            if let Some(r) = r.get::<_, Option<i32>>(column_index) {
                Ok(AggroValue::scalar(r))
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::INT8 => {
            if let Some(r) = r.get::<_, Option<i64>>(column_index) {
                Ok(if let Ok(s) = r.try_into() {
                    AggroValue::Scalar(AggroScalarValue::Int(s))
                } else {
                    AggroValue::Scalar(AggroScalarValue::Long(r))
                })
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::BYTEA => {
            if let Some(r) = r.get::<_, Option<&[u8]>>(column_index) {
                Ok(AggroValue::Scalar(AggroScalarValue::Binary(
                    Bytes::copy_from_slice(r),
                )))
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::TIMESTAMP => {
            if let Some(r) = r.get::<_, Option<_>>(column_index) {
                Ok(AggroValue::Scalar(AggroScalarValue::Datetime(
                    UtcDateTime::new(r),
                )))
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::UUID => {
            if let Some(r) = r.get::<_, Option<_>>(column_index) {
                Ok(AggroValue::Scalar(AggroScalarValue::Uuid(r)))
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::FLOAT4 | &tokio_postgres::types::Type::FLOAT8 => {
            if let Some(r) = r.get::<_, Option<f64>>(column_index) {
                Ok(AggroValue::scalar(r))
            } else {
                Ok(AggroValue::null())
            }
        }
        &tokio_postgres::types::Type::TEXT | &tokio_postgres::types::Type::VARCHAR => {
            if let Some(r) = r.get::<_, Option<&str>>(column_index) {
                Ok(AggroValue::scalar(r))
            } else {
                Ok(AggroValue::null())
            }
        }
        t => {
            panic!("Unsupported postgres type {}", t)
        }
    }
}

pub trait ExtractValues {
    fn len(&self) -> usize;
    fn extract_values(&self, names: &[Text]) -> Result<Vec<Option<Vec<AggroValue>>>>;
    fn extract_py_values(&self, py: Python, names: &[&PyString]) -> PuffResult<PyObject>;
    fn extract_first(&self) -> Result<Vec<AggroValue>>;
}

pub struct PostgresResultRows {
    pub statement: Statement,
    pub rows: Vec<Row>,
}

impl ExtractValues for PostgresResultRows {
    fn len(&self) -> usize {
        self.rows.len()
    }

    fn extract_values(&self, names: &[Text]) -> Result<Vec<Option<Vec<AggroValue>>>> {
        let field_mapping: HashMap<&str, (usize, &Column)> = names
            .iter()
            .flat_map(|field_name| {
                self.statement
                    .columns()
                    .iter()
                    .enumerate()
                    .find_map(|(ix, c)| {
                        if c.name() == field_name.as_str() {
                            Some((c.name(), (ix, c)))
                        } else {
                            None
                        }
                    })
            })
            .collect();

        let mut ret_vec = Vec::with_capacity(self.rows.len());
        for row in &self.rows {
            let mut row_vec = Vec::with_capacity(names.len());
            for name in names {
                let (column_ix, c) = if let Some((column_ix, c)) = field_mapping.get(name.as_str())
                {
                    (column_ix, *c)
                } else {
                    bail!("Could not find {} in Postgres row", name)
                };
                let field_val = convert_postgres_to_juniper(row, *column_ix, c.type_())?;
                row_vec.push(field_val);
            }
            ret_vec.push(Some(row_vec));
        }
        Ok(ret_vec)
    }

    fn extract_py_values(&self, py: Python, names: &[&PyString]) -> PuffResult<PyObject> {
        let field_mapping: HashMap<&str, (usize, &Column)> = names
            .iter()
            .flat_map(|field_name| {
                self.statement
                    .columns()
                    .iter()
                    .enumerate()
                    .find_map(|(ix, c)| {
                        if c.name() == field_name.to_str().expect("Expected string") {
                            Some((c.name(), (ix, c)))
                        } else {
                            None
                        }
                    })
            })
            .collect();

        let ret_vec = PyList::empty(py);
        for row in &self.rows {
            let row_vec = PyList::empty(py);
            for name in names {
                let (column_ix, c) = if let Some((column_ix, c)) = field_mapping.get(name.to_str()?)
                {
                    (column_ix, *c)
                } else {
                    bail!("Could not find {} in Postgres row", name)
                };
                let field_val = column_to_python(py, *column_ix, c, row)?;
                row_vec.append(field_val)?;
            }
            ret_vec.append(row_vec)?;
        }
        Ok(ret_vec.into_py(py))
    }

    fn extract_first(&self) -> Result<Vec<AggroValue>> {
        let mut ret_vec = Vec::with_capacity(self.rows.len());
        for row in &self.rows {
            let columns = row.columns();
            let c = columns
                .get(0)
                .ok_or(anyhow!("Expected at least one column in query."))?;
            let field_val = convert_postgres_to_juniper(row, 0, c.type_())?;
            ret_vec.push(field_val);
        }
        Ok(ret_vec)
    }
}

pub struct ExtractorRootNode;

impl ExtractValues for ExtractorRootNode {
    fn len(&self) -> usize {
        1
    }
    fn extract_values(&self, _names: &[Text]) -> Result<Vec<Option<Vec<AggroValue>>>> {
        bail!("Cannot extract values from the Root")
    }
    fn extract_py_values(&self, _py: Python, _names: &[&PyString]) -> Result<PyObject> {
        bail!("Cannot extract values from the Root")
    }
    fn extract_first(&self) -> Result<Vec<AggroValue>> {
        Ok(vec![AggroValue::Null])
    }
}

pub struct PythonResultRows {
    pub py_list: Py<PyList>,
}

impl ExtractValues for PythonResultRows {
    fn len(&self) -> usize {
        Python::with_gil(|py| self.py_list.as_ref(py).len())
    }
    fn extract_values(&self, names: &[Text]) -> Result<Vec<Option<Vec<AggroValue>>>> {
        Python::with_gil(|py| {
            let l = self.py_list.as_ref(py);
            let mut ret_vec = Vec::with_capacity(l.len());

            for row in l {
                if row.is_none() {
                    ret_vec.push(None);
                    continue;
                }
                let mut row_vec = Vec::with_capacity(names.len());
                for name in names {
                    let val = if let Ok(d) = row.downcast::<PyDict>() {
                        d.get_item(name.as_str())
                            .ok_or(PyKeyError::new_err(format!(
                                "Could not find {} in parent",
                                name
                            )))?
                    } else {
                        row.getattr(name.as_str())?
                    };
                    let jupiter_val = convert_pyany_to_jupiter(val);
                    row_vec.push(jupiter_val)
                }
                ret_vec.push(Some(row_vec));
            }
            Ok(ret_vec)
        })
    }

    fn extract_py_values(&self, py: Python, names: &[&PyString]) -> PuffResult<PyObject> {
        let l = self.py_list.as_ref(py);
        let final_list = PyList::empty(py);
        for row in l {
            let row_vec = PyList::empty(py);
            let none = row.is_none();
            for name in names {
                if none {
                    row_vec.append(py.None())?;
                    continue;
                }
                let val = if let Ok(d) = row.downcast::<PyDict>() {
                    d.get_item(name.to_str()?)
                        .ok_or(PyKeyError::new_err(format!(
                            "Could not find {} in parent",
                            name.to_string()
                        )))?
                } else {
                    row.getattr(name.to_str()?)?
                };
                row_vec.append(val)?
            }
            final_list.append(row_vec)?;
        }
        Ok(final_list.into_py(py))
    }

    fn extract_first(&self) -> Result<Vec<AggroValue>> {
        Python::with_gil(|py| {
            let l = self.py_list.as_ref(py);
            let mut ret_vec = Vec::with_capacity(l.len());
            for row in l {
                if row.is_none() {
                    ret_vec.push(AggroValue::null());
                    continue;
                }
                let val = convert_pyany_to_jupiter(row);
                ret_vec.push(val);
            }
            Ok(ret_vec)
        })
    }
}