1use bson::{Bson, Document, oid::ObjectId};
4use serde::{Serialize, de::DeserializeOwned};
5
6use crate::error::{MongoError, MongoResult};
7
8pub trait DocumentExt {
10 fn get_str(&self, key: &str) -> MongoResult<&str>;
12
13 fn get_str_opt(&self, key: &str) -> Option<&str>;
15
16 fn get_i32(&self, key: &str) -> MongoResult<i32>;
18
19 fn get_i32_opt(&self, key: &str) -> Option<i32>;
21
22 fn get_i64(&self, key: &str) -> MongoResult<i64>;
24
25 fn get_i64_opt(&self, key: &str) -> Option<i64>;
27
28 fn get_bool(&self, key: &str) -> MongoResult<bool>;
30
31 fn get_bool_opt(&self, key: &str) -> Option<bool>;
33
34 fn get_object_id(&self, key: &str) -> MongoResult<ObjectId>;
36
37 fn get_object_id_opt(&self, key: &str) -> Option<ObjectId>;
39
40 fn get_document(&self, key: &str) -> MongoResult<&Document>;
42
43 fn get_document_opt(&self, key: &str) -> Option<&Document>;
45
46 fn get_array(&self, key: &str) -> MongoResult<&Vec<Bson>>;
48
49 fn get_array_opt(&self, key: &str) -> Option<&Vec<Bson>>;
51
52 fn to_struct<T: DeserializeOwned>(&self) -> MongoResult<T>;
54
55 fn id(&self) -> MongoResult<ObjectId>;
57}
58
59impl DocumentExt for Document {
60 fn get_str(&self, key: &str) -> MongoResult<&str> {
61 self.get_str(key)
62 .map_err(|_| MongoError::query(format!("field '{}' is not a string", key)))
63 }
64
65 fn get_str_opt(&self, key: &str) -> Option<&str> {
66 self.get_str(key).ok()
67 }
68
69 fn get_i32(&self, key: &str) -> MongoResult<i32> {
70 self.get_i32(key)
71 .map_err(|_| MongoError::query(format!("field '{}' is not an i32", key)))
72 }
73
74 fn get_i32_opt(&self, key: &str) -> Option<i32> {
75 self.get_i32(key).ok()
76 }
77
78 fn get_i64(&self, key: &str) -> MongoResult<i64> {
79 self.get_i64(key)
80 .map_err(|_| MongoError::query(format!("field '{}' is not an i64", key)))
81 }
82
83 fn get_i64_opt(&self, key: &str) -> Option<i64> {
84 self.get_i64(key).ok()
85 }
86
87 fn get_bool(&self, key: &str) -> MongoResult<bool> {
88 self.get_bool(key)
89 .map_err(|_| MongoError::query(format!("field '{}' is not a bool", key)))
90 }
91
92 fn get_bool_opt(&self, key: &str) -> Option<bool> {
93 self.get_bool(key).ok()
94 }
95
96 fn get_object_id(&self, key: &str) -> MongoResult<ObjectId> {
97 self.get_object_id(key)
98 .map_err(|_| MongoError::query(format!("field '{}' is not an ObjectId", key)))
99 }
100
101 fn get_object_id_opt(&self, key: &str) -> Option<ObjectId> {
102 self.get_object_id(key).ok()
103 }
104
105 fn get_document(&self, key: &str) -> MongoResult<&Document> {
106 self.get_document(key)
107 .map_err(|_| MongoError::query(format!("field '{}' is not a document", key)))
108 }
109
110 fn get_document_opt(&self, key: &str) -> Option<&Document> {
111 self.get_document(key).ok()
112 }
113
114 fn get_array(&self, key: &str) -> MongoResult<&Vec<Bson>> {
115 self.get_array(key)
116 .map_err(|_| MongoError::query(format!("field '{}' is not an array", key)))
117 }
118
119 fn get_array_opt(&self, key: &str) -> Option<&Vec<Bson>> {
120 self.get_array(key).ok()
121 }
122
123 fn to_struct<T: DeserializeOwned>(&self) -> MongoResult<T> {
124 bson::from_document(self.clone()).map_err(|e| MongoError::serialization(e.to_string()))
125 }
126
127 fn id(&self) -> MongoResult<ObjectId> {
128 self.get_object_id("_id")
129 .map_err(|_| MongoError::query("field '_id' is not an ObjectId"))
130 }
131}
132
133pub fn to_document<T: Serialize>(value: &T) -> MongoResult<Document> {
135 bson::to_document(value).map_err(|e| MongoError::serialization(e.to_string()))
136}
137
138pub fn from_document<T: DeserializeOwned>(doc: Document) -> MongoResult<T> {
140 bson::from_document(doc).map_err(|e| MongoError::serialization(e.to_string()))
141}
142
143pub fn parse_object_id(s: &str) -> MongoResult<ObjectId> {
145 ObjectId::parse_str(s).map_err(MongoError::from)
146}
147
148pub fn new_object_id() -> ObjectId {
150 ObjectId::new()
151}
152
153pub mod bson_types {
155 use super::*;
156 use chrono::{DateTime, Utc};
157 use uuid::Uuid;
158
159 pub fn uuid_to_bson(uuid: Uuid) -> Bson {
161 Bson::Binary(bson::Binary {
162 subtype: bson::spec::BinarySubtype::Uuid,
163 bytes: uuid.as_bytes().to_vec(),
164 })
165 }
166
167 pub fn bson_to_uuid(bson: &Bson) -> MongoResult<Uuid> {
169 match bson {
170 Bson::Binary(binary) => {
171 let bytes: [u8; 16] = binary
172 .bytes
173 .as_slice()
174 .try_into()
175 .map_err(|_| MongoError::serialization("invalid UUID bytes"))?;
176 Ok(Uuid::from_bytes(bytes))
177 }
178 Bson::String(s) => Uuid::parse_str(s)
179 .map_err(|e| MongoError::serialization(format!("invalid UUID string: {}", e))),
180 _ => Err(MongoError::serialization(
181 "expected Binary or String for UUID",
182 )),
183 }
184 }
185
186 pub fn datetime_to_bson(dt: DateTime<Utc>) -> Bson {
188 Bson::DateTime(bson::DateTime::from_chrono(dt))
189 }
190
191 pub fn bson_to_datetime(bson: &Bson) -> MongoResult<DateTime<Utc>> {
193 match bson {
194 Bson::DateTime(dt) => Ok(dt.to_chrono()),
195 _ => Err(MongoError::serialization("expected DateTime")),
196 }
197 }
198
199 pub fn bson_type_to_prax(bson: &Bson) -> &'static str {
201 match bson {
202 Bson::Double(_) => "Float",
203 Bson::String(_) => "String",
204 Bson::Array(_) => "List",
205 Bson::Document(_) => "Json",
206 Bson::Boolean(_) => "Boolean",
207 Bson::Null => "Null",
208 Bson::Int32(_) => "Int",
209 Bson::Int64(_) => "BigInt",
210 Bson::DateTime(_) => "DateTime",
211 Bson::Binary(b) if b.subtype == bson::spec::BinarySubtype::Uuid => "Uuid",
212 Bson::Binary(_) => "Bytes",
213 Bson::ObjectId(_) => "String", Bson::Decimal128(_) => "Decimal",
215 _ => "Unknown",
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use bson::doc;
224 use uuid::Uuid;
225
226 #[test]
227 fn test_document_ext_get_str() {
228 let doc = doc! { "name": "Alice", "age": 30 };
229 assert_eq!(DocumentExt::get_str(&doc, "name").unwrap(), "Alice");
230 assert!(DocumentExt::get_str(&doc, "age").is_err());
231 assert!(DocumentExt::get_str(&doc, "missing").is_err());
232 }
233
234 #[test]
235 fn test_document_ext_get_i32() {
236 let doc = doc! { "count": 42, "name": "test" };
237 assert_eq!(DocumentExt::get_i32(&doc, "count").unwrap(), 42);
238 assert!(DocumentExt::get_i32(&doc, "name").is_err());
239 }
240
241 #[test]
242 fn test_document_ext_get_object_id() {
243 let oid = ObjectId::new();
244 let doc = doc! { "_id": oid };
245 assert_eq!(DocumentExt::get_object_id(&doc, "_id").unwrap(), oid);
246 }
247
248 #[test]
249 fn test_to_document() {
250 #[derive(Serialize)]
251 struct User {
252 name: String,
253 age: i32,
254 }
255
256 let user = User {
257 name: "Bob".to_string(),
258 age: 25,
259 };
260
261 let doc = to_document(&user).unwrap();
262 assert_eq!(doc.get_str("name").unwrap(), "Bob");
263 assert_eq!(doc.get_i32("age").unwrap(), 25);
264 }
265
266 #[test]
267 fn test_from_document() {
268 #[derive(Debug, PartialEq, serde::Deserialize)]
269 struct User {
270 name: String,
271 age: i32,
272 }
273
274 let doc = doc! { "name": "Carol", "age": 35 };
275 let user: User = from_document(doc).unwrap();
276 assert_eq!(
277 user,
278 User {
279 name: "Carol".to_string(),
280 age: 35
281 }
282 );
283 }
284
285 #[test]
286 fn test_parse_object_id() {
287 let oid = new_object_id();
288 let parsed = parse_object_id(&oid.to_hex()).unwrap();
289 assert_eq!(oid, parsed);
290
291 assert!(parse_object_id("invalid").is_err());
292 }
293
294 #[test]
295 fn test_uuid_conversion() {
296 use bson_types::*;
297
298 let uuid = Uuid::new_v4();
299 let bson = uuid_to_bson(uuid);
300 let parsed = bson_to_uuid(&bson).unwrap();
301 assert_eq!(uuid, parsed);
302 }
303}