Skip to main content

assay_core/mcp/
tool_call_handler.rs

1//! Central tool call handler with mandate authorization.
2//!
3//! This module integrates policy evaluation, mandate authorization, and
4//! decision emission into a single handler that guarantees the always-emit
5//! invariant (I1).
6
7use super::decision::{reason_codes, DecisionEmitter, DecisionEmitterGuard, DecisionEvent};
8use super::identity::ToolIdentity;
9use super::jsonrpc::JsonRpcRequest;
10use super::lifecycle::{mandate_used_event, LifecycleEmitter};
11use super::policy::{McpPolicy, PolicyDecision, PolicyState};
12use crate::runtime::{Authorizer, AuthzReceipt, MandateData, OperationClass, ToolCallData};
13use serde_json::Value;
14use std::sync::Arc;
15use std::time::Instant;
16
17/// Result of tool call handling.
18#[derive(Debug)]
19pub enum HandleResult {
20    /// Tool call is allowed, forward to server
21    Allow {
22        receipt: Option<AuthzReceipt>,
23        decision_event: DecisionEvent,
24    },
25    /// Tool call is denied, return error response
26    Deny {
27        reason_code: String,
28        reason: String,
29        decision_event: DecisionEvent,
30    },
31    /// Internal error during handling
32    Error {
33        reason_code: String,
34        reason: String,
35        decision_event: DecisionEvent,
36    },
37}
38
39/// Configuration for the tool call handler.
40#[derive(Clone)]
41pub struct ToolCallHandlerConfig {
42    /// Event source URI (I3: fixed, configured value)
43    pub event_source: String,
44    /// Whether commit tools require mandates
45    pub require_mandate_for_commit: bool,
46    /// Tools classified as commit operations
47    pub commit_tools: Vec<String>,
48}
49
50impl Default for ToolCallHandlerConfig {
51    fn default() -> Self {
52        Self {
53            event_source: "assay://unknown".to_string(),
54            require_mandate_for_commit: true,
55            commit_tools: vec![],
56        }
57    }
58}
59
60/// Central tool call handler with integrated authorization.
61pub struct ToolCallHandler {
62    policy: McpPolicy,
63    authorizer: Option<Authorizer>,
64    emitter: Arc<dyn DecisionEmitter>,
65    /// Emitter for mandate lifecycle events (audit log)
66    lifecycle_emitter: Option<Arc<dyn LifecycleEmitter>>,
67    config: ToolCallHandlerConfig,
68}
69
70impl ToolCallHandler {
71    /// Create a new handler.
72    pub fn new(
73        policy: McpPolicy,
74        authorizer: Option<Authorizer>,
75        emitter: Arc<dyn DecisionEmitter>,
76        config: ToolCallHandlerConfig,
77    ) -> Self {
78        Self {
79            policy,
80            authorizer,
81            emitter,
82            lifecycle_emitter: None,
83            config,
84        }
85    }
86
87    /// Set the lifecycle emitter for mandate.used events (P0-B).
88    pub fn with_lifecycle_emitter(mut self, emitter: Arc<dyn LifecycleEmitter>) -> Self {
89        self.lifecycle_emitter = Some(emitter);
90        self
91    }
92
93    /// Handle a tool call with full authorization and always-emit guarantee.
94    ///
95    /// This is the main entry point that enforces invariant I1: exactly one
96    /// decision event is emitted for every tool call attempt.
97    pub fn handle_tool_call(
98        &self,
99        request: &JsonRpcRequest,
100        state: &mut PolicyState,
101        runtime_identity: Option<&ToolIdentity>,
102        mandate: Option<&MandateData>,
103        transaction_object: Option<&Value>,
104    ) -> HandleResult {
105        let params = match request.tool_params() {
106            Some(p) => p,
107            None => {
108                // Not a tool call - still must emit decision (I1 invariant)
109                let tool_call_id = self.extract_tool_call_id(request);
110                let guard = DecisionEmitterGuard::new(
111                    self.emitter.clone(),
112                    self.config.event_source.clone(),
113                    tool_call_id.clone(),
114                    "unknown".to_string(),
115                );
116                guard.emit_error(
117                    reason_codes::S_INTERNAL_ERROR,
118                    Some("Not a tool call".to_string()),
119                );
120
121                return HandleResult::Error {
122                    reason_code: reason_codes::S_INTERNAL_ERROR.to_string(),
123                    reason: "Not a tool call".to_string(),
124                    decision_event: DecisionEvent::new(
125                        self.config.event_source.clone(),
126                        tool_call_id,
127                        "unknown".to_string(),
128                    )
129                    .error(
130                        reason_codes::S_INTERNAL_ERROR,
131                        Some("Not a tool call".to_string()),
132                    ),
133                };
134            }
135        };
136
137        let tool_name = params.name.clone();
138        let tool_call_id = self.extract_tool_call_id(request);
139
140        // Create guard - ensures decision is ALWAYS emitted
141        let mut guard = DecisionEmitterGuard::new(
142            self.emitter.clone(),
143            self.config.event_source.clone(),
144            tool_call_id.clone(),
145            tool_name.clone(),
146        );
147        guard.set_request_id(request.id.clone());
148
149        let start = Instant::now();
150
151        // Step 1: Policy evaluation
152        let policy_decision =
153            self.policy
154                .evaluate(&tool_name, &params.arguments, state, runtime_identity);
155
156        match policy_decision {
157            PolicyDecision::Deny {
158                tool: _,
159                code,
160                reason,
161                contract: _,
162            } => {
163                let reason_code = self.map_policy_code_to_reason(&code);
164                guard.emit_deny(&reason_code, Some(reason.clone()));
165
166                return HandleResult::Deny {
167                    reason_code: reason_code.clone(),
168                    reason: reason.clone(),
169                    decision_event: DecisionEvent::new(
170                        self.config.event_source.clone(),
171                        tool_call_id,
172                        tool_name,
173                    )
174                    .deny(&reason_code, Some(reason)),
175                };
176            }
177            PolicyDecision::AllowWithWarning { .. } | PolicyDecision::Allow => {
178                // Continue to mandate check
179            }
180        }
181
182        // Step 2: Check if mandate is required
183        let is_commit_tool = self.is_commit_tool(&tool_name);
184        if is_commit_tool && self.config.require_mandate_for_commit && mandate.is_none() {
185            guard.emit_deny(
186                reason_codes::P_MANDATE_REQUIRED,
187                Some("Commit tool requires mandate authorization".to_string()),
188            );
189
190            return HandleResult::Deny {
191                reason_code: reason_codes::P_MANDATE_REQUIRED.to_string(),
192                reason: "Commit tool requires mandate authorization".to_string(),
193                decision_event: DecisionEvent::new(
194                    self.config.event_source.clone(),
195                    tool_call_id,
196                    tool_name,
197                )
198                .deny(
199                    reason_codes::P_MANDATE_REQUIRED,
200                    Some("Commit tool requires mandate authorization".to_string()),
201                ),
202            };
203        }
204
205        // Step 3: Mandate authorization (if mandate present)
206        if let (Some(authorizer), Some(mandate_data)) = (&self.authorizer, mandate) {
207            let operation_class = if is_commit_tool {
208                OperationClass::Commit
209            } else {
210                OperationClass::Read // TODO: Determine from tool classification
211            };
212
213            let tool_call_data = ToolCallData {
214                tool_name: tool_name.clone(),
215                tool_call_id: tool_call_id.clone(),
216                operation_class,
217                transaction_object: transaction_object.cloned(),
218                source_run_id: None,
219            };
220
221            let authz_start = Instant::now();
222            match authorizer.authorize_and_consume(mandate_data, &tool_call_data) {
223                Ok(receipt) => {
224                    let authz_ms = authz_start.elapsed().as_millis() as u64;
225                    guard.set_mandate_info(
226                        Some(mandate_data.mandate_id.clone()),
227                        Some(receipt.use_id.clone()),
228                        Some(receipt.use_count),
229                    );
230                    guard.set_mandate_matches(
231                        Some(true),
232                        Some(true),
233                        transaction_object.map(|_| true),
234                    );
235                    guard.set_latencies(Some(authz_ms), None);
236                    guard.emit_allow(reason_codes::P_MANDATE_VALID);
237
238                    // Emit mandate.used lifecycle event (P0-B)
239                    // Only emit on first consumption, not on idempotent retries
240                    if receipt.was_new {
241                        if let Some(ref lifecycle) = self.lifecycle_emitter {
242                            let event = mandate_used_event(&self.config.event_source, &receipt);
243                            lifecycle.emit(&event);
244                        }
245                    }
246
247                    return HandleResult::Allow {
248                        receipt: Some(receipt),
249                        decision_event: DecisionEvent::new(
250                            self.config.event_source.clone(),
251                            tool_call_id,
252                            tool_name,
253                        )
254                        .allow(reason_codes::P_MANDATE_VALID),
255                    };
256                }
257                Err(e) => {
258                    let (reason_code, reason) = self.map_authz_error(&e);
259                    guard.set_mandate_info(Some(mandate_data.mandate_id.clone()), None, None);
260                    guard.emit_deny(&reason_code, Some(reason.clone()));
261
262                    return HandleResult::Deny {
263                        reason_code,
264                        reason,
265                        decision_event: DecisionEvent::new(
266                            self.config.event_source.clone(),
267                            tool_call_id,
268                            tool_name,
269                        ),
270                    };
271                }
272            }
273        }
274
275        // Step 4: No mandate required, policy allows
276        let elapsed_ms = start.elapsed().as_millis() as u64;
277        guard.set_latencies(Some(elapsed_ms), None);
278        guard.emit_allow(reason_codes::P_POLICY_DENY); // Actually P_POLICY_ALLOW but we use P_POLICY_DENY as catch-all
279
280        HandleResult::Allow {
281            receipt: None,
282            decision_event: DecisionEvent::new(
283                self.config.event_source.clone(),
284                tool_call_id,
285                tool_name,
286            )
287            .allow(reason_codes::P_POLICY_DENY),
288        }
289    }
290
291    /// Extract tool_call_id from request (I4: idempotency key).
292    fn extract_tool_call_id(&self, request: &JsonRpcRequest) -> String {
293        // Try to get from params._meta.tool_call_id (MCP standard)
294        if let Some(params) = request.tool_params() {
295            if let Some(meta) = params.arguments.get("_meta") {
296                if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
297                    return id.to_string();
298                }
299            }
300        }
301
302        // Fall back to request.id if present
303        if let Some(id) = &request.id {
304            if let Some(s) = id.as_str() {
305                return format!("req_{}", s);
306            }
307            if let Some(n) = id.as_i64() {
308                return format!("req_{}", n);
309            }
310        }
311
312        // Generate one if none found
313        format!("gen_{}", uuid::Uuid::new_v4())
314    }
315
316    /// Check if a tool is classified as a commit operation.
317    fn is_commit_tool(&self, tool_name: &str) -> bool {
318        self.config.commit_tools.iter().any(|pattern| {
319            if pattern == "*" {
320                return true;
321            }
322            if pattern.ends_with('*') {
323                let prefix = pattern.trim_end_matches('*');
324                tool_name.starts_with(prefix)
325            } else {
326                tool_name == pattern
327            }
328        })
329    }
330
331    /// Map policy error code to reason code.
332    fn map_policy_code_to_reason(&self, code: &str) -> String {
333        match code {
334            "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
335            "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
336            "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
337            "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
338            "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
339            _ => reason_codes::P_POLICY_DENY.to_string(),
340        }
341    }
342
343    /// Map authorization error to reason code and message.
344    fn map_authz_error(&self, error: &crate::runtime::AuthorizeError) -> (String, String) {
345        use crate::runtime::AuthorizeError;
346
347        match error {
348            AuthorizeError::Policy(pe) => {
349                use crate::runtime::PolicyError;
350                match pe {
351                    PolicyError::Expired { .. } => (
352                        reason_codes::M_EXPIRED.to_string(),
353                        "Mandate expired".to_string(),
354                    ),
355                    PolicyError::NotYetValid { .. } => (
356                        reason_codes::M_NOT_YET_VALID.to_string(),
357                        "Mandate not yet valid".to_string(),
358                    ),
359                    PolicyError::ToolNotInScope { tool } => (
360                        reason_codes::M_TOOL_NOT_IN_SCOPE.to_string(),
361                        format!("Tool '{}' not in mandate scope", tool),
362                    ),
363                    PolicyError::KindMismatch { kind, op_class } => (
364                        reason_codes::M_KIND_MISMATCH.to_string(),
365                        format!(
366                            "Mandate kind '{}' does not allow operation class '{}'",
367                            kind, op_class
368                        ),
369                    ),
370                    PolicyError::AudienceMismatch { expected, actual } => (
371                        reason_codes::M_AUDIENCE_MISMATCH.to_string(),
372                        format!(
373                            "Audience mismatch: expected '{}', got '{}'",
374                            expected, actual
375                        ),
376                    ),
377                    PolicyError::IssuerNotTrusted { issuer } => (
378                        reason_codes::M_ISSUER_NOT_TRUSTED.to_string(),
379                        format!("Issuer '{}' not in trusted list", issuer),
380                    ),
381                    PolicyError::MissingTransactionObject => (
382                        reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
383                        "Transaction object required but not provided".to_string(),
384                    ),
385                    PolicyError::TransactionRefMismatch { expected, actual } => (
386                        reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
387                        format!(
388                            "Transaction ref mismatch: expected '{}', computed '{}'",
389                            expected, actual
390                        ),
391                    ),
392                }
393            }
394            AuthorizeError::Store(se) => {
395                use crate::runtime::AuthzError;
396                match se {
397                    AuthzError::AlreadyUsed => (
398                        reason_codes::M_ALREADY_USED.to_string(),
399                        "Single-use mandate already consumed".to_string(),
400                    ),
401                    AuthzError::MaxUsesExceeded { max, current } => (
402                        reason_codes::M_MAX_USES_EXCEEDED.to_string(),
403                        format!("Max uses exceeded: {} of {} used", current, max),
404                    ),
405                    AuthzError::NonceReplay { nonce } => (
406                        reason_codes::M_NONCE_REPLAY.to_string(),
407                        format!("Nonce replay detected: {}", nonce),
408                    ),
409                    AuthzError::MandateNotFound { mandate_id } => (
410                        reason_codes::M_NOT_FOUND.to_string(),
411                        format!("Mandate not found: {}", mandate_id),
412                    ),
413                    AuthzError::Revoked { revoked_at } => (
414                        reason_codes::M_REVOKED.to_string(),
415                        format!("Mandate revoked at {}", revoked_at),
416                    ),
417                    AuthzError::MandateConflict { .. }
418                    | AuthzError::InvalidConstraints { .. }
419                    | AuthzError::Database(_) => (
420                        reason_codes::S_DB_ERROR.to_string(),
421                        format!("Database error: {}", se),
422                    ),
423                }
424            }
425            AuthorizeError::TransactionRef(msg) => (
426                reason_codes::M_TRANSACTION_REF_MISMATCH.to_string(),
427                format!("Transaction ref error: {}", msg),
428            ),
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::mcp::decision::NullDecisionEmitter;
437    use crate::mcp::lifecycle::{LifecycleEmitter, LifecycleEvent};
438    use std::sync::atomic::{AtomicUsize, Ordering};
439
440    struct CountingEmitter(AtomicUsize);
441
442    impl DecisionEmitter for CountingEmitter {
443        fn emit(&self, _event: &DecisionEvent) {
444            self.0.fetch_add(1, Ordering::SeqCst);
445        }
446    }
447
448    fn make_tool_call_request(tool: &str, args: Value) -> JsonRpcRequest {
449        JsonRpcRequest {
450            jsonrpc: "2.0".to_string(),
451            id: Some(Value::Number(1.into())),
452            method: "tools/call".to_string(),
453            params: serde_json::json!({
454                "name": tool,
455                "arguments": args
456            }),
457        }
458    }
459
460    #[test]
461    fn test_handler_emits_decision_on_policy_deny() {
462        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
463        let policy = McpPolicy {
464            tools: super::super::policy::ToolPolicy {
465                allow: None,
466                deny: Some(vec!["dangerous_*".to_string()]),
467            },
468            ..Default::default()
469        };
470
471        let handler = ToolCallHandler::new(
472            policy,
473            None,
474            emitter.clone(),
475            ToolCallHandlerConfig::default(),
476        );
477
478        let request = make_tool_call_request("dangerous_tool", serde_json::json!({}));
479        let mut state = PolicyState::default();
480
481        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
482
483        assert!(matches!(result, HandleResult::Deny { .. }));
484        assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
485    }
486
487    #[test]
488    fn test_handler_emits_decision_on_policy_allow() {
489        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
490        let policy = McpPolicy::default();
491
492        let handler = ToolCallHandler::new(
493            policy,
494            None,
495            emitter.clone(),
496            ToolCallHandlerConfig::default(),
497        );
498
499        let request = make_tool_call_request("safe_tool", serde_json::json!({}));
500        let mut state = PolicyState::default();
501
502        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
503
504        assert!(matches!(result, HandleResult::Allow { .. }));
505        assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
506    }
507
508    #[test]
509    fn test_commit_tool_without_mandate_denied() {
510        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
511        let policy = McpPolicy::default();
512
513        let config = ToolCallHandlerConfig {
514            event_source: "assay://test".to_string(),
515            require_mandate_for_commit: true,
516            commit_tools: vec!["purchase_*".to_string()],
517        };
518
519        let handler = ToolCallHandler::new(policy, None, emitter.clone(), config);
520
521        let request = make_tool_call_request("purchase_item", serde_json::json!({}));
522        let mut state = PolicyState::default();
523
524        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
525
526        assert!(
527            matches!(result, HandleResult::Deny { reason_code, .. } if reason_code == reason_codes::P_MANDATE_REQUIRED)
528        );
529        assert_eq!(emitter.0.load(Ordering::SeqCst), 1);
530    }
531
532    #[test]
533    fn test_is_commit_tool_matching() {
534        let config = ToolCallHandlerConfig {
535            commit_tools: vec!["purchase_*".to_string(), "delete_account".to_string()],
536            ..Default::default()
537        };
538
539        let handler = ToolCallHandler::new(
540            McpPolicy::default(),
541            None,
542            Arc::new(NullDecisionEmitter),
543            config,
544        );
545
546        assert!(handler.is_commit_tool("purchase_item"));
547        assert!(handler.is_commit_tool("purchase_subscription"));
548        assert!(handler.is_commit_tool("delete_account"));
549        assert!(!handler.is_commit_tool("search_products"));
550        assert!(!handler.is_commit_tool("purchase")); // Doesn't match purchase_*
551    }
552
553    // === P0-B: Lifecycle event emission tests ===
554
555    #[allow(dead_code)] // Prepared for future tests with mandate authorization
556    struct CountingLifecycleEmitter(AtomicUsize, std::sync::Mutex<Vec<LifecycleEvent>>);
557
558    impl LifecycleEmitter for CountingLifecycleEmitter {
559        fn emit(&self, event: &LifecycleEvent) {
560            self.0.fetch_add(1, Ordering::SeqCst);
561            if let Ok(mut events) = self.1.lock() {
562                events.push(event.clone());
563            }
564        }
565    }
566
567    #[test]
568    fn test_lifecycle_emitter_not_called_when_none() {
569        // When no lifecycle emitter is set, handler should still work
570        let emitter = Arc::new(CountingEmitter(AtomicUsize::new(0)));
571        let policy = McpPolicy::default();
572
573        let handler = ToolCallHandler::new(
574            policy,
575            None,
576            emitter.clone(),
577            ToolCallHandlerConfig::default(),
578        );
579        // No lifecycle emitter set
580
581        let request = make_tool_call_request("safe_tool", serde_json::json!({}));
582        let mut state = PolicyState::default();
583
584        let result = handler.handle_tool_call(&request, &mut state, None, None, None);
585
586        assert!(matches!(result, HandleResult::Allow { .. }));
587        assert_eq!(emitter.0.load(Ordering::SeqCst), 1); // Decision emitted
588    }
589}