Skip to main content

legalis_diff/
scripting.rs

1//! Scripting support for custom diff analysis using Rhai.
2//!
3//! This module provides a scripting interface for customizing diff analysis behavior
4//! using the Rhai scripting language. Scripts can define custom analyzers, validators,
5//! and transformations.
6//!
7//! # Example
8//!
9//! ```
10//! use legalis_diff::scripting::ScriptEngine;
11//!
12//! let mut engine = ScriptEngine::new();
13//!
14//! // Load a script
15//! let script = r#"
16//!     fn analyze_diff(diff) {
17//!         let change_count = diff.changes.len();
18//!         if change_count > 10 {
19//!             return #{
20//!                 severity: "high",
21//!                 message: "Too many changes"
22//!             };
23//!         }
24//!         return #{
25//!             severity: "low",
26//!             message: "Normal change count"
27//!         };
28//!     }
29//! "#;
30//!
31//! engine.load_script("analyzer", script).unwrap();
32//! ```
33
34use crate::StatuteDiff;
35use crate::plugins::{AnalysisResult, Finding, FindingSeverity, PluginError};
36use rhai::{AST, Dynamic, Engine, Map, Scope};
37use std::collections::HashMap;
38use std::sync::{Arc, RwLock};
39
40/// A scripting engine for custom diff analysis.
41pub struct ScriptEngine {
42    engine: Engine,
43    scripts: Arc<RwLock<HashMap<String, AST>>>,
44}
45
46impl ScriptEngine {
47    /// Creates a new script engine.
48    #[must_use]
49    pub fn new() -> Self {
50        let mut engine = Engine::new();
51
52        // Register custom types
53        engine.register_type::<ScriptDiff>();
54        engine.register_type::<ScriptChange>();
55        engine.register_type::<ScriptImpact>();
56
57        // Register getters for ScriptDiff
58        engine.register_get("change_count", |diff: &mut ScriptDiff| diff.change_count);
59        engine.register_get("statute_id", |diff: &mut ScriptDiff| {
60            diff.statute_id.clone()
61        });
62        engine.register_get("changes", |diff: &mut ScriptDiff| diff.changes.clone());
63        engine.register_get("impact", |diff: &mut ScriptDiff| diff.impact.clone());
64
65        // Register getters for ScriptChange
66        engine.register_get("change_type", |change: &mut ScriptChange| {
67            change.change_type.clone()
68        });
69        engine.register_get("target", |change: &mut ScriptChange| change.target.clone());
70        engine.register_get("description", |change: &mut ScriptChange| {
71            change.description.clone()
72        });
73
74        // Register getters for ScriptImpact
75        engine.register_get("severity", |impact: &mut ScriptImpact| {
76            impact.severity.clone()
77        });
78        engine.register_get("affects_eligibility", |impact: &mut ScriptImpact| {
79            impact.affects_eligibility
80        });
81        engine.register_get("affects_outcome", |impact: &mut ScriptImpact| {
82            impact.affects_outcome
83        });
84        engine.register_get("discretion_changed", |impact: &mut ScriptImpact| {
85            impact.discretion_changed
86        });
87
88        // Register helper functions
89        engine.register_fn("create_finding", create_finding);
90        engine.register_fn("create_result", create_analysis_result);
91
92        Self {
93            engine,
94            scripts: Arc::new(RwLock::new(HashMap::new())),
95        }
96    }
97
98    /// Loads a script into the engine.
99    pub fn load_script(&mut self, name: &str, script: &str) -> Result<(), PluginError> {
100        let ast = self.engine.compile(script).map_err(|e| {
101            PluginError::InitializationFailed(format!("Script compilation failed: {e}"))
102        })?;
103
104        let mut scripts = self.scripts.write().unwrap();
105        scripts.insert(name.to_string(), ast);
106
107        Ok(())
108    }
109
110    /// Executes a script function with a diff.
111    pub fn execute(
112        &self,
113        script_name: &str,
114        function_name: &str,
115        diff: &StatuteDiff,
116    ) -> Result<AnalysisResult, PluginError> {
117        let scripts = self.scripts.read().unwrap();
118        let ast = scripts
119            .get(script_name)
120            .ok_or_else(|| PluginError::NotFound(script_name.to_string()))?;
121
122        let script_diff = ScriptDiff::from_statute_diff(diff);
123
124        let mut scope = Scope::new();
125
126        let result: Dynamic = self
127            .engine
128            .call_fn(&mut scope, ast, function_name, (script_diff,))
129            .map_err(|e| PluginError::ExecutionFailed(format!("Script execution failed: {e}")))?;
130
131        // Convert result to AnalysisResult
132        self.convert_result(result, script_name)
133    }
134
135    /// Evaluates an expression with a diff.
136    pub fn evaluate(&self, script_name: &str, diff: &StatuteDiff) -> Result<Dynamic, PluginError> {
137        let scripts = self.scripts.read().unwrap();
138        let ast = scripts
139            .get(script_name)
140            .ok_or_else(|| PluginError::NotFound(script_name.to_string()))?;
141
142        let script_diff = ScriptDiff::from_statute_diff(diff);
143
144        let mut scope = Scope::new();
145        scope.push("diff", script_diff);
146
147        self.engine
148            .eval_ast_with_scope(&mut scope, ast)
149            .map_err(|e| PluginError::ExecutionFailed(format!("Evaluation failed: {e}")))
150    }
151
152    /// Converts a dynamic result to an AnalysisResult.
153    fn convert_result(
154        &self,
155        result: Dynamic,
156        script_name: &str,
157    ) -> Result<AnalysisResult, PluginError> {
158        if let Some(map) = result.try_cast::<Map>() {
159            let mut findings = Vec::new();
160            let mut metadata = HashMap::new();
161
162            if let Some(severity) = map
163                .get("severity")
164                .and_then(|v| v.clone().try_cast::<String>())
165                && let Some(message) = map
166                    .get("message")
167                    .and_then(|v| v.clone().try_cast::<String>())
168            {
169                let finding_severity = match severity.to_lowercase().as_str() {
170                    "critical" => FindingSeverity::Critical,
171                    "high" => FindingSeverity::High,
172                    "medium" => FindingSeverity::Medium,
173                    "low" => FindingSeverity::Low,
174                    _ => FindingSeverity::Info,
175                };
176
177                findings.push(Finding {
178                    severity: finding_severity,
179                    category: script_name.to_string(),
180                    message,
181                    location: None,
182                    suggestion: None,
183                });
184            }
185
186            if let Some(meta) = map
187                .get("metadata")
188                .and_then(|v| v.clone().try_cast::<Map>())
189            {
190                for (key, value) in meta {
191                    metadata.insert(key.to_string(), value.to_string());
192                }
193            }
194
195            Ok(AnalysisResult {
196                plugin_name: format!("script:{script_name}"),
197                findings,
198                confidence: 0.85,
199                metadata,
200            })
201        } else {
202            Err(PluginError::ExecutionFailed(
203                "Script must return a map with severity and message".to_string(),
204            ))
205        }
206    }
207}
208
209impl Default for ScriptEngine {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215/// Script-friendly representation of a diff.
216#[derive(Debug, Clone)]
217pub struct ScriptDiff {
218    pub statute_id: String,
219    pub changes: Vec<ScriptChange>,
220    pub impact: ScriptImpact,
221    pub change_count: i64,
222}
223
224impl ScriptDiff {
225    fn from_statute_diff(diff: &StatuteDiff) -> Self {
226        Self {
227            statute_id: diff.statute_id.clone(),
228            changes: diff.changes.iter().map(ScriptChange::from).collect(),
229            impact: ScriptImpact::from(&diff.impact),
230            change_count: diff.changes.len() as i64,
231        }
232    }
233}
234
235/// Script-friendly representation of a change.
236#[derive(Debug, Clone)]
237pub struct ScriptChange {
238    pub change_type: String,
239    pub target: String,
240    pub description: String,
241}
242
243impl From<&crate::Change> for ScriptChange {
244    fn from(change: &crate::Change) -> Self {
245        Self {
246            change_type: format!("{:?}", change.change_type),
247            target: change.target.to_string(),
248            description: change.description.clone(),
249        }
250    }
251}
252
253/// Script-friendly representation of impact.
254#[derive(Debug, Clone)]
255pub struct ScriptImpact {
256    pub severity: String,
257    pub affects_eligibility: bool,
258    pub affects_outcome: bool,
259    pub discretion_changed: bool,
260}
261
262impl From<&crate::ImpactAssessment> for ScriptImpact {
263    fn from(impact: &crate::ImpactAssessment) -> Self {
264        Self {
265            severity: format!("{:?}", impact.severity),
266            affects_eligibility: impact.affects_eligibility,
267            affects_outcome: impact.affects_outcome,
268            discretion_changed: impact.discretion_changed,
269        }
270    }
271}
272
273/// Creates a finding from script.
274#[allow(dead_code)]
275fn create_finding(severity: String, category: String, message: String) -> Map {
276    let mut map = Map::new();
277    map.insert("severity".into(), severity.into());
278    map.insert("category".into(), category.into());
279    map.insert("message".into(), message.into());
280    map
281}
282
283/// Creates an analysis result from script.
284#[allow(dead_code)]
285fn create_analysis_result(findings: Vec<Map>, metadata: Map) -> Map {
286    let mut result = Map::new();
287    result.insert("findings".into(), findings.into());
288    result.insert("metadata".into(), metadata.into());
289    result
290}
291
292/// A script-based diff analyzer plugin.
293pub struct ScriptAnalyzer {
294    engine: ScriptEngine,
295    script_name: String,
296    function_name: String,
297}
298
299impl ScriptAnalyzer {
300    /// Creates a new script analyzer.
301    pub fn new(script: &str, function_name: &str) -> Result<Self, PluginError> {
302        let mut engine = ScriptEngine::new();
303        let script_name = "analyzer".to_string();
304        engine.load_script(&script_name, script)?;
305
306        Ok(Self {
307            engine,
308            script_name,
309            function_name: function_name.to_string(),
310        })
311    }
312
313    /// Analyzes a diff using the script.
314    pub fn analyze(&self, diff: &StatuteDiff) -> Result<AnalysisResult, PluginError> {
315        self.engine
316            .execute(&self.script_name, &self.function_name, diff)
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::{Change, ChangeTarget, ChangeType, ImpactAssessment, Severity};
324
325    #[test]
326    fn test_script_engine_basic() {
327        let mut engine = ScriptEngine::new();
328
329        let script = r#"
330            fn analyze(diff) {
331                #{
332                    severity: "high",
333                    message: "Test message"
334                }
335            }
336        "#;
337
338        engine.load_script("test", script).unwrap();
339
340        let diff = create_test_diff();
341        let result = engine.execute("test", "analyze", &diff).unwrap();
342
343        assert_eq!(result.findings.len(), 1);
344        assert_eq!(result.findings[0].message, "Test message");
345    }
346
347    #[test]
348    fn test_script_analyzer() {
349        let script = r#"
350            fn analyze_changes(diff) {
351                let count = diff.change_count;
352                if count > 5 {
353                    #{
354                        severity: "high",
355                        message: "Too many changes"
356                    }
357                } else {
358                    #{
359                        severity: "low",
360                        message: "Normal change count"
361                    }
362                }
363            }
364        "#;
365
366        let analyzer = ScriptAnalyzer::new(script, "analyze_changes").unwrap();
367        let diff = create_test_diff();
368        let result = analyzer.analyze(&diff).unwrap();
369
370        assert_eq!(result.findings.len(), 1);
371    }
372
373    #[allow(dead_code)]
374    fn create_test_diff() -> StatuteDiff {
375        StatuteDiff {
376            statute_id: "test-123".to_string(),
377            version_info: None,
378            changes: vec![
379                Change {
380                    change_type: ChangeType::Modified,
381                    target: ChangeTarget::Precondition { index: 0 },
382                    description: "Age changed".to_string(),
383                    old_value: Some("65".to_string()),
384                    new_value: Some("60".to_string()),
385                },
386                Change {
387                    change_type: ChangeType::Added,
388                    target: ChangeTarget::Precondition { index: 1 },
389                    description: "Income requirement added".to_string(),
390                    old_value: None,
391                    new_value: Some("50000".to_string()),
392                },
393            ],
394            impact: ImpactAssessment {
395                severity: Severity::Moderate,
396                affects_eligibility: true,
397                affects_outcome: false,
398                discretion_changed: false,
399                notes: vec!["Test impact".to_string()],
400            },
401        }
402    }
403}