Skip to main content

rullst_orm/
audit.rs

1use crate::Orm;
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5pub struct AuditLog {
6    pub id: i32,
7    pub model_type: String,
8    pub model_id: i32,
9    pub event: String,
10    pub old_values: Option<String>,
11    pub new_values: Option<String>,
12    pub created_at: Option<String>,
13}
14
15pub async fn log_audit(
16    model_type: &str,
17    model_id: i32,
18    event: &str,
19    old_values: Option<String>,
20    new_values: Option<String>,
21) -> Result<(), crate::Error> {
22    let pool = Orm::pool();
23    let driver = Orm::driver();
24
25    if driver == "postgres" {
26        sqlx::query(
27            "INSERT INTO rullst_audits (model_type, model_id, event, old_values, new_values) VALUES ($1, $2, $3, $4, $5)"
28        )
29        .bind(model_type)
30        .bind(model_id)
31        .bind(event)
32        .bind(old_values)
33        .bind(new_values)
34        .execute(pool)
35        .await?;
36    } else {
37        sqlx::query(
38            "INSERT INTO rullst_audits (model_type, model_id, event, old_values, new_values) VALUES (?, ?, ?, ?, ?)"
39        )
40        .bind(model_type)
41        .bind(model_id)
42        .bind(event)
43        .bind(old_values)
44        .bind(new_values)
45        .execute(pool)
46        .await?;
47    }
48
49    Ok(())
50}
51
52pub fn compute_diff(old_json: &str, new_json: &str) -> (Option<String>, Option<String>) {
53    let old_val: serde_json::Value =
54        serde_json::from_str(old_json).unwrap_or(serde_json::Value::Null);
55    let new_val: serde_json::Value =
56        serde_json::from_str(new_json).unwrap_or(serde_json::Value::Null);
57
58    let mut diff_old = serde_json::Map::new();
59    let mut diff_new = serde_json::Map::new();
60
61    if let (serde_json::Value::Object(old_obj), serde_json::Value::Object(mut new_obj)) =
62        (old_val, new_val)
63    {
64        for (k, v) in old_obj {
65            if let Some(new_v) = new_obj.remove(&k) {
66                #[allow(clippy::collapsible_if)]
67                if v != new_v {
68                    diff_new.insert(k.clone(), new_v);
69                    diff_old.insert(k, v);
70                }
71            }
72        }
73    }
74
75    if diff_old.is_empty() && diff_new.is_empty() {
76        return (None, None); // Nothing changed
77    }
78
79    let final_old = serde_json::to_string(&diff_old).ok();
80    let final_new = serde_json::to_string(&diff_new).ok();
81
82    (final_old, final_new)
83}
84
85pub async fn log_audit_diff(
86    model_type: &str,
87    model_id: i32,
88    event: &str,
89    old_json: &str,
90    new_json: &str,
91) -> Result<(), crate::Error> {
92    let (final_old, final_new) = compute_diff(old_json, new_json);
93    if final_old.is_none() && final_new.is_none() {
94        return Ok(()); // Nothing changed
95    }
96    log_audit(model_type, model_id, event, final_old, final_new).await
97}
98
99pub async fn create_audit_table() -> Result<(), crate::Error> {
100    let pool = Orm::pool();
101    let driver = Orm::driver();
102
103    let query = if driver == "postgres" {
104        r#"
105        CREATE TABLE IF NOT EXISTS rullst_audits (
106            id SERIAL PRIMARY KEY,
107            model_type VARCHAR(255) NOT NULL,
108            model_id INT NOT NULL,
109            event VARCHAR(50) NOT NULL,
110            old_values TEXT,
111            new_values TEXT,
112            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
113        )
114        "#
115    } else if driver == "mysql" {
116        r#"
117        CREATE TABLE IF NOT EXISTS rullst_audits (
118            id INT AUTO_INCREMENT PRIMARY KEY,
119            model_type VARCHAR(255) NOT NULL,
120            model_id INT NOT NULL,
121            event VARCHAR(50) NOT NULL,
122            old_values TEXT,
123            new_values TEXT,
124            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
125        )
126        "#
127    } else {
128        r#"
129        CREATE TABLE IF NOT EXISTS rullst_audits (
130            id INTEGER PRIMARY KEY AUTOINCREMENT,
131            model_type TEXT NOT NULL,
132            model_id INTEGER NOT NULL,
133            event TEXT NOT NULL,
134            old_values TEXT,
135            new_values TEXT,
136            created_at DATETIME DEFAULT CURRENT_TIMESTAMP
137        )
138        "#
139    };
140
141    sqlx::query(query).execute(pool).await?;
142    Ok(())
143}
144
145#[cfg(test)]
146mod tests {
147    use super::AuditLog;
148
149    #[test]
150    fn test_audit_log_serialization_round_trip() {
151        let log = AuditLog {
152            id: 1,
153            model_type: "User".to_string(),
154            model_id: 42,
155            event: "created".to_string(),
156            old_values: None,
157            new_values: Some(r#"{"name":"Alice"}"#.to_string()),
158            created_at: Some("2024-01-01T00:00:00Z".to_string()),
159        };
160
161        let json_str = serde_json::to_string(&log).expect("serialize");
162        assert!(json_str.contains("\"model_type\":\"User\""));
163        assert!(json_str.contains("\"event\":\"created\""));
164
165        let deserialized: AuditLog = serde_json::from_str(&json_str).expect("deserialize");
166        assert_eq!(deserialized.id, 1);
167        assert_eq!(deserialized.model_id, 42);
168        assert_eq!(deserialized.event, "created");
169        assert!(deserialized.old_values.is_none());
170    }
171
172    #[test]
173    fn test_audit_log_clone_debug() {
174        let log = AuditLog {
175            id: 5,
176            model_type: "Post".to_string(),
177            model_id: 99,
178            event: "updated".to_string(),
179            old_values: Some(r#"{"title":"Old"}"#.to_string()),
180            new_values: Some(r#"{"title":"New"}"#.to_string()),
181            created_at: None,
182        };
183        let cloned = log.clone();
184        assert_eq!(cloned.model_type, "Post");
185        // Debug must not panic
186        let _ = format!("{:?}", cloned);
187    }
188
189    #[test]
190    fn test_compute_diff_changes() {
191        let old_json = r#"{"name":"Alice","age":30}"#;
192        let new_json = r#"{"name":"Alice","age":31}"#;
193        let (old_diff, new_diff) = super::compute_diff(old_json, new_json);
194        assert_eq!(old_diff.unwrap(), r#"{"age":30}"#);
195        assert_eq!(new_diff.unwrap(), r#"{"age":31}"#);
196    }
197
198    #[test]
199    fn test_compute_diff_no_changes() {
200        let json = r#"{"name":"Alice","age":30}"#;
201        let (old_diff, new_diff) = super::compute_diff(json, json);
202        assert!(old_diff.is_none());
203        assert!(new_diff.is_none());
204    }
205
206    #[test]
207    fn test_compute_diff_invalid_json() {
208        let (old_diff, new_diff) = super::compute_diff("not json", "{invalid}");
209        assert!(old_diff.is_none());
210        assert!(new_diff.is_none());
211    }
212
213    #[tokio::test]
214    async fn test_log_audit_diff_bypass() {
215        // Should not panic or hit the database if the old and new JSONs are identical
216        let result = super::log_audit_diff(
217            "User",
218            1,
219            "update",
220            r#"{"name":"Alice"}"#,
221            r#"{"name":"Alice"}"#,
222        )
223        .await;
224        assert!(result.is_ok());
225    }
226}