Skip to main content

crue_engine/
vm.rs

1//! Bytecode VM for compiled CRUE DSL rules (Phase 2 bootstrap).
2
3use crate::context::{EvaluationContext, FieldValue};
4use crate::decision::{ActionResult, Decision};
5use crate::error::EngineError;
6use crate::ir::ActionInstruction;
7use crue_dsl::compiler::{Bytecode, Constant, Opcode};
8
9#[derive(Debug, Clone, PartialEq)]
10enum VmValue {
11    Bool(bool),
12    Number(i64),
13    String(String),
14}
15
16/// Explicit VM instruction set used by the decoded execution path.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum Instruction {
19    LoadField(u16),
20    LoadConst(u32),
21    LoadTrue,
22    LoadFalse,
23    Gt,
24    Lt,
25    Gte,
26    Lte,
27    Eq,
28    Neq,
29    And,
30    Or,
31    Not,
32    JumpIfFalse(usize),
33    Jump(usize),
34    Ret,
35    EmitDecision(Decision),
36}
37
38/// VM exit value, allowing either a boolean gate or an emitted decision.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum VmExit {
41    Bool(bool),
42    Decision(Decision),
43}
44
45pub struct BytecodeVm;
46pub struct ActionVm;
47
48impl BytecodeVm {
49    /// Evaluate a compiled CRUE bytecode condition to a boolean decision gate.
50    pub fn eval(bytecode: &Bytecode, ctx: &EvaluationContext) -> Result<bool, EngineError> {
51        let program = Self::decode(bytecode)?;
52        match Self::eval_program(&program, bytecode, ctx)? {
53            VmExit::Bool(v) => Ok(v),
54            VmExit::Decision(_) => Err(EngineError::EvaluationError(
55                "VM emitted decision in boolean eval path".to_string(),
56            )),
57        }
58    }
59
60    /// Decode raw CRUE bytecode into an explicit instruction sequence.
61    pub fn decode(bytecode: &Bytecode) -> Result<Vec<Instruction>, EngineError> {
62        let mut pc = 0usize;
63        let code = &bytecode.instructions;
64        let mut program = Vec::new();
65
66        while pc < code.len() {
67            let op = decode_opcode(code[pc])?;
68            pc += 1;
69            match op {
70                Opcode::LoadField => {
71                    program.push(Instruction::LoadField(read_u16(code, &mut pc)?));
72                }
73                Opcode::LoadConst => {
74                    program.push(Instruction::LoadConst(read_u32(code, &mut pc)?));
75                }
76                Opcode::LoadTrue => program.push(Instruction::LoadTrue),
77                Opcode::LoadFalse => program.push(Instruction::LoadFalse),
78                Opcode::Gt => program.push(Instruction::Gt),
79                Opcode::Lt => program.push(Instruction::Lt),
80                Opcode::Gte => program.push(Instruction::Gte),
81                Opcode::Lte => program.push(Instruction::Lte),
82                Opcode::Eq => program.push(Instruction::Eq),
83                Opcode::Neq => program.push(Instruction::Neq),
84                Opcode::And => program.push(Instruction::And),
85                Opcode::Or => program.push(Instruction::Or),
86                Opcode::Not => program.push(Instruction::Not),
87                Opcode::Ret => program.push(Instruction::Ret),
88                Opcode::Jmp | Opcode::JmpF => {
89                    return Err(EngineError::EvaluationError(
90                        "Raw jump opcodes not supported in decoded VM yet".to_string(),
91                    ));
92                }
93            }
94        }
95
96        Ok(program)
97    }
98
99    /// Evaluate a decoded VM program against bytecode metadata/context.
100    pub fn eval_program(
101        program: &[Instruction],
102        bytecode: &Bytecode,
103        ctx: &EvaluationContext,
104    ) -> Result<VmExit, EngineError> {
105        let mut pc = 0usize;
106        let mut stack: Vec<VmValue> = Vec::new();
107
108        while pc < program.len() {
109            match &program[pc] {
110                Instruction::LoadField(idx) => {
111                    let field = bytecode.fields.get(*idx as usize).ok_or_else(|| {
112                        EngineError::EvaluationError("Invalid field index".to_string())
113                    })?;
114                    let value = ctx
115                        .get_field(field)
116                        .ok_or_else(|| EngineError::FieldNotFound(field.clone()))?;
117                    stack.push(field_to_vm(value)?);
118                    pc += 1;
119                }
120                Instruction::LoadConst(idx) => {
121                    let c = bytecode.constants.get(*idx as usize).ok_or_else(|| {
122                        EngineError::EvaluationError("Invalid constant index".to_string())
123                    })?;
124                    stack.push(constant_to_vm(c));
125                    pc += 1;
126                }
127                Instruction::LoadTrue => {
128                    stack.push(VmValue::Bool(true));
129                    pc += 1;
130                }
131                Instruction::LoadFalse => {
132                    stack.push(VmValue::Bool(false));
133                    pc += 1;
134                }
135                Instruction::Gt => {
136                    binary_compare(&mut stack, |a, b| a > b)?;
137                    pc += 1;
138                }
139                Instruction::Lt => {
140                    binary_compare(&mut stack, |a, b| a < b)?;
141                    pc += 1;
142                }
143                Instruction::Gte => {
144                    binary_compare(&mut stack, |a, b| a >= b)?;
145                    pc += 1;
146                }
147                Instruction::Lte => {
148                    binary_compare(&mut stack, |a, b| a <= b)?;
149                    pc += 1;
150                }
151                Instruction::Eq => {
152                    binary_eq(&mut stack, true)?;
153                    pc += 1;
154                }
155                Instruction::Neq => {
156                    binary_eq(&mut stack, false)?;
157                    pc += 1;
158                }
159                Instruction::And => {
160                    binary_bool(&mut stack, |a, b| a && b)?;
161                    pc += 1;
162                }
163                Instruction::Or => {
164                    binary_bool(&mut stack, |a, b| a || b)?;
165                    pc += 1;
166                }
167                Instruction::Not => {
168                    let v = pop_bool(&mut stack)?;
169                    stack.push(VmValue::Bool(!v));
170                    pc += 1;
171                }
172                Instruction::JumpIfFalse(target) => {
173                    let cond = pop_bool(&mut stack)?;
174                    if !cond {
175                        ensure_target(*target, program.len())?;
176                        pc = *target;
177                    } else {
178                        pc += 1;
179                    }
180                }
181                Instruction::Jump(target) => {
182                    ensure_target(*target, program.len())?;
183                    pc = *target;
184                }
185                Instruction::Ret => {
186                    return Ok(VmExit::Bool(pop_bool(&mut stack)?));
187                }
188                Instruction::EmitDecision(decision) => {
189                    return Ok(VmExit::Decision(*decision));
190                }
191            }
192        }
193
194        Err(EngineError::EvaluationError(
195            "VM program terminated without RET/EmitDecision".to_string(),
196        ))
197    }
198
199    /// Evaluate a compiled rule condition and emit a typed decision via VM instructions.
200    pub fn eval_decision(
201        bytecode: &Bytecode,
202        ctx: &EvaluationContext,
203        on_true: Decision,
204        on_false: Decision,
205    ) -> Result<Decision, EngineError> {
206        let mut program = Self::decode(bytecode)?;
207        if !matches!(program.last(), Some(Instruction::Ret)) {
208            return Err(EngineError::EvaluationError(
209                "Bytecode terminated without RET".to_string(),
210            ));
211        }
212        program.pop();
213
214        let false_target = program.len() + 2;
215        program.push(Instruction::JumpIfFalse(false_target));
216        program.push(Instruction::EmitDecision(on_true));
217        program.push(Instruction::EmitDecision(on_false));
218
219        match Self::eval_program(&program, bytecode, ctx)? {
220            VmExit::Decision(d) => Ok(d),
221            VmExit::Bool(_) => Err(EngineError::EvaluationError(
222                "VM returned bool in decision eval path".to_string(),
223            )),
224        }
225    }
226
227    /// Build a decoded VM program that emits `on_match` when the bytecode condition matches,
228    /// and returns `false` (boolean) when it does not match.
229    pub fn build_match_program(
230        bytecode: &Bytecode,
231        on_match: Decision,
232    ) -> Result<Vec<Instruction>, EngineError> {
233        let mut program = Self::decode(bytecode)?;
234        if !matches!(program.last(), Some(Instruction::Ret)) {
235            return Err(EngineError::EvaluationError(
236                "Bytecode terminated without RET".to_string(),
237            ));
238        }
239        program.pop();
240
241        // Stack holds the condition boolean at this point.
242        // false -> push false + RET (signals "no match")
243        // true  -> EmitDecision(on_match)
244        let false_target = program.len() + 2;
245        program.push(Instruction::JumpIfFalse(false_target));
246        program.push(Instruction::EmitDecision(on_match));
247        program.push(Instruction::LoadFalse);
248        program.push(Instruction::Ret);
249        Ok(program)
250    }
251
252    /// Evaluate a prebuilt match program and return:
253    /// - `Some(decision)` when rule matched and emitted a decision
254    /// - `None` when rule condition evaluated to false
255    pub fn eval_match_program(
256        program: &[Instruction],
257        bytecode: &Bytecode,
258        ctx: &EvaluationContext,
259    ) -> Result<Option<Decision>, EngineError> {
260        match Self::eval_program(program, bytecode, ctx)? {
261            VmExit::Decision(d) => Ok(Some(d)),
262            VmExit::Bool(false) => Ok(None),
263            VmExit::Bool(true) => Err(EngineError::EvaluationError(
264                "VM match program returned unexpected true boolean".to_string(),
265            )),
266        }
267    }
268}
269
270impl ActionVm {
271    /// Execute a compiled action program into a deterministic `ActionResult`.
272    pub fn execute(program: &[ActionInstruction]) -> Result<ActionResult, EngineError> {
273        let mut decision = Decision::Allow;
274        let mut error_code: Option<String> = None;
275        let mut message: Option<String> = None;
276        let mut approval_timeout: Option<u32> = None;
277        let mut alert_soc = false;
278
279        for insn in program {
280            match insn {
281                ActionInstruction::SetDecision(d) => decision = *d,
282                ActionInstruction::SetErrorCode(code) => error_code = Some(code.clone()),
283                ActionInstruction::SetMessage(msg) => message = Some(msg.clone()),
284                ActionInstruction::SetApprovalTimeout(timeout) => approval_timeout = Some(*timeout),
285                ActionInstruction::SetAlertSoc(v) => alert_soc = *v,
286                ActionInstruction::Halt => break,
287            }
288        }
289
290        let final_message = match decision {
291            Decision::ApprovalRequired => {
292                if let Some(m) = message {
293                    Some(m)
294                } else {
295                    Some(format!(
296                        "Approval required within {} minutes",
297                        approval_timeout.unwrap_or(30)
298                    ))
299                }
300            }
301            _ => message,
302        };
303
304        Ok(ActionResult {
305            decision,
306            error_code,
307            message: final_message,
308            alert_soc,
309        })
310    }
311}
312
313fn ensure_target(target: usize, len: usize) -> Result<(), EngineError> {
314    if target >= len {
315        return Err(EngineError::EvaluationError(format!(
316            "Invalid jump target {} (program len {})",
317            target, len
318        )));
319    }
320    Ok(())
321}
322
323fn decode_opcode(byte: u8) -> Result<Opcode, EngineError> {
324    let op = match byte {
325        0x01 => Opcode::LoadField,
326        0x02 => Opcode::LoadConst,
327        0x03 => Opcode::LoadTrue,
328        0x04 => Opcode::LoadFalse,
329        0x10 => Opcode::Gt,
330        0x11 => Opcode::Lt,
331        0x12 => Opcode::Gte,
332        0x13 => Opcode::Lte,
333        0x14 => Opcode::Eq,
334        0x15 => Opcode::Neq,
335        0x20 => Opcode::And,
336        0x21 => Opcode::Or,
337        0x22 => Opcode::Not,
338        0x30 => Opcode::JmpF,
339        0x31 => Opcode::Jmp,
340        0xFF => Opcode::Ret,
341        _ => {
342            return Err(EngineError::EvaluationError(format!(
343                "Unknown opcode 0x{byte:02x}"
344            )))
345        }
346    };
347    Ok(op)
348}
349
350fn read_u16(code: &[u8], pc: &mut usize) -> Result<u16, EngineError> {
351    if *pc + 2 > code.len() {
352        return Err(EngineError::EvaluationError("Truncated u16 operand".to_string()));
353    }
354    let v = u16::from_be_bytes([code[*pc], code[*pc + 1]]);
355    *pc += 2;
356    Ok(v)
357}
358
359fn read_u32(code: &[u8], pc: &mut usize) -> Result<u32, EngineError> {
360    if *pc + 4 > code.len() {
361        return Err(EngineError::EvaluationError("Truncated u32 operand".to_string()));
362    }
363    let v = u32::from_be_bytes([code[*pc], code[*pc + 1], code[*pc + 2], code[*pc + 3]]);
364    *pc += 4;
365    Ok(v)
366}
367
368fn constant_to_vm(c: &Constant) -> VmValue {
369    match c {
370        Constant::Number(n) => VmValue::Number(*n),
371        Constant::String(s) => VmValue::String(s.clone()),
372        Constant::Boolean(b) => VmValue::Bool(*b),
373    }
374}
375
376fn field_to_vm(v: &FieldValue) -> Result<VmValue, EngineError> {
377    match v {
378        FieldValue::Number(n) => Ok(VmValue::Number(*n)),
379        FieldValue::String(s) => Ok(VmValue::String(s.clone())),
380        FieldValue::Boolean(b) => Ok(VmValue::Bool(*b)),
381        FieldValue::Float(_) => Err(EngineError::TypeMismatch("float field unsupported in VM".into())),
382    }
383}
384
385fn pop(stack: &mut Vec<VmValue>) -> Result<VmValue, EngineError> {
386    stack.pop()
387        .ok_or_else(|| EngineError::EvaluationError("VM stack underflow".to_string()))
388}
389
390fn pop_bool(stack: &mut Vec<VmValue>) -> Result<bool, EngineError> {
391    match pop(stack)? {
392        VmValue::Bool(v) => Ok(v),
393        _ => Err(EngineError::TypeMismatch("Expected bool".to_string())),
394    }
395}
396
397fn pop_number(stack: &mut Vec<VmValue>) -> Result<i64, EngineError> {
398    match pop(stack)? {
399        VmValue::Number(v) => Ok(v),
400        _ => Err(EngineError::TypeMismatch("Expected number".to_string())),
401    }
402}
403
404fn binary_compare(
405    stack: &mut Vec<VmValue>,
406    cmp: impl Fn(i64, i64) -> bool,
407) -> Result<(), EngineError> {
408    let right = pop_number(stack)?;
409    let left = pop_number(stack)?;
410    stack.push(VmValue::Bool(cmp(left, right)));
411    Ok(())
412}
413
414fn binary_bool(
415    stack: &mut Vec<VmValue>,
416    op: impl Fn(bool, bool) -> bool,
417) -> Result<(), EngineError> {
418    let right = pop_bool(stack)?;
419    let left = pop_bool(stack)?;
420    stack.push(VmValue::Bool(op(left, right)));
421    Ok(())
422}
423
424fn binary_eq(stack: &mut Vec<VmValue>, eq: bool) -> Result<(), EngineError> {
425    let right = pop(stack)?;
426    let left = pop(stack)?;
427    let result = match (left, right) {
428        (VmValue::Bool(a), VmValue::Bool(b)) => a == b,
429        (VmValue::Number(a), VmValue::Number(b)) => a == b,
430        (VmValue::String(a), VmValue::String(b)) => a == b,
431        _ => return Err(EngineError::TypeMismatch("Incompatible equality operands".to_string())),
432    };
433    stack.push(VmValue::Bool(if eq { result } else { !result }));
434    Ok(())
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::ir::ActionInstruction;
441    use crate::EvaluationRequest;
442
443    #[test]
444    fn test_vm_eval_compiled_rule() {
445        let src = r#"
446RULE CRUE_001 VERSION 1.0
447WHEN
448    agent.requests_last_hour >= 50
449THEN
450    BLOCK WITH CODE "VOLUME_EXCEEDED"
451"#;
452        let ast = crue_dsl::parser::parse(src).unwrap();
453        let bytecode = crue_dsl::compiler::Compiler::compile(&ast).unwrap();
454
455        let req = EvaluationRequest {
456            request_id: "req".into(),
457            agent_id: "a".into(),
458            agent_org: "o".into(),
459            agent_level: "standard".into(),
460            mission_id: None,
461            mission_type: None,
462            query_type: None,
463            justification: Some("demo justification".into()),
464            export_format: None,
465            result_limit: Some(1),
466            requests_last_hour: 60,
467            requests_last_24h: 100,
468            results_last_query: 1,
469            account_department: None,
470            allowed_departments: vec![],
471            request_hour: 10,
472            is_within_mission_hours: true,
473        };
474        let ctx = EvaluationContext::from_request(&req);
475        assert!(BytecodeVm::eval(&bytecode, &ctx).unwrap());
476    }
477
478    #[test]
479    fn test_vm_eval_decision_emits_decision() {
480        let src = r#"
481RULE CRUE_001 VERSION 1.0
482WHEN
483    agent.requests_last_hour >= 50
484THEN
485    BLOCK WITH CODE "VOLUME_EXCEEDED"
486"#;
487        let ast = crue_dsl::parser::parse(src).unwrap();
488        let bytecode = crue_dsl::compiler::Compiler::compile(&ast).unwrap();
489        let mut req = EvaluationRequest {
490            request_id: "req".into(),
491            agent_id: "a".into(),
492            agent_org: "o".into(),
493            agent_level: "standard".into(),
494            mission_id: None,
495            mission_type: None,
496            query_type: None,
497            justification: Some("demo justification".into()),
498            export_format: None,
499            result_limit: Some(1),
500            requests_last_hour: 60,
501            requests_last_24h: 100,
502            results_last_query: 1,
503            account_department: None,
504            allowed_departments: vec![],
505            request_hour: 10,
506            is_within_mission_hours: true,
507        };
508        let ctx = EvaluationContext::from_request(&req);
509        assert_eq!(
510            BytecodeVm::eval_decision(&bytecode, &ctx, Decision::Block, Decision::Allow).unwrap(),
511            Decision::Block
512        );
513
514        req.requests_last_hour = 1;
515        let ctx2 = EvaluationContext::from_request(&req);
516        assert_eq!(
517            BytecodeVm::eval_decision(&bytecode, &ctx2, Decision::Block, Decision::Allow).unwrap(),
518            Decision::Allow
519        );
520    }
521
522    #[test]
523    fn test_vm_explicit_jump_and_emit_program() {
524        let bytecode = Bytecode {
525            instructions: vec![],
526            constants: vec![],
527            fields: vec![],
528            action_instructions: vec![],
529        };
530        let req = EvaluationRequest {
531            request_id: "req".into(),
532            agent_id: "a".into(),
533            agent_org: "o".into(),
534            agent_level: "standard".into(),
535            mission_id: None,
536            mission_type: None,
537            query_type: None,
538            justification: None,
539            export_format: None,
540            result_limit: None,
541            requests_last_hour: 0,
542            requests_last_24h: 0,
543            results_last_query: 0,
544            account_department: None,
545            allowed_departments: vec![],
546            request_hour: 0,
547            is_within_mission_hours: true,
548        };
549        let ctx = EvaluationContext::from_request(&req);
550        let program = vec![
551            Instruction::LoadFalse,
552            Instruction::JumpIfFalse(3),
553            Instruction::EmitDecision(Decision::Block),
554            Instruction::EmitDecision(Decision::Allow),
555        ];
556        assert_eq!(
557            BytecodeVm::eval_program(&program, &bytecode, &ctx).unwrap(),
558            VmExit::Decision(Decision::Allow)
559        );
560    }
561
562    #[test]
563    fn test_action_vm_exec_block_with_soc_alert() {
564        let program = vec![
565            ActionInstruction::SetDecision(Decision::Block),
566            ActionInstruction::SetErrorCode("VOLUME_EXCEEDED".into()),
567            ActionInstruction::SetMessage("Demo policy matched".into()),
568            ActionInstruction::SetAlertSoc(true),
569            ActionInstruction::Halt,
570        ];
571        let result = ActionVm::execute(&program).unwrap();
572        assert_eq!(result.decision, Decision::Block);
573        assert_eq!(result.error_code.as_deref(), Some("VOLUME_EXCEEDED"));
574        assert_eq!(result.message.as_deref(), Some("Demo policy matched"));
575        assert!(result.alert_soc);
576    }
577
578    #[test]
579    fn test_action_vm_exec_approval_default_message() {
580        let program = vec![
581            ActionInstruction::SetDecision(Decision::ApprovalRequired),
582            ActionInstruction::SetErrorCode("APPROVAL_REQUIRED".into()),
583            ActionInstruction::SetApprovalTimeout(15),
584            ActionInstruction::Halt,
585        ];
586        let result = ActionVm::execute(&program).unwrap();
587        assert_eq!(result.decision, Decision::ApprovalRequired);
588        assert_eq!(
589            result.message.as_deref(),
590            Some("Approval required within 15 minutes")
591        );
592    }
593}