1use super::decision::{reason_codes, DecisionEmitter, DecisionEmitterGuard, DecisionEvent};
8use super::identity::ToolIdentity;
9use super::jsonrpc::JsonRpcRequest;
10use super::lifecycle::{mandate_used_event, LifecycleEmitter};
11use super::policy::{McpPolicy, PolicyDecision, PolicyState};
12use crate::runtime::{Authorizer, AuthzReceipt, MandateData, OperationClass, ToolCallData};
13use serde_json::Value;
14use std::sync::Arc;
15use std::time::Instant;
16
17#[derive(Debug)]
19pub enum HandleResult {
20 Allow {
22 receipt: Option<AuthzReceipt>,
23 decision_event: DecisionEvent,
24 },
25 Deny {
27 reason_code: String,
28 reason: String,
29 decision_event: DecisionEvent,
30 },
31 Error {
33 reason_code: String,
34 reason: String,
35 decision_event: DecisionEvent,
36 },
37}
38
39#[derive(Clone)]
41pub struct ToolCallHandlerConfig {
42 pub event_source: String,
44 pub require_mandate_for_commit: bool,
46 pub commit_tools: Vec<String>,
48 pub write_tools: Vec<String>,
50}
51
52impl Default for ToolCallHandlerConfig {
53 fn default() -> Self {
54 Self {
55 event_source: "assay://unknown".to_string(),
56 require_mandate_for_commit: true,
57 commit_tools: vec![],
58 write_tools: vec![],
59 }
60 }
61}
62
63pub struct ToolCallHandler {
65 policy: McpPolicy,
66 authorizer: Option<Authorizer>,
67 emitter: Arc<dyn DecisionEmitter>,
68 lifecycle_emitter: Option<Arc<dyn LifecycleEmitter>>,
70 config: ToolCallHandlerConfig,
71}
72
73impl ToolCallHandler {
74 pub fn new(
76 policy: McpPolicy,
77 authorizer: Option<Authorizer>,
78 emitter: Arc<dyn DecisionEmitter>,
79 config: ToolCallHandlerConfig,
80 ) -> Self {
81 Self {
82 policy,
83 authorizer,
84 emitter,
85 lifecycle_emitter: None,
86 config,
87 }
88 }
89
90 pub fn with_lifecycle_emitter(mut self, emitter: Arc<dyn LifecycleEmitter>) -> Self {
92 self.lifecycle_emitter = Some(emitter);
93 self
94 }
95
96 pub fn handle_tool_call(
101 &self,
102 request: &JsonRpcRequest,
103 state: &mut PolicyState,
104 runtime_identity: Option<&ToolIdentity>,
105 mandate: Option<&MandateData>,
106 transaction_object: Option<&Value>,
107 ) -> HandleResult {
108 let params = match request.tool_params() {
109 Some(p) => p,
110 None => {
111 let tool_call_id = self.extract_tool_call_id(request);
113 let guard = DecisionEmitterGuard::new(
114 self.emitter.clone(),
115 self.config.event_source.clone(),
116 tool_call_id.clone(),
117 "unknown".to_string(),
118 );
119 guard.emit_error(
120 reason_codes::S_INTERNAL_ERROR,
121 Some("Not a tool call".to_string()),
122 );
123
124 return HandleResult::Error {
125 reason_code: reason_codes::S_INTERNAL_ERROR.to_string(),
126 reason: "Not a tool call".to_string(),
127 decision_event: DecisionEvent::new(
128 self.config.event_source.clone(),
129 tool_call_id,
130 "unknown".to_string(),
131 )
132 .error(
133 reason_codes::S_INTERNAL_ERROR,
134 Some("Not a tool call".to_string()),
135 ),
136 };
137 }
138 };
139
140 let tool_name = params.name.clone();
141 let tool_call_id = self.extract_tool_call_id(request);
142
143 let mut guard = DecisionEmitterGuard::new(
145 self.emitter.clone(),
146 self.config.event_source.clone(),
147 tool_call_id.clone(),
148 tool_name.clone(),
149 );
150 guard.set_request_id(request.id.clone());
151
152 let start = Instant::now();
153
154 let policy_eval = self.policy.evaluate_with_metadata(
156 &tool_name,
157 ¶ms.arguments,
158 state,
159 runtime_identity,
160 );
161 let tool_classes = policy_eval.metadata.tool_classes.clone();
162 let matched_tool_classes = policy_eval.metadata.matched_tool_classes.clone();
163 let match_basis = policy_eval
164 .metadata
165 .match_basis
166 .as_str()
167 .map(ToString::to_string);
168 let matched_rule = policy_eval.metadata.matched_rule.clone();
169 guard.set_tool_match(
170 policy_eval.metadata.tool_classes.clone(),
171 policy_eval.metadata.matched_tool_classes.clone(),
172 match_basis.clone(),
173 matched_rule.clone(),
174 );
175
176 match policy_eval.decision {
177 PolicyDecision::Deny {
178 tool: _,
179 code,
180 reason,
181 contract: _,
182 } => {
183 let reason_code = self.map_policy_code_to_reason(&code);
184 guard.emit_deny(&reason_code, Some(reason.clone()));
185
186 return HandleResult::Deny {
187 reason_code: reason_code.clone(),
188 reason: reason.clone(),
189 decision_event: DecisionEvent::new(
190 self.config.event_source.clone(),
191 tool_call_id,
192 tool_name,
193 )
194 .deny(&reason_code, Some(reason))
195 .with_tool_match(
196 tool_classes.clone(),
197 matched_tool_classes.clone(),
198 match_basis.clone(),
199 matched_rule.clone(),
200 ),
201 };
202 }
203 PolicyDecision::AllowWithWarning { .. } | PolicyDecision::Allow => {
204 }
206 }
207
208 let is_commit_tool = self.is_commit_tool(&tool_name);
210 if is_commit_tool && self.config.require_mandate_for_commit && mandate.is_none() {
211 guard.emit_deny(
212 reason_codes::P_MANDATE_REQUIRED,
213 Some("Commit tool requires mandate authorization".to_string()),
214 );
215
216 return HandleResult::Deny {
217 reason_code: reason_codes::P_MANDATE_REQUIRED.to_string(),
218 reason: "Commit tool requires mandate authorization".to_string(),
219 decision_event: DecisionEvent::new(
220 self.config.event_source.clone(),
221 tool_call_id,
222 tool_name,
223 )
224 .deny(
225 reason_codes::P_MANDATE_REQUIRED,
226 Some("Commit tool requires mandate authorization".to_string()),
227 )
228 .with_tool_match(
229 tool_classes.clone(),
230 matched_tool_classes.clone(),
231 match_basis.clone(),
232 matched_rule.clone(),
233 ),
234 };
235 }
236
237 if let (Some(authorizer), Some(mandate_data)) = (&self.authorizer, mandate) {
239 let operation_class = self.operation_class_for_tool(&tool_name);
240
241 let tool_call_data = ToolCallData {
242 tool_name: tool_name.clone(),
243 tool_call_id: tool_call_id.clone(),
244 operation_class,
245 transaction_object: transaction_object.cloned(),
246 source_run_id: None,
247 };
248
249 let authz_start = Instant::now();
250 match authorizer.authorize_and_consume(mandate_data, &tool_call_data) {
251 Ok(receipt) => {
252 let authz_ms = authz_start.elapsed().as_millis() as u64;
253 guard.set_mandate_info(
254 Some(mandate_data.mandate_id.clone()),
255 Some(receipt.use_id.clone()),
256 Some(receipt.use_count),
257 );
258 guard.set_mandate_matches(
259 Some(true),
260 Some(true),
261 transaction_object.map(|_| true),
262 );
263 guard.set_latencies(Some(authz_ms), None);
264 guard.emit_allow(reason_codes::P_MANDATE_VALID);
265
266 if receipt.was_new {
269 if let Some(ref lifecycle) = self.lifecycle_emitter {
270 let event = mandate_used_event(&self.config.event_source, &receipt);
271 lifecycle.emit(&event);
272 }
273 }
274
275 return HandleResult::Allow {
276 receipt: Some(receipt),
277 decision_event: DecisionEvent::new(
278 self.config.event_source.clone(),
279 tool_call_id,
280 tool_name,
281 )
282 .allow(reason_codes::P_MANDATE_VALID)
283 .with_tool_match(
284 tool_classes.clone(),
285 matched_tool_classes.clone(),
286 match_basis.clone(),
287 matched_rule.clone(),
288 ),
289 };
290 }
291 Err(e) => {
292 let (reason_code, reason) = self.map_authz_error(&e);
293 guard.set_mandate_info(Some(mandate_data.mandate_id.clone()), None, None);
294 guard.emit_deny(&reason_code, Some(reason.clone()));
295
296 return HandleResult::Deny {
297 reason_code: reason_code.clone(),
298 reason: reason.clone(),
299 decision_event: DecisionEvent::new(
300 self.config.event_source.clone(),
301 tool_call_id,
302 tool_name,
303 )
304 .deny(&reason_code, Some(reason))
305 .with_tool_match(
306 tool_classes.clone(),
307 matched_tool_classes.clone(),
308 match_basis.clone(),
309 matched_rule.clone(),
310 ),
311 };
312 }
313 }
314 }
315
316 let elapsed_ms = start.elapsed().as_millis() as u64;
318 guard.set_latencies(Some(elapsed_ms), None);
319 guard.emit_allow(reason_codes::P_POLICY_ALLOW);
320
321 HandleResult::Allow {
322 receipt: None,
323 decision_event: DecisionEvent::new(
324 self.config.event_source.clone(),
325 tool_call_id,
326 tool_name,
327 )
328 .allow(reason_codes::P_POLICY_ALLOW)
329 .with_tool_match(
330 tool_classes,
331 matched_tool_classes,
332 match_basis,
333 matched_rule,
334 ),
335 }
336 }
337
338 fn extract_tool_call_id(&self, request: &JsonRpcRequest) -> String {
340 if let Some(params) = request.tool_params() {
342 if let Some(meta) = params.arguments.get("_meta") {
343 if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
344 return id.to_string();
345 }
346 }
347 }
348
349 if let Some(id) = &request.id {
351 if let Some(s) = id.as_str() {
352 return format!("req_{}", s);
353 }
354 if let Some(n) = id.as_i64() {
355 return format!("req_{}", n);
356 }
357 }
358
359 format!("gen_{}", uuid::Uuid::new_v4())
361 }
362
363 fn is_commit_tool(&self, tool_name: &str) -> bool {
365 self.config.commit_tools.iter().any(|pattern| {
366 if pattern == "*" {
367 return true;
368 }
369 if pattern.ends_with('*') {
370 let prefix = pattern.trim_end_matches('*');
371 tool_name.starts_with(prefix)
372 } else {
373 tool_name == pattern
374 }
375 })
376 }
377
378 fn is_write_tool(&self, tool_name: &str) -> bool {
380 self.config.write_tools.iter().any(|pattern| {
381 if pattern == "*" {
382 return true;
383 }
384 if pattern.ends_with('*') {
385 let prefix = pattern.trim_end_matches('*');
386 tool_name.starts_with(prefix)
387 } else {
388 tool_name == pattern
389 }
390 })
391 }
392
393 fn operation_class_for_tool(&self, tool_name: &str) -> OperationClass {
395 if self.is_commit_tool(tool_name) {
396 OperationClass::Commit
397 } else if self.is_write_tool(tool_name) {
398 OperationClass::Write
399 } else {
400 OperationClass::Read
401 }
402 }
403
404 fn map_policy_code_to_reason(&self, code: &str) -> String {
406 match code {
407 "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
408 "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
409 "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
410 "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
411 "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
412 _ => reason_codes::P_POLICY_DENY.to_string(),
413 }
414 }
415
416 fn map_authz_error(&self, error: &crate::runtime::AuthorizeError) -> (String, String) {
418 use crate::runtime::AuthorizeError;
419
420 match error {
421 AuthorizeError::Policy(pe) => {
422 use crate::runtime::PolicyError;
423 match pe {
424 PolicyError::Expired { .. } => (
425 reason_codes::M_EXPIRED.to_string(),
426 "Mandate expired".to_string(),
427 ),
428 PolicyError::NotYetValid { .. } => (
429 reason_codes::M_NOT_YET_VALID.to_string(),
430 "Mandate not yet valid".to_string(),
431 ),
432 PolicyError::ToolNotInScope { tool } => (
433 reason_codes::M_TOOL_NOT_IN_SCOPE.to_string(),
434 format!("Tool '{}' not in mandate scope", tool),
435 ),
436 PolicyError::KindMismatch { kind, op_class } => (
437 reason_codes::M_KIND_MISMATCH.to_string(),
438 format!(
439 "Mandate kind '{}' does not allow operation class '{}'",
440 kind, op_class
441 ),
442 ),
443 PolicyError::AudienceMismatch { expected, actual } => (
444 reason_codes::M_AUDIENCE_MISMATCH.to_string(),
445 format!(
446 "Audience mismatch: expected '{}', got '{}'",
447 expected, actual
448 ),
449 ),
450 PolicyError::IssuerNotTrusted { issuer } => (
451 reason_codes::M_ISSUER_NOT_TRUSTED.to_string(),
452 format!("Issuer '{}' not in trusted list", issuer),
453 ),
454 PolicyError::MissingTransactionObject => (
455 reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
456 "Transaction object required but not provided".to_string(),
457 ),
458 PolicyError::TransactionRefMismatch { expected, actual } => (
459 reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
460 format!(
461 "Transaction ref mismatch: expected '{}', computed '{}'",
462 expected, actual
463 ),
464 ),
465 }
466 }
467 AuthorizeError::Store(se) => {
468 use crate::runtime::AuthzError;
469 match se {
470 AuthzError::AlreadyUsed => (
471 reason_codes::M_ALREADY_USED.to_string(),
472 "Single-use mandate already consumed".to_string(),
473 ),
474 AuthzError::MaxUsesExceeded { max, current } => (
475 reason_codes::M_MAX_USES_EXCEEDED.to_string(),
476 format!("Max uses exceeded: {} of {} used", current, max),
477 ),
478 AuthzError::NonceReplay { nonce } => (
479 reason_codes::M_NONCE_REPLAY.to_string(),
480 format!("Nonce replay detected: {}", nonce),
481 ),
482 AuthzError::MandateNotFound { mandate_id } => (
483 reason_codes::M_NOT_FOUND.to_string(),
484 format!("Mandate not found: {}", mandate_id),
485 ),
486 AuthzError::Revoked { revoked_at } => (
487 reason_codes::M_REVOKED.to_string(),
488 format!("Mandate revoked at {}", revoked_at),
489 ),
490 AuthzError::MandateConflict { .. }
491 | AuthzError::InvalidConstraints { .. }
492 | AuthzError::Database(_) => (
493 reason_codes::S_DB_ERROR.to_string(),
494 format!("Database error: {}", se),
495 ),
496 }
497 }
498 AuthorizeError::TransactionRef(msg) => (
499 reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
500 format!("Transaction ref error: {}", msg),
501 ),
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::mcp::decision::NullDecisionEmitter;
510 use crate::mcp::lifecycle::{LifecycleEmitter, LifecycleEvent};
511 use std::sync::atomic::{AtomicUsize, Ordering};
512
513 struct CountingEmitter(AtomicUsize);
514
515 impl DecisionEmitter for CountingEmitter {
516 fn emit(&self, _event: &DecisionEvent) {
517 self.0.fetch_add(1, Ordering::SeqCst);
518 }
519 }
520
521 fn make_tool_call_request(tool: &str, args: Value) -> JsonRpcRequest {
522 JsonRpcRequest {
523 jsonrpc: "2.0".to_string(),
524 id: Some(Value::Number(1.into())),
525 method: "tools/call".to_string(),
526 params: serde_json::json!({
527 "name": tool,
528 "arguments": args
529 }),
530 }
531 }
532
533 #[test]
534 fn test_handler_emits_decision_on_policy_deny() {
535 let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
536 let policy = McpPolicy {
537 tools: super::super::policy::ToolPolicy {
538 allow: None,
539 deny: Some(vec!["dangerous_*".to_string()]),
540 ..Default::default()
541 },
542 ..Default::default()
543 };
544
545 let handler = ToolCallHandler::new(
546 policy,
547 None,
548 emitter.clone(),
549 ToolCallHandlerConfig::default(),
550 );
551
552 let request = make_tool_call_request("dangerous_tool", serde_json::json!({}));
553 let mut state = PolicyState::default();
554
555 let result = handler.handle_tool_call(&request, &mut state, None, None, None);
556
557 assert!(matches!(result, HandleResult::Deny { .. }));
558 assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
559 }
560
561 #[test]
562 fn test_handler_emits_decision_on_policy_allow() {
563 let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
564 let policy = McpPolicy::default();
565
566 let handler = ToolCallHandler::new(
567 policy,
568 None,
569 emitter.clone(),
570 ToolCallHandlerConfig::default(),
571 );
572
573 let request = make_tool_call_request("safe_tool", serde_json::json!({}));
574 let mut state = PolicyState::default();
575
576 let result = handler.handle_tool_call(&request, &mut state, None, None, None);
577
578 assert!(matches!(result, HandleResult::Allow { .. }));
579 assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
580 }
581
582 #[test]
583 fn test_commit_tool_without_mandate_denied() {
584 let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
585 let policy = McpPolicy::default();
586
587 let config = ToolCallHandlerConfig {
588 event_source: "assay://test".to_string(),
589 require_mandate_for_commit: true,
590 commit_tools: vec!["purchase_*".to_string()],
591 write_tools: vec![],
592 };
593
594 let handler = ToolCallHandler::new(policy, None, emitter.clone(), config);
595
596 let request = make_tool_call_request("purchase_item", serde_json::json!({}));
597 let mut state = PolicyState::default();
598
599 let result = handler.handle_tool_call(&request, &mut state, None, None, None);
600
601 assert!(
602 matches!(result, HandleResult::Deny { reason_code, .. } if reason_code == reason_codes::P_MANDATE_REQUIRED)
603 );
604 assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
605 }
606
607 #[test]
608 fn test_is_commit_tool_matching() {
609 let config = ToolCallHandlerConfig {
610 commit_tools: vec!["purchase_*".to_string(), "delete_account".to_string()],
611 ..Default::default()
612 };
613
614 let handler = ToolCallHandler::new(
615 McpPolicy::default(),
616 None,
617 Arc::new(NullDecisionEmitter),
618 config,
619 );
620
621 assert!(handler.is_commit_tool("purchase_item"));
622 assert!(handler.is_commit_tool("purchase_subscription"));
623 assert!(handler.is_commit_tool("delete_account"));
624 assert!(!handler.is_commit_tool("search_products"));
625 assert!(!handler.is_commit_tool("purchase")); }
627
628 #[test]
629 fn test_operation_class_for_tool() {
630 use crate::runtime::OperationClass;
631 let config = ToolCallHandlerConfig {
632 commit_tools: vec!["purchase_*".to_string()],
633 write_tools: vec!["update_*".to_string(), "create_item".to_string()],
634 ..Default::default()
635 };
636 let handler = ToolCallHandler::new(
637 McpPolicy::default(),
638 None,
639 Arc::new(NullDecisionEmitter),
640 config,
641 );
642 assert_eq!(
643 handler.operation_class_for_tool("purchase_item"),
644 OperationClass::Commit
645 );
646 assert_eq!(
647 handler.operation_class_for_tool("update_profile"),
648 OperationClass::Write
649 );
650 assert_eq!(
651 handler.operation_class_for_tool("create_item"),
652 OperationClass::Write
653 );
654 assert_eq!(
655 handler.operation_class_for_tool("read_file"),
656 OperationClass::Read
657 );
658 }
659
660 #[allow(dead_code)] struct CountingLifecycleEmitter(AtomicUsize, std::sync::Mutex<Vec<LifecycleEvent>>);
664
665 impl LifecycleEmitter for CountingLifecycleEmitter {
666 fn emit(&self, event: &LifecycleEvent) {
667 self.0.fetch_add(1, Ordering::SeqCst);
668 if let Ok(mut events) = self.1.lock() {
669 events.push(event.clone());
670 }
671 }
672 }
673
674 #[test]
675 fn test_lifecycle_emitter_not_called_when_none() {
676 let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
678 let policy = McpPolicy::default();
679
680 let handler = ToolCallHandler::new(
681 policy,
682 None,
683 emitter.clone(),
684 ToolCallHandlerConfig::default(),
685 );
686 let request = make_tool_call_request("safe_tool", serde_json::json!({}));
689 let mut state = PolicyState::default();
690
691 let result = handler.handle_tool_call(&request, &mut state, None, None, None);
692
693 assert!(matches!(result, HandleResult::Allow { .. }));
694 assert_eq!(emitter.0.load(Ordering::SeqCst), 1); }
696}