Skip to main content

assay_core/mcp/
tool_call_handler.rs

1//! Central tool call handler with mandate authorization.
2//!
3//! This module integrates policy evaluation, mandate authorization, and
4//! decision emission into a single handler that guarantees the always-emit
5//! invariant (I1).
6
7use super::decision::{reason_codes, DecisionEmitter, DecisionEmitterGuard, DecisionEvent};
8use super::identity::ToolIdentity;
9use super::jsonrpc::JsonRpcRequest;
10use super::lifecycle::{mandate_used_event, LifecycleEmitter};
11use super::policy::{McpPolicy, PolicyDecision, PolicyState};
12use crate::runtime::{Authorizer, AuthzReceipt, MandateData, OperationClass, ToolCallData};
13use serde_json::Value;
14use std::sync::Arc;
15use std::time::Instant;
16
17/// Result of tool call handling.
18#[derive(Debug)]
19pub enum HandleResult {
20    /// Tool call is allowed, forward to server
21    Allow {
22        receipt: Option<AuthzReceipt>,
23        decision_event: DecisionEvent,
24    },
25    /// Tool call is denied, return error response
26    Deny {
27        reason_code: String,
28        reason: String,
29        decision_event: DecisionEvent,
30    },
31    /// Internal error during handling
32    Error {
33        reason_code: String,
34        reason: String,
35        decision_event: DecisionEvent,
36    },
37}
38
39/// Configuration for the tool call handler.
40#[derive(Clone)]
41pub struct ToolCallHandlerConfig {
42    /// Event source URI (I3: fixed, configured value)
43    pub event_source: String,
44    /// Whether commit tools require mandates
45    pub require_mandate_for_commit: bool,
46    /// Tools classified as commit operations (glob: "prefix*" or exact)
47    pub commit_tools: Vec<String>,
48    /// Tools classified as write operations (non-commit; glob or exact). Used for mandate operation_class.
49    pub write_tools: Vec<String>,
50}
51
52impl Default for ToolCallHandlerConfig {
53    fn default() -> Self {
54        Self {
55            event_source: "assay://unknown".to_string(),
56            require_mandate_for_commit: true,
57            commit_tools: vec![],
58            write_tools: vec![],
59        }
60    }
61}
62
63/// Central tool call handler with integrated authorization.
64pub struct ToolCallHandler {
65    policy: McpPolicy,
66    authorizer: Option<Authorizer>,
67    emitter: Arc<dyn DecisionEmitter>,
68    /// Emitter for mandate lifecycle events (audit log)
69    lifecycle_emitter: Option<Arc<dyn LifecycleEmitter>>,
70    config: ToolCallHandlerConfig,
71}
72
73impl ToolCallHandler {
74    /// Create a new handler.
75    pub fn new(
76        policy: McpPolicy,
77        authorizer: Option<Authorizer>,
78        emitter: Arc<dyn DecisionEmitter>,
79        config: ToolCallHandlerConfig,
80    ) -> Self {
81        Self {
82            policy,
83            authorizer,
84            emitter,
85            lifecycle_emitter: None,
86            config,
87        }
88    }
89
90    /// Set the lifecycle emitter for mandate.used events (P0-B).
91    pub fn with_lifecycle_emitter(mut self, emitter: Arc<dyn LifecycleEmitter>) -> Self {
92        self.lifecycle_emitter = Some(emitter);
93        self
94    }
95
96    /// Handle a tool call with full authorization and always-emit guarantee.
97    ///
98    /// This is the main entry point that enforces invariant I1: exactly one
99    /// decision event is emitted for every tool call attempt.
100    pub fn handle_tool_call(
101        &self,
102        request: &JsonRpcRequest,
103        state: &mut PolicyState,
104        runtime_identity: Option<&ToolIdentity>,
105        mandate: Option<&MandateData>,
106        transaction_object: Option<&Value>,
107    ) -> HandleResult {
108        let params = match request.tool_params() {
109            Some(p) => p,
110            None => {
111                // Not a tool call - still must emit decision (I1 invariant)
112                let tool_call_id = self.extract_tool_call_id(request);
113                let guard = DecisionEmitterGuard::new(
114                    self.emitter.clone(),
115                    self.config.event_source.clone(),
116                    tool_call_id.clone(),
117                    "unknown".to_string(),
118                );
119                guard.emit_error(
120                    reason_codes::S_INTERNAL_ERROR,
121                    Some("Not a tool call".to_string()),
122                );
123
124                return HandleResult::Error {
125                    reason_code: reason_codes::S_INTERNAL_ERROR.to_string(),
126                    reason: "Not a tool call".to_string(),
127                    decision_event: DecisionEvent::new(
128                        self.config.event_source.clone(),
129                        tool_call_id,
130                        "unknown".to_string(),
131                    )
132                    .error(
133                        reason_codes::S_INTERNAL_ERROR,
134                        Some("Not a tool call".to_string()),
135                    ),
136                };
137            }
138        };
139
140        let tool_name = params.name.clone();
141        let tool_call_id = self.extract_tool_call_id(request);
142
143        // Create guard - ensures decision is ALWAYS emitted
144        let mut guard = DecisionEmitterGuard::new(
145            self.emitter.clone(),
146            self.config.event_source.clone(),
147            tool_call_id.clone(),
148            tool_name.clone(),
149        );
150        guard.set_request_id(request.id.clone());
151
152        let start = Instant::now();
153
154        // Step 1: Policy evaluation
155        let policy_eval = self.policy.evaluate_with_metadata(
156            &tool_name,
157            &params.arguments,
158            state,
159            runtime_identity,
160        );
161        let tool_classes = policy_eval.metadata.tool_classes.clone();
162        let matched_tool_classes = policy_eval.metadata.matched_tool_classes.clone();
163        let match_basis = policy_eval
164            .metadata
165            .match_basis
166            .as_str()
167            .map(ToString::to_string);
168        let matched_rule = policy_eval.metadata.matched_rule.clone();
169        guard.set_tool_match(
170            policy_eval.metadata.tool_classes.clone(),
171            policy_eval.metadata.matched_tool_classes.clone(),
172            match_basis.clone(),
173            matched_rule.clone(),
174        );
175
176        match policy_eval.decision {
177            PolicyDecision::Deny {
178                tool: _,
179                code,
180                reason,
181                contract: _,
182            } => {
183                let reason_code = self.map_policy_code_to_reason(&code);
184                guard.emit_deny(&reason_code, Some(reason.clone()));
185
186                return HandleResult::Deny {
187                    reason_code: reason_code.clone(),
188                    reason: reason.clone(),
189                    decision_event: DecisionEvent::new(
190                        self.config.event_source.clone(),
191                        tool_call_id,
192                        tool_name,
193                    )
194                    .deny(&reason_code, Some(reason))
195                    .with_tool_match(
196                        tool_classes.clone(),
197                        matched_tool_classes.clone(),
198                        match_basis.clone(),
199                        matched_rule.clone(),
200                    ),
201                };
202            }
203            PolicyDecision::AllowWithWarning { .. } | PolicyDecision::Allow => {
204                // Continue to mandate check
205            }
206        }
207
208        // Step 2: Check if mandate is required
209        let is_commit_tool = self.is_commit_tool(&tool_name);
210        if is_commit_tool && self.config.require_mandate_for_commit && mandate.is_none() {
211            guard.emit_deny(
212                reason_codes::P_MANDATE_REQUIRED,
213                Some("Commit tool requires mandate authorization".to_string()),
214            );
215
216            return HandleResult::Deny {
217                reason_code: reason_codes::P_MANDATE_REQUIRED.to_string(),
218                reason: "Commit tool requires mandate authorization".to_string(),
219                decision_event: DecisionEvent::new(
220                    self.config.event_source.clone(),
221                    tool_call_id,
222                    tool_name,
223                )
224                .deny(
225                    reason_codes::P_MANDATE_REQUIRED,
226                    Some("Commit tool requires mandate authorization".to_string()),
227                )
228                .with_tool_match(
229                    tool_classes.clone(),
230                    matched_tool_classes.clone(),
231                    match_basis.clone(),
232                    matched_rule.clone(),
233                ),
234            };
235        }
236
237        // Step 3: Mandate authorization (if mandate present)
238        if let (Some(authorizer), Some(mandate_data)) = (&self.authorizer, mandate) {
239            let operation_class = self.operation_class_for_tool(&tool_name);
240
241            let tool_call_data = ToolCallData {
242                tool_name: tool_name.clone(),
243                tool_call_id: tool_call_id.clone(),
244                operation_class,
245                transaction_object: transaction_object.cloned(),
246                source_run_id: None,
247            };
248
249            let authz_start = Instant::now();
250            match authorizer.authorize_and_consume(mandate_data, &tool_call_data) {
251                Ok(receipt) => {
252                    let authz_ms = authz_start.elapsed().as_millis() as u64;
253                    guard.set_mandate_info(
254                        Some(mandate_data.mandate_id.clone()),
255                        Some(receipt.use_id.clone()),
256                        Some(receipt.use_count),
257                    );
258                    guard.set_mandate_matches(
259                        Some(true),
260                        Some(true),
261                        transaction_object.map(|_| true),
262                    );
263                    guard.set_latencies(Some(authz_ms), None);
264                    guard.emit_allow(reason_codes::P_MANDATE_VALID);
265
266                    // Emit mandate.used lifecycle event (P0-B)
267                    // Only emit on first consumption, not on idempotent retries
268                    if receipt.was_new {
269                        if let Some(ref lifecycle) = self.lifecycle_emitter {
270                            let event = mandate_used_event(&self.config.event_source, &receipt);
271                            lifecycle.emit(&event);
272                        }
273                    }
274
275                    return HandleResult::Allow {
276                        receipt: Some(receipt),
277                        decision_event: DecisionEvent::new(
278                            self.config.event_source.clone(),
279                            tool_call_id,
280                            tool_name,
281                        )
282                        .allow(reason_codes::P_MANDATE_VALID)
283                        .with_tool_match(
284                            tool_classes.clone(),
285                            matched_tool_classes.clone(),
286                            match_basis.clone(),
287                            matched_rule.clone(),
288                        ),
289                    };
290                }
291                Err(e) => {
292                    let (reason_code, reason) = self.map_authz_error(&e);
293                    guard.set_mandate_info(Some(mandate_data.mandate_id.clone()), None, None);
294                    guard.emit_deny(&reason_code, Some(reason.clone()));
295
296                    return HandleResult::Deny {
297                        reason_code: reason_code.clone(),
298                        reason: reason.clone(),
299                        decision_event: DecisionEvent::new(
300                            self.config.event_source.clone(),
301                            tool_call_id,
302                            tool_name,
303                        )
304                        .deny(&reason_code, Some(reason))
305                        .with_tool_match(
306                            tool_classes.clone(),
307                            matched_tool_classes.clone(),
308                            match_basis.clone(),
309                            matched_rule.clone(),
310                        ),
311                    };
312                }
313            }
314        }
315
316        // Step 4: No mandate required, policy allows
317        let elapsed_ms = start.elapsed().as_millis() as u64;
318        guard.set_latencies(Some(elapsed_ms), None);
319        guard.emit_allow(reason_codes::P_POLICY_ALLOW);
320
321        HandleResult::Allow {
322            receipt: None,
323            decision_event: DecisionEvent::new(
324                self.config.event_source.clone(),
325                tool_call_id,
326                tool_name,
327            )
328            .allow(reason_codes::P_POLICY_ALLOW)
329            .with_tool_match(
330                tool_classes,
331                matched_tool_classes,
332                match_basis,
333                matched_rule,
334            ),
335        }
336    }
337
338    /// Extract tool_call_id from request (I4: idempotency key).
339    fn extract_tool_call_id(&self, request: &JsonRpcRequest) -> String {
340        // Try to get from params._meta.tool_call_id (MCP standard)
341        if let Some(params) = request.tool_params() {
342            if let Some(meta) = params.arguments.get("_meta") {
343                if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
344                    return id.to_string();
345                }
346            }
347        }
348
349        // Fall back to request.id if present
350        if let Some(id) = &request.id {
351            if let Some(s) = id.as_str() {
352                return format!("req_{}", s);
353            }
354            if let Some(n) = id.as_i64() {
355                return format!("req_{}", n);
356            }
357        }
358
359        // Generate one if none found
360        format!("gen_{}", uuid::Uuid::new_v4())
361    }
362
363    /// Check if a tool is classified as a commit operation.
364    fn is_commit_tool(&self, tool_name: &str) -> bool {
365        self.config.commit_tools.iter().any(|pattern| {
366            if pattern == "*" {
367                return true;
368            }
369            if pattern.ends_with('*') {
370                let prefix = pattern.trim_end_matches('*');
371                tool_name.starts_with(prefix)
372            } else {
373                tool_name == pattern
374            }
375        })
376    }
377
378    /// Check if a tool is classified as a write operation (non-commit).
379    fn is_write_tool(&self, tool_name: &str) -> bool {
380        self.config.write_tools.iter().any(|pattern| {
381            if pattern == "*" {
382                return true;
383            }
384            if pattern.ends_with('*') {
385                let prefix = pattern.trim_end_matches('*');
386                tool_name.starts_with(prefix)
387            } else {
388                tool_name == pattern
389            }
390        })
391    }
392
393    /// Derive operation class from tool classification (commit_tools, write_tools, else Read).
394    fn operation_class_for_tool(&self, tool_name: &str) -> OperationClass {
395        if self.is_commit_tool(tool_name) {
396            OperationClass::Commit
397        } else if self.is_write_tool(tool_name) {
398            OperationClass::Write
399        } else {
400            OperationClass::Read
401        }
402    }
403
404    /// Map policy error code to reason code.
405    fn map_policy_code_to_reason(&self, code: &str) -> String {
406        match code {
407            "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
408            "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
409            "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
410            "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
411            "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
412            _ => reason_codes::P_POLICY_DENY.to_string(),
413        }
414    }
415
416    /// Map authorization error to reason code and message.
417    fn map_authz_error(&self, error: &crate::runtime::AuthorizeError) -> (String, String) {
418        use crate::runtime::AuthorizeError;
419
420        match error {
421            AuthorizeError::Policy(pe) => {
422                use crate::runtime::PolicyError;
423                match pe {
424                    PolicyError::Expired { .. } => (
425                        reason_codes::M_EXPIRED.to_string(),
426                        "Mandate expired".to_string(),
427                    ),
428                    PolicyError::NotYetValid { .. } => (
429                        reason_codes::M_NOT_YET_VALID.to_string(),
430                        "Mandate not yet valid".to_string(),
431                    ),
432                    PolicyError::ToolNotInScope { tool } => (
433                        reason_codes::M_TOOL_NOT_IN_SCOPE.to_string(),
434                        format!("Tool '{}' not in mandate scope", tool),
435                    ),
436                    PolicyError::KindMismatch { kind, op_class } => (
437                        reason_codes::M_KIND_MISMATCH.to_string(),
438                        format!(
439                            "Mandate kind '{}' does not allow operation class '{}'",
440                            kind, op_class
441                        ),
442                    ),
443                    PolicyError::AudienceMismatch { expected, actual } => (
444                        reason_codes::M_AUDIENCE_MISMATCH.to_string(),
445                        format!(
446                            "Audience mismatch: expected '{}', got '{}'",
447                            expected, actual
448                        ),
449                    ),
450                    PolicyError::IssuerNotTrusted { issuer } => (
451                        reason_codes::M_ISSUER_NOT_TRUSTED.to_string(),
452                        format!("Issuer '{}' not in trusted list", issuer),
453                    ),
454                    PolicyError::MissingTransactionObject => (
455                        reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
456                        "Transaction object required but not provided".to_string(),
457                    ),
458                    PolicyError::TransactionRefMismatch { expected, actual } => (
459                        reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
460                        format!(
461                            "Transaction ref mismatch: expected '{}', computed '{}'",
462                            expected, actual
463                        ),
464                    ),
465                }
466            }
467            AuthorizeError::Store(se) => {
468                use crate::runtime::AuthzError;
469                match se {
470                    AuthzError::AlreadyUsed => (
471                        reason_codes::M_ALREADY_USED.to_string(),
472                        "Single-use mandate already consumed".to_string(),
473                    ),
474                    AuthzError::MaxUsesExceeded { max, current } => (
475                        reason_codes::M_MAX_USES_EXCEEDED.to_string(),
476                        format!("Max uses exceeded: {} of {} used", current, max),
477                    ),
478                    AuthzError::NonceReplay { nonce } => (
479                        reason_codes::M_NONCE_REPLAY.to_string(),
480                        format!("Nonce replay detected: {}", nonce),
481                    ),
482                    AuthzError::MandateNotFound { mandate_id } => (
483                        reason_codes::M_NOT_FOUND.to_string(),
484                        format!("Mandate not found: {}", mandate_id),
485                    ),
486                    AuthzError::Revoked { revoked_at } => (
487                        reason_codes::M_REVOKED.to_string(),
488                        format!("Mandate revoked at {}", revoked_at),
489                    ),
490                    AuthzError::MandateConflict { .. }
491                    | AuthzError::InvalidConstraints { .. }
492                    | AuthzError::Database(_) => (
493                        reason_codes::S_DB_ERROR.to_string(),
494                        format!("Database error: {}", se),
495                    ),
496                }
497            }
498            AuthorizeError::TransactionRef(msg) => (
499                reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
500                format!("Transaction ref error: {}", msg),
501            ),
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::mcp::decision::NullDecisionEmitter;
510    use crate::mcp::lifecycle::{LifecycleEmitter, LifecycleEvent};
511    use std::sync::atomic::{AtomicUsize, Ordering};
512
513    struct CountingEmitter(AtomicUsize);
514
515    impl DecisionEmitter for CountingEmitter {
516        fn emit(&self, _event: &DecisionEvent) {
517            self.0.fetch_add(1, Ordering::SeqCst);
518        }
519    }
520
521    fn make_tool_call_request(tool: &str, args: Value) -> JsonRpcRequest {
522        JsonRpcRequest {
523            jsonrpc: "2.0".to_string(),
524            id: Some(Value::Number(1.into())),
525            method: "tools/call".to_string(),
526            params: serde_json::json!({
527                "name": tool,
528                "arguments": args
529            }),
530        }
531    }
532
533    #[test]
534    fn test_handler_emits_decision_on_policy_deny() {
535        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
536        let policy = McpPolicy {
537            tools: super::super::policy::ToolPolicy {
538                allow: None,
539                deny: Some(vec!["dangerous_*".to_string()]),
540                ..Default::default()
541            },
542            ..Default::default()
543        };
544
545        let handler = ToolCallHandler::new(
546            policy,
547            None,
548            emitter.clone(),
549            ToolCallHandlerConfig::default(),
550        );
551
552        let request = make_tool_call_request("dangerous_tool", serde_json::json!({}));
553        let mut state = PolicyState::default();
554
555        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
556
557        assert!(matches!(result, HandleResult::Deny { .. }));
558        assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
559    }
560
561    #[test]
562    fn test_handler_emits_decision_on_policy_allow() {
563        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
564        let policy = McpPolicy::default();
565
566        let handler = ToolCallHandler::new(
567            policy,
568            None,
569            emitter.clone(),
570            ToolCallHandlerConfig::default(),
571        );
572
573        let request = make_tool_call_request("safe_tool", serde_json::json!({}));
574        let mut state = PolicyState::default();
575
576        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
577
578        assert!(matches!(result, HandleResult::Allow { .. }));
579        assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
580    }
581
582    #[test]
583    fn test_commit_tool_without_mandate_denied() {
584        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
585        let policy = McpPolicy::default();
586
587        let config = ToolCallHandlerConfig {
588            event_source: "assay://test".to_string(),
589            require_mandate_for_commit: true,
590            commit_tools: vec!["purchase_*".to_string()],
591            write_tools: vec![],
592        };
593
594        let handler = ToolCallHandler::new(policy, None, emitter.clone(), config);
595
596        let request = make_tool_call_request("purchase_item", serde_json::json!({}));
597        let mut state = PolicyState::default();
598
599        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
600
601        assert!(
602            matches!(result, HandleResult::Deny { reason_code, .. } if reason_code == reason_codes::P_MANDATE_REQUIRED)
603        );
604        assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
605    }
606
607    #[test]
608    fn test_is_commit_tool_matching() {
609        let config = ToolCallHandlerConfig {
610            commit_tools: vec!["purchase_*".to_string(), "delete_account".to_string()],
611            ..Default::default()
612        };
613
614        let handler = ToolCallHandler::new(
615            McpPolicy::default(),
616            None,
617            Arc::new(NullDecisionEmitter),
618            config,
619        );
620
621        assert!(handler.is_commit_tool("purchase_item"));
622        assert!(handler.is_commit_tool("purchase_subscription"));
623        assert!(handler.is_commit_tool("delete_account"));
624        assert!(!handler.is_commit_tool("search_products"));
625        assert!(!handler.is_commit_tool("purchase")); // Doesn't match purchase_*
626    }
627
628    #[test]
629    fn test_operation_class_for_tool() {
630        use crate::runtime::OperationClass;
631        let config = ToolCallHandlerConfig {
632            commit_tools: vec!["purchase_*".to_string()],
633            write_tools: vec!["update_*".to_string(), "create_item".to_string()],
634            ..Default::default()
635        };
636        let handler = ToolCallHandler::new(
637            McpPolicy::default(),
638            None,
639            Arc::new(NullDecisionEmitter),
640            config,
641        );
642        assert_eq!(
643            handler.operation_class_for_tool("purchase_item"),
644            OperationClass::Commit
645        );
646        assert_eq!(
647            handler.operation_class_for_tool("update_profile"),
648            OperationClass::Write
649        );
650        assert_eq!(
651            handler.operation_class_for_tool("create_item"),
652            OperationClass::Write
653        );
654        assert_eq!(
655            handler.operation_class_for_tool("read_file"),
656            OperationClass::Read
657        );
658    }
659
660    // === P0-B: Lifecycle event emission tests ===
661
662    #[allow(dead_code)] // Prepared for future tests with mandate authorization
663    struct CountingLifecycleEmitter(AtomicUsize, std::sync::Mutex<Vec<LifecycleEvent>>);
664
665    impl LifecycleEmitter for CountingLifecycleEmitter {
666        fn emit(&self, event: &LifecycleEvent) {
667            self.0.fetch_add(1, Ordering::SeqCst);
668            if let Ok(mut events) = self.1.lock() {
669                events.push(event.clone());
670            }
671        }
672    }
673
674    #[test]
675    fn test_lifecycle_emitter_not_called_when_none() {
676        // When no lifecycle emitter is set, handler should still work
677        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
678        let policy = McpPolicy::default();
679
680        let handler = ToolCallHandler::new(
681            policy,
682            None,
683            emitter.clone(),
684            ToolCallHandlerConfig::default(),
685        );
686        // No lifecycle emitter set
687
688        let request = make_tool_call_request("safe_tool", serde_json::json!({}));
689        let mut state = PolicyState::default();
690
691        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
692
693        assert!(matches!(result, HandleResult::Allow { .. }));
694        assert_eq!(emitter.0.load(Ordering::SeqCst), 1); // Decision emitted
695    }
696}