1use crate::context::{EvaluationContext, FieldValue};
4use crate::decision::{ActionResult, Decision};
5use crate::error::EngineError;
6use crate::ir::ActionInstruction;
7use crue_dsl::compiler::{Bytecode, Constant, Opcode};
8
9#[derive(Debug, Clone, PartialEq)]
10enum VmValue {
11 Bool(bool),
12 Number(i64),
13 String(String),
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum Instruction {
19 LoadField(u16),
20 LoadConst(u32),
21 LoadTrue,
22 LoadFalse,
23 Gt,
24 Lt,
25 Gte,
26 Lte,
27 Eq,
28 Neq,
29 And,
30 Or,
31 Not,
32 JumpIfFalse(usize),
33 Jump(usize),
34 Ret,
35 EmitDecision(Decision),
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum VmExit {
41 Bool(bool),
42 Decision(Decision),
43}
44
45pub struct BytecodeVm;
46pub struct ActionVm;
47
48impl BytecodeVm {
49 pub fn eval(bytecode: &Bytecode, ctx: &EvaluationContext) -> Result<bool, EngineError> {
51 let program = Self::decode(bytecode)?;
52 match Self::eval_program(&program, bytecode, ctx)? {
53 VmExit::Bool(v) => Ok(v),
54 VmExit::Decision(_) => Err(EngineError::EvaluationError(
55 "VM emitted decision in boolean eval path".to_string(),
56 )),
57 }
58 }
59
60 pub fn decode(bytecode: &Bytecode) -> Result<Vec<Instruction>, EngineError> {
62 let mut pc = 0usize;
63 let code = &bytecode.instructions;
64 let mut program = Vec::new();
65
66 while pc < code.len() {
67 let op = decode_opcode(code[pc])?;
68 pc += 1;
69 match op {
70 Opcode::LoadField => {
71 program.push(Instruction::LoadField(read_u16(code, &mut pc)?));
72 }
73 Opcode::LoadConst => {
74 program.push(Instruction::LoadConst(read_u32(code, &mut pc)?));
75 }
76 Opcode::LoadTrue => program.push(Instruction::LoadTrue),
77 Opcode::LoadFalse => program.push(Instruction::LoadFalse),
78 Opcode::Gt => program.push(Instruction::Gt),
79 Opcode::Lt => program.push(Instruction::Lt),
80 Opcode::Gte => program.push(Instruction::Gte),
81 Opcode::Lte => program.push(Instruction::Lte),
82 Opcode::Eq => program.push(Instruction::Eq),
83 Opcode::Neq => program.push(Instruction::Neq),
84 Opcode::And => program.push(Instruction::And),
85 Opcode::Or => program.push(Instruction::Or),
86 Opcode::Not => program.push(Instruction::Not),
87 Opcode::Ret => program.push(Instruction::Ret),
88 Opcode::Jmp | Opcode::JmpF => {
89 return Err(EngineError::EvaluationError(
90 "Raw jump opcodes not supported in decoded VM yet".to_string(),
91 ));
92 }
93 }
94 }
95
96 Ok(program)
97 }
98
99 pub fn eval_program(
101 program: &[Instruction],
102 bytecode: &Bytecode,
103 ctx: &EvaluationContext,
104 ) -> Result<VmExit, EngineError> {
105 let mut pc = 0usize;
106 let mut stack: Vec<VmValue> = Vec::new();
107
108 while pc < program.len() {
109 match &program[pc] {
110 Instruction::LoadField(idx) => {
111 let field = bytecode.fields.get(*idx as usize).ok_or_else(|| {
112 EngineError::EvaluationError("Invalid field index".to_string())
113 })?;
114 let value = ctx
115 .get_field(field)
116 .ok_or_else(|| EngineError::FieldNotFound(field.clone()))?;
117 stack.push(field_to_vm(value)?);
118 pc += 1;
119 }
120 Instruction::LoadConst(idx) => {
121 let c = bytecode.constants.get(*idx as usize).ok_or_else(|| {
122 EngineError::EvaluationError("Invalid constant index".to_string())
123 })?;
124 stack.push(constant_to_vm(c));
125 pc += 1;
126 }
127 Instruction::LoadTrue => {
128 stack.push(VmValue::Bool(true));
129 pc += 1;
130 }
131 Instruction::LoadFalse => {
132 stack.push(VmValue::Bool(false));
133 pc += 1;
134 }
135 Instruction::Gt => {
136 binary_compare(&mut stack, |a, b| a > b)?;
137 pc += 1;
138 }
139 Instruction::Lt => {
140 binary_compare(&mut stack, |a, b| a < b)?;
141 pc += 1;
142 }
143 Instruction::Gte => {
144 binary_compare(&mut stack, |a, b| a >= b)?;
145 pc += 1;
146 }
147 Instruction::Lte => {
148 binary_compare(&mut stack, |a, b| a <= b)?;
149 pc += 1;
150 }
151 Instruction::Eq => {
152 binary_eq(&mut stack, true)?;
153 pc += 1;
154 }
155 Instruction::Neq => {
156 binary_eq(&mut stack, false)?;
157 pc += 1;
158 }
159 Instruction::And => {
160 binary_bool(&mut stack, |a, b| a && b)?;
161 pc += 1;
162 }
163 Instruction::Or => {
164 binary_bool(&mut stack, |a, b| a || b)?;
165 pc += 1;
166 }
167 Instruction::Not => {
168 let v = pop_bool(&mut stack)?;
169 stack.push(VmValue::Bool(!v));
170 pc += 1;
171 }
172 Instruction::JumpIfFalse(target) => {
173 let cond = pop_bool(&mut stack)?;
174 if !cond {
175 ensure_target(*target, program.len())?;
176 pc = *target;
177 } else {
178 pc += 1;
179 }
180 }
181 Instruction::Jump(target) => {
182 ensure_target(*target, program.len())?;
183 pc = *target;
184 }
185 Instruction::Ret => {
186 return Ok(VmExit::Bool(pop_bool(&mut stack)?));
187 }
188 Instruction::EmitDecision(decision) => {
189 return Ok(VmExit::Decision(*decision));
190 }
191 }
192 }
193
194 Err(EngineError::EvaluationError(
195 "VM program terminated without RET/EmitDecision".to_string(),
196 ))
197 }
198
199 pub fn eval_decision(
201 bytecode: &Bytecode,
202 ctx: &EvaluationContext,
203 on_true: Decision,
204 on_false: Decision,
205 ) -> Result<Decision, EngineError> {
206 let mut program = Self::decode(bytecode)?;
207 if !matches!(program.last(), Some(Instruction::Ret)) {
208 return Err(EngineError::EvaluationError(
209 "Bytecode terminated without RET".to_string(),
210 ));
211 }
212 program.pop();
213
214 let false_target = program.len() + 2;
215 program.push(Instruction::JumpIfFalse(false_target));
216 program.push(Instruction::EmitDecision(on_true));
217 program.push(Instruction::EmitDecision(on_false));
218
219 match Self::eval_program(&program, bytecode, ctx)? {
220 VmExit::Decision(d) => Ok(d),
221 VmExit::Bool(_) => Err(EngineError::EvaluationError(
222 "VM returned bool in decision eval path".to_string(),
223 )),
224 }
225 }
226
227 pub fn build_match_program(
230 bytecode: &Bytecode,
231 on_match: Decision,
232 ) -> Result<Vec<Instruction>, EngineError> {
233 let mut program = Self::decode(bytecode)?;
234 if !matches!(program.last(), Some(Instruction::Ret)) {
235 return Err(EngineError::EvaluationError(
236 "Bytecode terminated without RET".to_string(),
237 ));
238 }
239 program.pop();
240
241 let false_target = program.len() + 2;
245 program.push(Instruction::JumpIfFalse(false_target));
246 program.push(Instruction::EmitDecision(on_match));
247 program.push(Instruction::LoadFalse);
248 program.push(Instruction::Ret);
249 Ok(program)
250 }
251
252 pub fn eval_match_program(
256 program: &[Instruction],
257 bytecode: &Bytecode,
258 ctx: &EvaluationContext,
259 ) -> Result<Option<Decision>, EngineError> {
260 match Self::eval_program(program, bytecode, ctx)? {
261 VmExit::Decision(d) => Ok(Some(d)),
262 VmExit::Bool(false) => Ok(None),
263 VmExit::Bool(true) => Err(EngineError::EvaluationError(
264 "VM match program returned unexpected true boolean".to_string(),
265 )),
266 }
267 }
268}
269
270impl ActionVm {
271 pub fn execute(program: &[ActionInstruction]) -> Result<ActionResult, EngineError> {
273 let mut decision = Decision::Allow;
274 let mut error_code: Option<String> = None;
275 let mut message: Option<String> = None;
276 let mut approval_timeout: Option<u32> = None;
277 let mut alert_soc = false;
278
279 for insn in program {
280 match insn {
281 ActionInstruction::SetDecision(d) => decision = *d,
282 ActionInstruction::SetErrorCode(code) => error_code = Some(code.clone()),
283 ActionInstruction::SetMessage(msg) => message = Some(msg.clone()),
284 ActionInstruction::SetApprovalTimeout(timeout) => approval_timeout = Some(*timeout),
285 ActionInstruction::SetAlertSoc(v) => alert_soc = *v,
286 ActionInstruction::Halt => break,
287 }
288 }
289
290 let final_message = match decision {
291 Decision::ApprovalRequired => {
292 if let Some(m) = message {
293 Some(m)
294 } else {
295 Some(format!(
296 "Approval required within {} minutes",
297 approval_timeout.unwrap_or(30)
298 ))
299 }
300 }
301 _ => message,
302 };
303
304 Ok(ActionResult {
305 decision,
306 error_code,
307 message: final_message,
308 alert_soc,
309 })
310 }
311}
312
313fn ensure_target(target: usize, len: usize) -> Result<(), EngineError> {
314 if target >= len {
315 return Err(EngineError::EvaluationError(format!(
316 "Invalid jump target {} (program len {})",
317 target, len
318 )));
319 }
320 Ok(())
321}
322
323fn decode_opcode(byte: u8) -> Result<Opcode, EngineError> {
324 let op = match byte {
325 0x01 => Opcode::LoadField,
326 0x02 => Opcode::LoadConst,
327 0x03 => Opcode::LoadTrue,
328 0x04 => Opcode::LoadFalse,
329 0x10 => Opcode::Gt,
330 0x11 => Opcode::Lt,
331 0x12 => Opcode::Gte,
332 0x13 => Opcode::Lte,
333 0x14 => Opcode::Eq,
334 0x15 => Opcode::Neq,
335 0x20 => Opcode::And,
336 0x21 => Opcode::Or,
337 0x22 => Opcode::Not,
338 0x30 => Opcode::JmpF,
339 0x31 => Opcode::Jmp,
340 0xFF => Opcode::Ret,
341 _ => {
342 return Err(EngineError::EvaluationError(format!(
343 "Unknown opcode 0x{byte:02x}"
344 )))
345 }
346 };
347 Ok(op)
348}
349
350fn read_u16(code: &[u8], pc: &mut usize) -> Result<u16, EngineError> {
351 if *pc + 2 > code.len() {
352 return Err(EngineError::EvaluationError("Truncated u16 operand".to_string()));
353 }
354 let v = u16::from_be_bytes([code[*pc], code[*pc + 1]]);
355 *pc += 2;
356 Ok(v)
357}
358
359fn read_u32(code: &[u8], pc: &mut usize) -> Result<u32, EngineError> {
360 if *pc + 4 > code.len() {
361 return Err(EngineError::EvaluationError("Truncated u32 operand".to_string()));
362 }
363 let v = u32::from_be_bytes([code[*pc], code[*pc + 1], code[*pc + 2], code[*pc + 3]]);
364 *pc += 4;
365 Ok(v)
366}
367
368fn constant_to_vm(c: &Constant) -> VmValue {
369 match c {
370 Constant::Number(n) => VmValue::Number(*n),
371 Constant::String(s) => VmValue::String(s.clone()),
372 Constant::Boolean(b) => VmValue::Bool(*b),
373 }
374}
375
376fn field_to_vm(v: &FieldValue) -> Result<VmValue, EngineError> {
377 match v {
378 FieldValue::Number(n) => Ok(VmValue::Number(*n)),
379 FieldValue::String(s) => Ok(VmValue::String(s.clone())),
380 FieldValue::Boolean(b) => Ok(VmValue::Bool(*b)),
381 FieldValue::Float(_) => Err(EngineError::TypeMismatch("float field unsupported in VM".into())),
382 }
383}
384
385fn pop(stack: &mut Vec<VmValue>) -> Result<VmValue, EngineError> {
386 stack.pop()
387 .ok_or_else(|| EngineError::EvaluationError("VM stack underflow".to_string()))
388}
389
390fn pop_bool(stack: &mut Vec<VmValue>) -> Result<bool, EngineError> {
391 match pop(stack)? {
392 VmValue::Bool(v) => Ok(v),
393 _ => Err(EngineError::TypeMismatch("Expected bool".to_string())),
394 }
395}
396
397fn pop_number(stack: &mut Vec<VmValue>) -> Result<i64, EngineError> {
398 match pop(stack)? {
399 VmValue::Number(v) => Ok(v),
400 _ => Err(EngineError::TypeMismatch("Expected number".to_string())),
401 }
402}
403
404fn binary_compare(
405 stack: &mut Vec<VmValue>,
406 cmp: impl Fn(i64, i64) -> bool,
407) -> Result<(), EngineError> {
408 let right = pop_number(stack)?;
409 let left = pop_number(stack)?;
410 stack.push(VmValue::Bool(cmp(left, right)));
411 Ok(())
412}
413
414fn binary_bool(
415 stack: &mut Vec<VmValue>,
416 op: impl Fn(bool, bool) -> bool,
417) -> Result<(), EngineError> {
418 let right = pop_bool(stack)?;
419 let left = pop_bool(stack)?;
420 stack.push(VmValue::Bool(op(left, right)));
421 Ok(())
422}
423
424fn binary_eq(stack: &mut Vec<VmValue>, eq: bool) -> Result<(), EngineError> {
425 let right = pop(stack)?;
426 let left = pop(stack)?;
427 let result = match (left, right) {
428 (VmValue::Bool(a), VmValue::Bool(b)) => a == b,
429 (VmValue::Number(a), VmValue::Number(b)) => a == b,
430 (VmValue::String(a), VmValue::String(b)) => a == b,
431 _ => return Err(EngineError::TypeMismatch("Incompatible equality operands".to_string())),
432 };
433 stack.push(VmValue::Bool(if eq { result } else { !result }));
434 Ok(())
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::ir::ActionInstruction;
441 use crate::EvaluationRequest;
442
443 #[test]
444 fn test_vm_eval_compiled_rule() {
445 let src = r#"
446RULE CRUE_001 VERSION 1.0
447WHEN
448 agent.requests_last_hour >= 50
449THEN
450 BLOCK WITH CODE "VOLUME_EXCEEDED"
451"#;
452 let ast = crue_dsl::parser::parse(src).unwrap();
453 let bytecode = crue_dsl::compiler::Compiler::compile(&ast).unwrap();
454
455 let req = EvaluationRequest {
456 request_id: "req".into(),
457 agent_id: "a".into(),
458 agent_org: "o".into(),
459 agent_level: "standard".into(),
460 mission_id: None,
461 mission_type: None,
462 query_type: None,
463 justification: Some("demo justification".into()),
464 export_format: None,
465 result_limit: Some(1),
466 requests_last_hour: 60,
467 requests_last_24h: 100,
468 results_last_query: 1,
469 account_department: None,
470 allowed_departments: vec![],
471 request_hour: 10,
472 is_within_mission_hours: true,
473 };
474 let ctx = EvaluationContext::from_request(&req);
475 assert!(BytecodeVm::eval(&bytecode, &ctx).unwrap());
476 }
477
478 #[test]
479 fn test_vm_eval_decision_emits_decision() {
480 let src = r#"
481RULE CRUE_001 VERSION 1.0
482WHEN
483 agent.requests_last_hour >= 50
484THEN
485 BLOCK WITH CODE "VOLUME_EXCEEDED"
486"#;
487 let ast = crue_dsl::parser::parse(src).unwrap();
488 let bytecode = crue_dsl::compiler::Compiler::compile(&ast).unwrap();
489 let mut req = EvaluationRequest {
490 request_id: "req".into(),
491 agent_id: "a".into(),
492 agent_org: "o".into(),
493 agent_level: "standard".into(),
494 mission_id: None,
495 mission_type: None,
496 query_type: None,
497 justification: Some("demo justification".into()),
498 export_format: None,
499 result_limit: Some(1),
500 requests_last_hour: 60,
501 requests_last_24h: 100,
502 results_last_query: 1,
503 account_department: None,
504 allowed_departments: vec![],
505 request_hour: 10,
506 is_within_mission_hours: true,
507 };
508 let ctx = EvaluationContext::from_request(&req);
509 assert_eq!(
510 BytecodeVm::eval_decision(&bytecode, &ctx, Decision::Block, Decision::Allow).unwrap(),
511 Decision::Block
512 );
513
514 req.requests_last_hour = 1;
515 let ctx2 = EvaluationContext::from_request(&req);
516 assert_eq!(
517 BytecodeVm::eval_decision(&bytecode, &ctx2, Decision::Block, Decision::Allow).unwrap(),
518 Decision::Allow
519 );
520 }
521
522 #[test]
523 fn test_vm_explicit_jump_and_emit_program() {
524 let bytecode = Bytecode {
525 instructions: vec![],
526 constants: vec![],
527 fields: vec![],
528 action_instructions: vec![],
529 };
530 let req = EvaluationRequest {
531 request_id: "req".into(),
532 agent_id: "a".into(),
533 agent_org: "o".into(),
534 agent_level: "standard".into(),
535 mission_id: None,
536 mission_type: None,
537 query_type: None,
538 justification: None,
539 export_format: None,
540 result_limit: None,
541 requests_last_hour: 0,
542 requests_last_24h: 0,
543 results_last_query: 0,
544 account_department: None,
545 allowed_departments: vec![],
546 request_hour: 0,
547 is_within_mission_hours: true,
548 };
549 let ctx = EvaluationContext::from_request(&req);
550 let program = vec![
551 Instruction::LoadFalse,
552 Instruction::JumpIfFalse(3),
553 Instruction::EmitDecision(Decision::Block),
554 Instruction::EmitDecision(Decision::Allow),
555 ];
556 assert_eq!(
557 BytecodeVm::eval_program(&program, &bytecode, &ctx).unwrap(),
558 VmExit::Decision(Decision::Allow)
559 );
560 }
561
562 #[test]
563 fn test_action_vm_exec_block_with_soc_alert() {
564 let program = vec![
565 ActionInstruction::SetDecision(Decision::Block),
566 ActionInstruction::SetErrorCode("VOLUME_EXCEEDED".into()),
567 ActionInstruction::SetMessage("Demo policy matched".into()),
568 ActionInstruction::SetAlertSoc(true),
569 ActionInstruction::Halt,
570 ];
571 let result = ActionVm::execute(&program).unwrap();
572 assert_eq!(result.decision, Decision::Block);
573 assert_eq!(result.error_code.as_deref(), Some("VOLUME_EXCEEDED"));
574 assert_eq!(result.message.as_deref(), Some("Demo policy matched"));
575 assert!(result.alert_soc);
576 }
577
578 #[test]
579 fn test_action_vm_exec_approval_default_message() {
580 let program = vec![
581 ActionInstruction::SetDecision(Decision::ApprovalRequired),
582 ActionInstruction::SetErrorCode("APPROVAL_REQUIRED".into()),
583 ActionInstruction::SetApprovalTimeout(15),
584 ActionInstruction::Halt,
585 ];
586 let result = ActionVm::execute(&program).unwrap();
587 assert_eq!(result.decision, Decision::ApprovalRequired);
588 assert_eq!(
589 result.message.as_deref(),
590 Some("Approval required within 15 minutes")
591 );
592 }
593}