Skip to main content

rullst_orm/
audit.rs

1use crate::Orm;
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5pub struct AuditLog {
6    pub id: i32,
7    pub model_type: String,
8    pub model_id: i32,
9    pub event: String,
10    pub old_values: Option<String>,
11    pub new_values: Option<String>,
12    pub created_at: Option<String>,
13}
14
15pub async fn log_audit(
16    model_type: &str,
17    model_id: i32,
18    event: &str,
19    mut old_values: Option<String>,
20    mut new_values: Option<String>,
21) -> Result<(), crate::Error> {
22    const MAX_PAYLOAD_LEN: usize = 5 * 1024 * 1024; // 5 MB
23
24    if model_type.len() > 255 || event.len() > 50 {
25        return Err(crate::Error::Validation(
26            "Audit model_type or event string too long".to_string(),
27        ));
28    }
29
30    if let Some(val) = &old_values
31        && val.len() > MAX_PAYLOAD_LEN
32    {
33        old_values = Some(r#"{"error":"payload_too_large"}"#.to_string());
34    }
35
36    if let Some(val) = &new_values
37        && val.len() > MAX_PAYLOAD_LEN
38    {
39        new_values = Some(r#"{"error":"payload_too_large"}"#.to_string());
40    }
41
42    let pool = Orm::pool();
43    let driver = Orm::driver();
44
45    if driver == "postgres" {
46        sqlx::query(
47            "INSERT INTO rullst_audits (model_type, model_id, event, old_values, new_values) VALUES ($1, $2, $3, $4, $5)"
48        )
49        .bind(model_type)
50        .bind(model_id)
51        .bind(event)
52        .bind(old_values)
53        .bind(new_values)
54        .execute(pool)
55        .await?;
56    } else {
57        sqlx::query(
58            "INSERT INTO rullst_audits (model_type, model_id, event, old_values, new_values) VALUES (?, ?, ?, ?, ?)"
59        )
60        .bind(model_type)
61        .bind(model_id)
62        .bind(event)
63        .bind(old_values)
64        .bind(new_values)
65        .execute(pool)
66        .await?;
67    }
68
69    Ok(())
70}
71
72pub fn compute_diff(old_json: &str, new_json: &str) -> (Option<String>, Option<String>) {
73    if old_json == new_json {
74        return (None, None);
75    }
76
77    let old_val: serde_json::Value =
78        serde_json::from_str(old_json).unwrap_or(serde_json::Value::Null);
79    let new_val: serde_json::Value =
80        serde_json::from_str(new_json).unwrap_or(serde_json::Value::Null);
81
82    let mut diff_old = serde_json::Map::new();
83    let mut diff_new = serde_json::Map::new();
84
85    fn is_sensitive(key: &str) -> bool {
86        let k = key.to_lowercase();
87        k.contains("password")
88            || k.contains("token")
89            || k.contains("secret")
90            || k.contains("senha")
91            || k.contains("api_key")
92    }
93
94    fn mask_if_sensitive(key: &str, value: serde_json::Value) -> serde_json::Value {
95        if is_sensitive(key) {
96            serde_json::Value::String("***".to_string())
97        } else {
98            value
99        }
100    }
101
102    if let (serde_json::Value::Object(old_obj), serde_json::Value::Object(mut new_obj)) =
103        (old_val, new_val)
104    {
105        for (k, v) in old_obj {
106            if let Some(new_v) = new_obj.remove(&k) {
107                #[allow(clippy::collapsible_if)]
108                if v != new_v {
109                    diff_new.insert(k.clone(), mask_if_sensitive(&k, new_v));
110                    diff_old.insert(k.clone(), mask_if_sensitive(&k, v));
111                }
112            } else {
113                diff_new.insert(k.clone(), serde_json::Value::Null);
114                diff_old.insert(k.clone(), mask_if_sensitive(&k, v));
115            }
116        }
117        for (k, new_v) in new_obj {
118            diff_new.insert(k.clone(), mask_if_sensitive(&k, new_v));
119            diff_old.insert(k, serde_json::Value::Null);
120        }
121    }
122
123    if diff_old.is_empty() && diff_new.is_empty() {
124        return (None, None); // Nothing changed
125    }
126
127    let final_old = serde_json::to_string(&diff_old).ok();
128    let final_new = serde_json::to_string(&diff_new).ok();
129
130    (final_old, final_new)
131}
132
133pub async fn log_audit_diff(
134    model_type: &str,
135    model_id: i32,
136    event: &str,
137    old_json: &str,
138    new_json: &str,
139) -> Result<(), crate::Error> {
140    const MAX_PAYLOAD_LEN: usize = 5 * 1024 * 1024; // 5 MB
141
142    if old_json.len() > MAX_PAYLOAD_LEN || new_json.len() > MAX_PAYLOAD_LEN {
143        return log_audit(
144            model_type,
145            model_id,
146            event,
147            Some(r#"{"error":"payload_too_large_for_diff"}"#.to_string()),
148            Some(r#"{"error":"payload_too_large_for_diff"}"#.to_string()),
149        )
150        .await;
151    }
152
153    let (final_old, final_new) = compute_diff(old_json, new_json);
154    if final_old.is_none() && final_new.is_none() {
155        return Ok(()); // Nothing changed
156    }
157    log_audit(model_type, model_id, event, final_old, final_new).await
158}
159
160pub async fn create_audit_table() -> Result<(), crate::Error> {
161    let pool = Orm::pool();
162    let driver = Orm::driver();
163
164    let query = if driver == "postgres" {
165        r#"
166        CREATE TABLE IF NOT EXISTS rullst_audits (
167            id SERIAL PRIMARY KEY,
168            model_type VARCHAR(255) NOT NULL,
169            model_id INT NOT NULL,
170            event VARCHAR(50) NOT NULL,
171            old_values TEXT,
172            new_values TEXT,
173            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
174        )
175        "#
176    } else if driver == "mysql" {
177        r#"
178        CREATE TABLE IF NOT EXISTS rullst_audits (
179            id INT AUTO_INCREMENT PRIMARY KEY,
180            model_type VARCHAR(255) NOT NULL,
181            model_id INT NOT NULL,
182            event VARCHAR(50) NOT NULL,
183            old_values TEXT,
184            new_values TEXT,
185            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
186        )
187        "#
188    } else {
189        r#"
190        CREATE TABLE IF NOT EXISTS rullst_audits (
191            id INTEGER PRIMARY KEY AUTOINCREMENT,
192            model_type TEXT NOT NULL,
193            model_id INTEGER NOT NULL,
194            event TEXT NOT NULL,
195            old_values TEXT,
196            new_values TEXT,
197            created_at DATETIME DEFAULT CURRENT_TIMESTAMP
198        )
199        "#
200    };
201
202    sqlx::query(query).execute(pool).await?;
203    Ok(())
204}
205
206#[cfg(test)]
207mod tests {
208    use super::AuditLog;
209
210    #[test]
211    fn test_audit_log_serialization_round_trip() {
212        let log = AuditLog {
213            id: 1,
214            model_type: "User".to_string(),
215            model_id: 42,
216            event: "created".to_string(),
217            old_values: None,
218            new_values: Some(r#"{"name":"Alice"}"#.to_string()),
219            created_at: Some("2024-01-01T00:00:00Z".to_string()),
220        };
221
222        let json_str = serde_json::to_string(&log).expect("serialize");
223        assert!(json_str.contains("\"model_type\":\"User\""));
224        assert!(json_str.contains("\"event\":\"created\""));
225
226        let deserialized: AuditLog = serde_json::from_str(&json_str).expect("deserialize");
227        assert_eq!(deserialized.id, 1);
228        assert_eq!(deserialized.model_id, 42);
229        assert_eq!(deserialized.event, "created");
230        assert!(deserialized.old_values.is_none());
231    }
232
233    #[test]
234    fn test_audit_log_clone_debug() {
235        let log = AuditLog {
236            id: 5,
237            model_type: "Post".to_string(),
238            model_id: 99,
239            event: "updated".to_string(),
240            old_values: Some(r#"{"title":"Old"}"#.to_string()),
241            new_values: Some(r#"{"title":"New"}"#.to_string()),
242            created_at: None,
243        };
244        let cloned = log.clone();
245        assert_eq!(cloned.model_type, "Post");
246        // Debug must not panic
247        let _ = format!("{:?}", cloned);
248    }
249
250    #[test]
251    fn test_compute_diff_changes() {
252        let old_json = r#"{"name":"Alice","age":30}"#;
253        let new_json = r#"{"name":"Alice","age":31}"#;
254        let (old_diff, new_diff) = super::compute_diff(old_json, new_json);
255        assert_eq!(old_diff.unwrap(), r#"{"age":30}"#);
256        assert_eq!(new_diff.unwrap(), r#"{"age":31}"#);
257    }
258
259    #[test]
260    fn test_compute_diff_no_changes() {
261        let json = r#"{"name":"Alice","age":30}"#;
262        let (old_diff, new_diff) = super::compute_diff(json, json);
263        assert!(old_diff.is_none());
264        assert!(new_diff.is_none());
265    }
266
267    #[test]
268    fn test_compute_diff_invalid_json() {
269        let (old_diff, new_diff) = super::compute_diff("not json", "{invalid}");
270        assert!(old_diff.is_none());
271        assert!(new_diff.is_none());
272    }
273
274    #[tokio::test]
275    async fn test_log_audit_diff_bypass() {
276        // Should not panic or hit the database if the old and new JSONs are identical
277        let result = super::log_audit_diff(
278            "User",
279            1,
280            "update",
281            r#"{"name":"Alice"}"#,
282            r#"{"name":"Alice"}"#,
283        )
284        .await;
285        assert!(result.is_ok());
286    }
287
288    #[test]
289    fn test_compute_diff_explicit_null_vs_omitted() {
290        let old_json = r#"{"name":"Alice","age":30}"#;
291        let new_json = r#"{"name":"Alice"}"#;
292        let (old_diff, new_diff) = super::compute_diff(old_json, new_json);
293        assert_eq!(old_diff.unwrap(), r#"{"age":30}"#);
294        assert_eq!(new_diff.unwrap(), r#"{"age":null}"#);
295
296        let old_json2 = r#"{"name":"Alice"}"#;
297        let new_json2 = r#"{"name":"Alice","age":null}"#;
298        let (old_diff2, new_diff2) = super::compute_diff(old_json2, new_json2);
299        assert_eq!(old_diff2.unwrap(), r#"{"age":null}"#);
300        assert_eq!(new_diff2.unwrap(), r#"{"age":null}"#);
301
302        let old_json3 = r#"{"name":"Alice"}"#;
303        let new_json3 = r#"{"name":"Alice","age":30}"#;
304        let (old_diff3, new_diff3) = super::compute_diff(old_json3, new_json3);
305        assert_eq!(old_diff3.unwrap(), r#"{"age":null}"#);
306        assert_eq!(new_diff3.unwrap(), r#"{"age":30}"#);
307    }
308}