pbbson 0.1.0

Utilities for pbjson to BSON conversion
use bson::{Bson, DateTime, Decimal128, Document, Timestamp};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::str::FromStr;
use tonic::Status;

#[derive(Clone, Debug, Default, Deserialize)]
pub struct Model(Document);

impl Model {
    pub fn contains_field(&self, field: impl AsRef<str>) -> bool {
        self.0.contains_key(field.as_ref())
    }

    pub fn get(&self, field: impl AsRef<str>) -> Option<&Bson> {
        self.0.get(field)
    }

    pub fn get_mut(&mut self, field: impl AsRef<str>) -> Option<&mut Bson> {
        self.0.get_mut(field.as_ref())
    }

    pub fn get_bool(&self, field: impl AsRef<str>) -> Result<bool, Status> {
        self.0
            .get_bool(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_datetime(&self, field: impl AsRef<str>) -> Result<&DateTime, Status> {
        self.0
            .get_datetime(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_decimal128(&self, field: impl AsRef<str>) -> Result<&Decimal128, Status> {
        self.0
            .get_decimal128(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_f64(&self, field: impl AsRef<str>) -> Result<f64, Status> {
        self.0
            .get_f64(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_i32(&self, field: impl AsRef<str>) -> Result<i32, Status> {
        self.0
            .get_i32(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_i64(&self, field: impl AsRef<str>) -> Result<i64, Status> {
        self.0
            .get_i64(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_str(&self, field: impl AsRef<str>) -> Result<&str, Status> {
        self.0
            .get_str(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn get_timestamp(&self, field: impl AsRef<str>) -> Result<Timestamp, Status> {
        self.0
            .get_timestamp(field)
            .map_err(|e| Status::invalid_argument(e.to_string()))
    }

    pub fn id(&self) -> Result<String, Status> {
        match self.0.get("id") {
            Some(Bson::ObjectId(object_id)) => try_from_object_id(object_id),
            Some(Bson::String(str)) => Ok(str.clone()),
            _ => Err(Status::not_found("No such field")),
        }
    }

    pub fn insert<KT: Into<String>, BT: Into<Bson>>(&mut self, field: KT, value: BT) -> Option<Bson> {
        self.0.insert(field, value)
    }

    pub fn is_null(&self, field: impl AsRef<str>) -> bool {
        self.0.is_null(field)
    }

    pub fn remove(&mut self, field: impl AsRef<str>) -> Option<Bson> {
        self.0.remove(field.as_ref())
    }

    pub fn set_datetime(&mut self, field: &str, value: DateTime) {
        self.0.insert(field, Some(value));
    }

    pub fn try_from<T: prost::Message + Serialize>(other: &T) -> Result<Self, Status> {
        let buf = serde_json::to_vec(&other).map_err(|e| Status::internal(e.to_string()))?;
        let mut model: Model = serde_json::from_slice(&buf).map_err(|e| Status::internal(e.to_string()))?;
        for key in model.0.clone().keys() {
            if key == "id" {
                let value = model.0.get_str(key).unwrap();
                if value.is_empty() {
                    model.0.remove(key);
                } else {
                    let value = try_to_object_id(value)?;
                    model.0.insert(key, value);
                }
            }
            if key.ends_with("At") {
                let value = model.0.get_str(key).unwrap();
                let value = bson::DateTime::parse_rfc3339_str(value).unwrap();
                model.0.insert(key, value);
            }
        }
        Ok(model)
    }

    pub fn try_into<T: prost::Message + Clone + Default + for<'a> Deserialize<'a>>(self) -> Result<T, Status> {
        let mut this = self.0.clone();

        // Perform some translations
        let this_ = this.clone();
        let keys = this_.keys();
        for key in keys {
            let value = this.get(key);
            if let Some(bson_value) = value {
                let new_value = to_pbjson(bson_value.clone())?;
                if key == "_id" {
                    this.remove(key);
                    this.insert("id", new_value);
                } else {
                    this.insert(key, new_value);
                }
            }
        }

        let buf = serde_json::to_vec(&this).map_err(|e| Status::internal(e.to_string()))?;
        let msg: T = serde_json::from_slice::<T>(&buf)
            .map_err(|e| Status::internal(e.to_string()))?
            .clone();
        let mut message = T::default();
        message.clone_from(&msg);
        Ok(message)
    }
}

pub fn to_pbjson(value: Bson) -> Result<Bson, Status> {
    match value {
        Bson::ObjectId(object_id) => Ok(Bson::String(try_from_object_id(&object_id)?)),
        Bson::DateTime(date_time) => Ok(Bson::String(
            date_time
                .try_to_rfc3339_string()
                .map_err(|e| Status::internal(e.to_string()))?
                .to_string(),
        )),
        Bson::Array(array) => {
            let mut new_array = vec![];
            for value in array {
                new_array.push(to_pbjson(value.clone())?);
            }
            Ok(Bson::Array(new_array))
        }
        _ => Ok(value),
    }
}

impl From<Document> for Model {
    fn from(other: Document) -> Self {
        Model(other)
    }
}

fn try_from_object_id(id: &bson::oid::ObjectId) -> Result<String, Status> {
    let id = xid::Id::from_bytes(&id.bytes()).map_err(|e| Status::internal(e.to_string()))?;
    Ok(id.to_string())
}

fn try_to_object_id(id: &str) -> Result<bson::oid::ObjectId, Status> {
    const WHAT: &str = "IDs";
    require_xid_base32hex(WHAT, id)?;
    let as_xid = match xid::Id::from_str(id) {
        Ok(xid) => xid,
        Err(e) => {
            let msg = format!("{WHAT} must be 12-byte XID's in 20-byte base32hex encoded form ({e:?})");
            return Err(Status::invalid_argument(msg));
        }
    };
    Ok(bson::oid::ObjectId::from_bytes(*as_xid.as_bytes()))
}

fn require_xid_base32hex(what: &str, id: &str) -> Result<(), Status> {
    if id.len() != 20 {
        return Err(Status::invalid_argument(format!(
            "{what} must be 12-byte XID's in 20-byte base32hex encoded form"
        )));
    }
    Ok(())
}

impl Borrow<Document> for Model {
    fn borrow(&self) -> &Document {
        self.0.borrow()
    }
}

#[cfg(test)]
mod tests {
    use super::Model;
    use tonic::Code;

    #[test]
    fn can_access_id_when_none() {
        let model = Model::default();
        assert!(model.id().is_err());
        assert_eq!(model.id().err().unwrap().code(), Code::NotFound);
    }

    #[test]
    fn can_deser() {
        const JSON: &str = r#"{"id":"d05vu4r71n0pfhlalahg","firstName":"Joe","age":23}"#;
        let model: Model = serde_json::from_str(JSON).unwrap();
        assert_eq!(model.id().unwrap(), "d05vu4r71n0pfhlalahg");
        assert_eq!(model.get_str("firstName").unwrap(), "Joe");
        assert_eq!(model.get_i32("age").unwrap(), 23);
    }
}