1use std::collections::BTreeMap;
2
3use serde_json::Value;
4
5use crate::db::{DbAdapter, DbFieldType, DbRecord, DbValue, FindOne, User, Where};
6use crate::error::OpenAuthError;
7use crate::options::{SessionAdditionalField, UserAdditionalField};
8
9pub trait AdditionalField {
10 fn field_type(&self) -> &DbFieldType;
11 fn required(&self) -> bool;
12 fn input(&self) -> bool;
13 fn returned(&self) -> bool;
14 fn default_value(&self) -> Option<&DbValue>;
15 fn db_name(&self) -> Option<&str>;
16}
17
18impl AdditionalField for UserAdditionalField {
19 fn field_type(&self) -> &DbFieldType {
20 &self.field_type
21 }
22
23 fn required(&self) -> bool {
24 self.required
25 }
26
27 fn input(&self) -> bool {
28 self.input
29 }
30
31 fn returned(&self) -> bool {
32 self.returned
33 }
34
35 fn default_value(&self) -> Option<&DbValue> {
36 self.default_value.as_ref()
37 }
38
39 fn db_name(&self) -> Option<&str> {
40 self.db_name.as_deref()
41 }
42}
43
44impl AdditionalField for SessionAdditionalField {
45 fn field_type(&self) -> &DbFieldType {
46 &self.field_type
47 }
48
49 fn required(&self) -> bool {
50 self.required
51 }
52
53 fn input(&self) -> bool {
54 self.input
55 }
56
57 fn returned(&self) -> bool {
58 self.returned
59 }
60
61 fn default_value(&self) -> Option<&DbValue> {
62 self.default_value.as_ref()
63 }
64
65 fn db_name(&self) -> Option<&str> {
66 self.db_name.as_deref()
67 }
68}
69
70pub fn create_values<F>(
71 fields: &BTreeMap<String, F>,
72 body: &serde_json::Map<String, Value>,
73) -> Result<DbRecord, AdditionalFieldError>
74where
75 F: AdditionalField,
76{
77 let mut values = DbRecord::new();
78 for (name, field) in fields {
79 match body.get(name) {
80 Some(value) => {
81 if !field.input() {
82 return Err(AdditionalFieldError::NotInput(name.clone()));
83 }
84 values.insert(
85 storage_name(name, field),
86 json_to_db_value(name, field.field_type(), value)
87 .map_err(AdditionalFieldError::InvalidType)?,
88 );
89 }
90 None => {
91 if let Some(value) = field.default_value() {
92 values.insert(storage_name(name, field), value.clone());
93 } else if field.required() {
94 return Err(AdditionalFieldError::MissingRequired(name.clone()));
95 } else {
96 values.insert(storage_name(name, field), DbValue::Null);
97 }
98 }
99 }
100 }
101 Ok(values)
102}
103
104pub fn update_values<F>(
105 fields: &BTreeMap<String, F>,
106 body: &serde_json::Map<String, Value>,
107) -> Result<DbRecord, AdditionalFieldError>
108where
109 F: AdditionalField,
110{
111 let mut values = DbRecord::new();
112 for (name, value) in body {
113 let Some(field) = fields.get(name) else {
114 continue;
115 };
116 if !field.input() {
117 return Err(AdditionalFieldError::NotInput(name.clone()));
118 }
119 values.insert(
120 storage_name(name, field),
121 json_to_db_value(name, field.field_type(), value)
122 .map_err(AdditionalFieldError::InvalidType)?,
123 );
124 }
125 Ok(values)
126}
127
128pub fn insert_returned_fields<F>(
129 object: &mut serde_json::Map<String, Value>,
130 fields: &BTreeMap<String, F>,
131 record: &DbRecord,
132) -> Result<(), OpenAuthError>
133where
134 F: AdditionalField,
135{
136 for (name, field) in fields {
137 if !field.returned() {
138 continue;
139 }
140 let value = record
141 .get(name)
142 .or_else(|| field.db_name().and_then(|db_name| record.get(db_name)))
143 .or_else(|| field.default_value())
144 .unwrap_or(&DbValue::Null);
145 object.insert(name.clone(), db_value_to_json(value)?);
146 }
147 Ok(())
148}
149
150fn storage_name<F>(logical_name: &str, field: &F) -> String
151where
152 F: AdditionalField,
153{
154 field
155 .db_name()
156 .map(str::to_owned)
157 .unwrap_or_else(|| logical_name.to_owned())
158}
159
160pub fn db_value_to_json(value: &DbValue) -> Result<Value, OpenAuthError> {
161 match value {
162 DbValue::String(value) => Ok(Value::String(value.clone())),
163 DbValue::Number(value) => Ok(Value::Number((*value).into())),
164 DbValue::Boolean(value) => Ok(Value::Bool(*value)),
165 DbValue::Timestamp(value) => {
166 serde_json::to_value(value).map_err(|error| OpenAuthError::Serialization {
167 context: "serializing additional field timestamp",
168 message: error.to_string(),
169 })
170 }
171 DbValue::Json(value) => Ok(value.clone()),
172 DbValue::StringArray(values) => Ok(Value::Array(
173 values.iter().cloned().map(Value::String).collect(),
174 )),
175 DbValue::NumberArray(values) => Ok(Value::Array(
176 values
177 .iter()
178 .map(|value| Value::Number((*value).into()))
179 .collect(),
180 )),
181 DbValue::Record(record) => db_record_to_json(record),
182 DbValue::RecordArray(records) => records
183 .iter()
184 .map(db_record_to_json)
185 .collect::<Result<Vec<_>, _>>()
186 .map(Value::Array),
187 DbValue::Null => Ok(Value::Null),
188 }
189}
190
191pub fn json_to_db_value(
192 name: &str,
193 field_type: &DbFieldType,
194 value: &Value,
195) -> Result<DbValue, String> {
196 if value.is_null() {
197 return Ok(DbValue::Null);
198 }
199 match field_type {
200 DbFieldType::String => value
201 .as_str()
202 .map(|value| DbValue::String(value.to_owned())),
203 DbFieldType::Number => value.as_i64().map(DbValue::Number),
204 DbFieldType::Boolean => value.as_bool().map(DbValue::Boolean),
205 DbFieldType::Json => Some(DbValue::Json(value.clone())),
206 DbFieldType::StringArray => value.as_array().and_then(|values| {
207 values
208 .iter()
209 .map(|value| value.as_str().map(str::to_owned))
210 .collect::<Option<Vec<_>>>()
211 .map(DbValue::StringArray)
212 }),
213 DbFieldType::NumberArray => value.as_array().and_then(|values| {
214 values
215 .iter()
216 .map(Value::as_i64)
217 .collect::<Option<Vec<_>>>()
218 .map(DbValue::NumberArray)
219 }),
220 DbFieldType::Timestamp => None,
221 }
222 .ok_or_else(|| format!("invalid value for additional field `{name}`"))
223}
224
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum AdditionalFieldError {
227 MissingRequired(String),
228 NotInput(String),
229 InvalidType(String),
230}
231
232impl AdditionalFieldError {
233 pub fn message(&self) -> String {
234 match self {
235 Self::MissingRequired(name) => format!("missing required additional field `{name}`"),
236 Self::NotInput(name) => format!("additional field `{name}` is not accepted as input"),
237 Self::InvalidType(message) => message.clone(),
238 }
239 }
240}
241
242pub async fn user_response_value(
243 adapter: &dyn DbAdapter,
244 fields: &BTreeMap<String, UserAdditionalField>,
245 user: &User,
246) -> Result<Value, OpenAuthError> {
247 if fields.is_empty() {
248 return serde_json::to_value(user).map_err(|error| OpenAuthError::Serialization {
249 context: "serializing user output",
250 message: error.to_string(),
251 });
252 }
253 let record = adapter
254 .find_one(
255 FindOne::new("user").where_clause(Where::new("id", DbValue::String(user.id.clone()))),
256 )
257 .await?;
258 let mut value = serde_json::to_value(user).map_err(|error| OpenAuthError::Serialization {
259 context: "serializing user output",
260 message: error.to_string(),
261 })?;
262 let Some(object) = value.as_object_mut() else {
263 return Err(OpenAuthError::Serialization {
264 context: "serializing user output",
265 message: "expected JSON object".to_owned(),
266 });
267 };
268 if let Some(record) = record {
269 insert_returned_fields(object, fields, &record)?;
270 }
271 Ok(value)
272}
273
274fn db_record_to_json(record: &DbRecord) -> Result<Value, OpenAuthError> {
275 record
276 .iter()
277 .map(|(field, value)| db_value_to_json(value).map(|value| (field.clone(), value)))
278 .collect::<Result<serde_json::Map<_, _>, _>>()
279 .map(Value::Object)
280}