Skip to main content

openauth_core/api/
additional_fields.rs

1use std::collections::BTreeMap;
2
3use serde_json::Value;
4
5use crate::db::{DbAdapter, DbFieldType, DbRecord, DbValue, FindOne, User, Where};
6use crate::error::OpenAuthError;
7use crate::options::{SessionAdditionalField, UserAdditionalField};
8
9pub trait AdditionalField {
10    fn field_type(&self) -> &DbFieldType;
11    fn required(&self) -> bool;
12    fn input(&self) -> bool;
13    fn returned(&self) -> bool;
14    fn default_value(&self) -> Option<&DbValue>;
15    fn db_name(&self) -> Option<&str>;
16}
17
18impl AdditionalField for UserAdditionalField {
19    fn field_type(&self) -> &DbFieldType {
20        &self.field_type
21    }
22
23    fn required(&self) -> bool {
24        self.required
25    }
26
27    fn input(&self) -> bool {
28        self.input
29    }
30
31    fn returned(&self) -> bool {
32        self.returned
33    }
34
35    fn default_value(&self) -> Option<&DbValue> {
36        self.default_value.as_ref()
37    }
38
39    fn db_name(&self) -> Option<&str> {
40        self.db_name.as_deref()
41    }
42}
43
44impl AdditionalField for SessionAdditionalField {
45    fn field_type(&self) -> &DbFieldType {
46        &self.field_type
47    }
48
49    fn required(&self) -> bool {
50        self.required
51    }
52
53    fn input(&self) -> bool {
54        self.input
55    }
56
57    fn returned(&self) -> bool {
58        self.returned
59    }
60
61    fn default_value(&self) -> Option<&DbValue> {
62        self.default_value.as_ref()
63    }
64
65    fn db_name(&self) -> Option<&str> {
66        self.db_name.as_deref()
67    }
68}
69
70pub fn create_values<F>(
71    fields: &BTreeMap<String, F>,
72    body: &serde_json::Map<String, Value>,
73) -> Result<DbRecord, AdditionalFieldError>
74where
75    F: AdditionalField,
76{
77    let mut values = DbRecord::new();
78    for (name, field) in fields {
79        match body.get(name) {
80            Some(value) => {
81                if !field.input() {
82                    return Err(AdditionalFieldError::NotInput(name.clone()));
83                }
84                values.insert(
85                    storage_name(name, field),
86                    json_to_db_value(name, field.field_type(), value)
87                        .map_err(AdditionalFieldError::InvalidType)?,
88                );
89            }
90            None => {
91                if let Some(value) = field.default_value() {
92                    values.insert(storage_name(name, field), value.clone());
93                } else if field.required() {
94                    return Err(AdditionalFieldError::MissingRequired(name.clone()));
95                } else {
96                    values.insert(storage_name(name, field), DbValue::Null);
97                }
98            }
99        }
100    }
101    Ok(values)
102}
103
104pub fn update_values<F>(
105    fields: &BTreeMap<String, F>,
106    body: &serde_json::Map<String, Value>,
107) -> Result<DbRecord, AdditionalFieldError>
108where
109    F: AdditionalField,
110{
111    let mut values = DbRecord::new();
112    for (name, value) in body {
113        let Some(field) = fields.get(name) else {
114            continue;
115        };
116        if !field.input() {
117            return Err(AdditionalFieldError::NotInput(name.clone()));
118        }
119        values.insert(
120            storage_name(name, field),
121            json_to_db_value(name, field.field_type(), value)
122                .map_err(AdditionalFieldError::InvalidType)?,
123        );
124    }
125    Ok(values)
126}
127
128pub fn insert_returned_fields<F>(
129    object: &mut serde_json::Map<String, Value>,
130    fields: &BTreeMap<String, F>,
131    record: &DbRecord,
132) -> Result<(), OpenAuthError>
133where
134    F: AdditionalField,
135{
136    for (name, field) in fields {
137        if !field.returned() {
138            continue;
139        }
140        let value = record
141            .get(name)
142            .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
143            .or_else(|| field.default_value())
144            .unwrap_or(&DbValue::Null);
145        object.insert(name.clone(), db_value_to_json(value)?);
146    }
147    Ok(())
148}
149
150fn storage_name<F>(logical_name: &str, field: &F) -> String
151where
152    F: AdditionalField,
153{
154    field
155        .db_name()
156        .map(str::to_owned)
157        .unwrap_or_else(|| logical_name.to_owned())
158}
159
160pub fn db_value_to_json(value: &DbValue) -> Result<Value, OpenAuthError> {
161    match value {
162        DbValue::String(value) => Ok(Value::String(value.clone())),
163        DbValue::Number(value) => Ok(Value::Number((*value).into())),
164        DbValue::Boolean(value) => Ok(Value::Bool(*value)),
165        DbValue::Timestamp(value) => {
166            serde_json::to_value(value).map_err(|error| OpenAuthError::Serialization {
167                context: "serializing additional field timestamp",
168                message: error.to_string(),
169            })
170        }
171        DbValue::Json(value) => Ok(value.clone()),
172        DbValue::StringArray(values) => Ok(Value::Array(
173            values.iter().cloned().map(Value::String).collect(),
174        )),
175        DbValue::NumberArray(values) => Ok(Value::Array(
176            values
177                .iter()
178                .map(|value| Value::Number((*value).into()))
179                .collect(),
180        )),
181        DbValue::Record(record) => db_record_to_json(record),
182        DbValue::RecordArray(records) => records
183            .iter()
184            .map(db_record_to_json)
185            .collect::<Result<Vec<_>, _>>()
186            .map(Value::Array),
187        DbValue::Null => Ok(Value::Null),
188    }
189}
190
191pub fn json_to_db_value(
192    name: &str,
193    field_type: &DbFieldType,
194    value: &Value,
195) -> Result<DbValue, String> {
196    if value.is_null() {
197        return Ok(DbValue::Null);
198    }
199    match field_type {
200        DbFieldType::String => value
201            .as_str()
202            .map(|value| DbValue::String(value.to_owned())),
203        DbFieldType::Number => value.as_i64().map(DbValue::Number),
204        DbFieldType::Boolean => value.as_bool().map(DbValue::Boolean),
205        DbFieldType::Json => Some(DbValue::Json(value.clone())),
206        DbFieldType::StringArray => value.as_array().and_then(|values| {
207            values
208                .iter()
209                .map(|value| value.as_str().map(str::to_owned))
210                .collect::<Option<Vec<_>>>()
211                .map(DbValue::StringArray)
212        }),
213        DbFieldType::NumberArray => value.as_array().and_then(|values| {
214            values
215                .iter()
216                .map(Value::as_i64)
217                .collect::<Option<Vec<_>>>()
218                .map(DbValue::NumberArray)
219        }),
220        DbFieldType::Timestamp => None,
221    }
222    .ok_or_else(|| format!("invalid value for additional field `{name}`"))
223}
224
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum AdditionalFieldError {
227    MissingRequired(String),
228    NotInput(String),
229    InvalidType(String),
230}
231
232impl AdditionalFieldError {
233    pub fn message(&self) -> String {
234        match self {
235            Self::MissingRequired(name) => format!("missing required additional field `{name}`"),
236            Self::NotInput(name) => format!("additional field `{name}` is not accepted as input"),
237            Self::InvalidType(message) => message.clone(),
238        }
239    }
240}
241
242pub async fn user_response_value(
243    adapter: &dyn DbAdapter,
244    fields: &BTreeMap<String, UserAdditionalField>,
245    user: &User,
246) -> Result<Value, OpenAuthError> {
247    if fields.is_empty() {
248        return serde_json::to_value(user).map_err(|error| OpenAuthError::Serialization {
249            context: "serializing user output",
250            message: error.to_string(),
251        });
252    }
253    let record = adapter
254        .find_one(
255            FindOne::new("user").where_clause(Where::new("id", DbValue::String(user.id.clone()))),
256        )
257        .await?;
258    let mut value = serde_json::to_value(user).map_err(|error| OpenAuthError::Serialization {
259        context: "serializing user output",
260        message: error.to_string(),
261    })?;
262    let Some(object) = value.as_object_mut() else {
263        return Err(OpenAuthError::Serialization {
264            context: "serializing user output",
265            message: "expected JSON object".to_owned(),
266        });
267    };
268    if let Some(record) = record {
269        insert_returned_fields(object, fields, &record)?;
270    }
271    Ok(value)
272}
273
274fn db_record_to_json(record: &DbRecord) -> Result<Value, OpenAuthError> {
275    record
276        .iter()
277        .map(|(field, value)| db_value_to_json(value).map(|value| (field.clone(), value)))
278        .collect::<Result<serde_json::Map<_, _>, _>>()
279        .map(Value::Object)
280}