Skip to main content

prax_mongodb/
document.rs

1//! Document mapping and conversion utilities.
2
3use bson::{Bson, Document, oid::ObjectId};
4use serde::{Serialize, de::DeserializeOwned};
5
6use crate::error::{MongoError, MongoResult};
7
8/// Extension trait for BSON documents.
9pub trait DocumentExt {
10    /// Get a string value from the document.
11    fn get_str(&self, key: &str) -> MongoResult<&str>;
12
13    /// Get an optional string value.
14    fn get_str_opt(&self, key: &str) -> Option<&str>;
15
16    /// Get an i32 value.
17    fn get_i32(&self, key: &str) -> MongoResult<i32>;
18
19    /// Get an optional i32 value.
20    fn get_i32_opt(&self, key: &str) -> Option<i32>;
21
22    /// Get an i64 value.
23    fn get_i64(&self, key: &str) -> MongoResult<i64>;
24
25    /// Get an optional i64 value.
26    fn get_i64_opt(&self, key: &str) -> Option<i64>;
27
28    /// Get a bool value.
29    fn get_bool(&self, key: &str) -> MongoResult<bool>;
30
31    /// Get an optional bool value.
32    fn get_bool_opt(&self, key: &str) -> Option<bool>;
33
34    /// Get an ObjectId value.
35    fn get_object_id(&self, key: &str) -> MongoResult<ObjectId>;
36
37    /// Get an optional ObjectId value.
38    fn get_object_id_opt(&self, key: &str) -> Option<ObjectId>;
39
40    /// Get a nested document.
41    fn get_document(&self, key: &str) -> MongoResult<&Document>;
42
43    /// Get an optional nested document.
44    fn get_document_opt(&self, key: &str) -> Option<&Document>;
45
46    /// Get an array value.
47    fn get_array(&self, key: &str) -> MongoResult<&Vec<Bson>>;
48
49    /// Get an optional array value.
50    fn get_array_opt(&self, key: &str) -> Option<&Vec<Bson>>;
51
52    /// Convert to a typed struct.
53    fn to_struct<T: DeserializeOwned>(&self) -> MongoResult<T>;
54
55    /// Get the `_id` field as ObjectId.
56    fn id(&self) -> MongoResult<ObjectId>;
57}
58
59impl DocumentExt for Document {
60    fn get_str(&self, key: &str) -> MongoResult<&str> {
61        self.get_str(key)
62            .map_err(|_| MongoError::query(format!("field '{}' is not a string", key)))
63    }
64
65    fn get_str_opt(&self, key: &str) -> Option<&str> {
66        self.get_str(key).ok()
67    }
68
69    fn get_i32(&self, key: &str) -> MongoResult<i32> {
70        self.get_i32(key)
71            .map_err(|_| MongoError::query(format!("field '{}' is not an i32", key)))
72    }
73
74    fn get_i32_opt(&self, key: &str) -> Option<i32> {
75        self.get_i32(key).ok()
76    }
77
78    fn get_i64(&self, key: &str) -> MongoResult<i64> {
79        self.get_i64(key)
80            .map_err(|_| MongoError::query(format!("field '{}' is not an i64", key)))
81    }
82
83    fn get_i64_opt(&self, key: &str) -> Option<i64> {
84        self.get_i64(key).ok()
85    }
86
87    fn get_bool(&self, key: &str) -> MongoResult<bool> {
88        self.get_bool(key)
89            .map_err(|_| MongoError::query(format!("field '{}' is not a bool", key)))
90    }
91
92    fn get_bool_opt(&self, key: &str) -> Option<bool> {
93        self.get_bool(key).ok()
94    }
95
96    fn get_object_id(&self, key: &str) -> MongoResult<ObjectId> {
97        self.get_object_id(key)
98            .map_err(|_| MongoError::query(format!("field '{}' is not an ObjectId", key)))
99    }
100
101    fn get_object_id_opt(&self, key: &str) -> Option<ObjectId> {
102        self.get_object_id(key).ok()
103    }
104
105    fn get_document(&self, key: &str) -> MongoResult<&Document> {
106        self.get_document(key)
107            .map_err(|_| MongoError::query(format!("field '{}' is not a document", key)))
108    }
109
110    fn get_document_opt(&self, key: &str) -> Option<&Document> {
111        self.get_document(key).ok()
112    }
113
114    fn get_array(&self, key: &str) -> MongoResult<&Vec<Bson>> {
115        self.get_array(key)
116            .map_err(|_| MongoError::query(format!("field '{}' is not an array", key)))
117    }
118
119    fn get_array_opt(&self, key: &str) -> Option<&Vec<Bson>> {
120        self.get_array(key).ok()
121    }
122
123    fn to_struct<T: DeserializeOwned>(&self) -> MongoResult<T> {
124        bson::from_document(self.clone()).map_err(|e| MongoError::serialization(e.to_string()))
125    }
126
127    fn id(&self) -> MongoResult<ObjectId> {
128        self.get_object_id("_id")
129            .map_err(|_| MongoError::query("field '_id' is not an ObjectId"))
130    }
131}
132
133/// Convert a struct to a BSON document.
134pub fn to_document<T: Serialize>(value: &T) -> MongoResult<Document> {
135    bson::to_document(value).map_err(|e| MongoError::serialization(e.to_string()))
136}
137
138/// Convert a BSON document to a struct.
139pub fn from_document<T: DeserializeOwned>(doc: Document) -> MongoResult<T> {
140    bson::from_document(doc).map_err(|e| MongoError::serialization(e.to_string()))
141}
142
143/// Parse an ObjectId from a string.
144pub fn parse_object_id(s: &str) -> MongoResult<ObjectId> {
145    ObjectId::parse_str(s).map_err(MongoError::from)
146}
147
148/// Create a new ObjectId.
149pub fn new_object_id() -> ObjectId {
150    ObjectId::new()
151}
152
153/// BSON type helpers.
154pub mod bson_types {
155    use super::*;
156    use chrono::{DateTime, Utc};
157    use uuid::Uuid;
158
159    /// Convert a UUID to BSON Binary.
160    pub fn uuid_to_bson(uuid: Uuid) -> Bson {
161        Bson::Binary(bson::Binary {
162            subtype: bson::spec::BinarySubtype::Uuid,
163            bytes: uuid.as_bytes().to_vec(),
164        })
165    }
166
167    /// Convert BSON Binary to UUID.
168    pub fn bson_to_uuid(bson: &Bson) -> MongoResult<Uuid> {
169        match bson {
170            Bson::Binary(binary) => {
171                let bytes: [u8; 16] = binary
172                    .bytes
173                    .as_slice()
174                    .try_into()
175                    .map_err(|_| MongoError::serialization("invalid UUID bytes"))?;
176                Ok(Uuid::from_bytes(bytes))
177            }
178            Bson::String(s) => Uuid::parse_str(s)
179                .map_err(|e| MongoError::serialization(format!("invalid UUID string: {}", e))),
180            _ => Err(MongoError::serialization(
181                "expected Binary or String for UUID",
182            )),
183        }
184    }
185
186    /// Convert a DateTime to BSON DateTime.
187    pub fn datetime_to_bson(dt: DateTime<Utc>) -> Bson {
188        Bson::DateTime(bson::DateTime::from_chrono(dt))
189    }
190
191    /// Convert BSON DateTime to chrono DateTime.
192    pub fn bson_to_datetime(bson: &Bson) -> MongoResult<DateTime<Utc>> {
193        match bson {
194            Bson::DateTime(dt) => Ok(dt.to_chrono()),
195            _ => Err(MongoError::serialization("expected DateTime")),
196        }
197    }
198
199    /// Get the Prax schema type for a BSON type.
200    pub fn bson_type_to_prax(bson: &Bson) -> &'static str {
201        match bson {
202            Bson::Double(_) => "Float",
203            Bson::String(_) => "String",
204            Bson::Array(_) => "List",
205            Bson::Document(_) => "Json",
206            Bson::Boolean(_) => "Boolean",
207            Bson::Null => "Null",
208            Bson::Int32(_) => "Int",
209            Bson::Int64(_) => "BigInt",
210            Bson::DateTime(_) => "DateTime",
211            Bson::Binary(b) if b.subtype == bson::spec::BinarySubtype::Uuid => "Uuid",
212            Bson::Binary(_) => "Bytes",
213            Bson::ObjectId(_) => "String", // ObjectId maps to String in Prax
214            Bson::Decimal128(_) => "Decimal",
215            _ => "Unknown",
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use bson::doc;
224    use uuid::Uuid;
225
226    #[test]
227    fn test_document_ext_get_str() {
228        let doc = doc! { "name": "Alice", "age": 30 };
229        assert_eq!(DocumentExt::get_str(&doc, "name").unwrap(), "Alice");
230        assert!(DocumentExt::get_str(&doc, "age").is_err());
231        assert!(DocumentExt::get_str(&doc, "missing").is_err());
232    }
233
234    #[test]
235    fn test_document_ext_get_i32() {
236        let doc = doc! { "count": 42, "name": "test" };
237        assert_eq!(DocumentExt::get_i32(&doc, "count").unwrap(), 42);
238        assert!(DocumentExt::get_i32(&doc, "name").is_err());
239    }
240
241    #[test]
242    fn test_document_ext_get_object_id() {
243        let oid = ObjectId::new();
244        let doc = doc! { "_id": oid };
245        assert_eq!(DocumentExt::get_object_id(&doc, "_id").unwrap(), oid);
246    }
247
248    #[test]
249    fn test_to_document() {
250        #[derive(Serialize)]
251        struct User {
252            name: String,
253            age: i32,
254        }
255
256        let user = User {
257            name: "Bob".to_string(),
258            age: 25,
259        };
260
261        let doc = to_document(&user).unwrap();
262        assert_eq!(doc.get_str("name").unwrap(), "Bob");
263        assert_eq!(doc.get_i32("age").unwrap(), 25);
264    }
265
266    #[test]
267    fn test_from_document() {
268        #[derive(Debug, PartialEq, serde::Deserialize)]
269        struct User {
270            name: String,
271            age: i32,
272        }
273
274        let doc = doc! { "name": "Carol", "age": 35 };
275        let user: User = from_document(doc).unwrap();
276        assert_eq!(
277            user,
278            User {
279                name: "Carol".to_string(),
280                age: 35
281            }
282        );
283    }
284
285    #[test]
286    fn test_parse_object_id() {
287        let oid = new_object_id();
288        let parsed = parse_object_id(&oid.to_hex()).unwrap();
289        assert_eq!(oid, parsed);
290
291        assert!(parse_object_id("invalid").is_err());
292    }
293
294    #[test]
295    fn test_uuid_conversion() {
296        use bson_types::*;
297
298        let uuid = Uuid::new_v4();
299        let bson = uuid_to_bson(uuid);
300        let parsed = bson_to_uuid(&bson).unwrap();
301        assert_eq!(uuid, parsed);
302    }
303}