use std::any::Any;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
use agentkit_capabilities::{
CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptProvider,
ResourceProvider,
};
use agentkit_core::{
ApprovalId, Item, ItemKind, MetadataMap, Part, SessionId, TaskId, ToolCallId, ToolOutput,
ToolResultPart, TurnCancellation, TurnId,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use thiserror::Error;
#[doc(hidden)]
pub mod __private_async_trait {
pub use async_trait::async_trait;
}
#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ToolName(pub String);
impl ToolName {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl fmt::Display for ToolName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<&str> for ToolName {
fn from(value: &str) -> Self {
Self::new(value)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolAnnotations {
pub read_only_hint: bool,
pub destructive_hint: bool,
pub idempotent_hint: bool,
pub needs_approval_hint: bool,
pub supports_streaming_hint: bool,
}
impl ToolAnnotations {
pub fn new() -> Self {
Self::default()
}
pub fn read_only() -> Self {
Self::default().with_read_only(true)
}
pub fn destructive() -> Self {
Self::default().with_destructive(true)
}
pub fn needs_approval() -> Self {
Self::default().with_needs_approval(true)
}
pub fn streaming() -> Self {
Self::default().with_supports_streaming(true)
}
pub fn with_read_only(mut self, read_only_hint: bool) -> Self {
self.read_only_hint = read_only_hint;
self
}
pub fn with_destructive(mut self, destructive_hint: bool) -> Self {
self.destructive_hint = destructive_hint;
self
}
pub fn with_idempotent(mut self, idempotent_hint: bool) -> Self {
self.idempotent_hint = idempotent_hint;
self
}
pub fn with_needs_approval(mut self, needs_approval_hint: bool) -> Self {
self.needs_approval_hint = needs_approval_hint;
self
}
pub fn with_supports_streaming(mut self, supports_streaming_hint: bool) -> Self {
self.supports_streaming_hint = supports_streaming_hint;
self
}
}
pub const TOOL_OUTPUT_LIMIT_METADATA_KEY: &str = "agentkit.tool_output_limit";
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolOutputOverflowAction {
Fail,
InlineClip,
StoreForReadback,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolOutputLimit {
pub max_bytes: usize,
pub action: ToolOutputOverflowAction,
}
impl ToolOutputLimit {
pub fn fail(max_bytes: usize) -> Self {
Self {
max_bytes,
action: ToolOutputOverflowAction::Fail,
}
}
pub fn inline_clip(max_bytes: usize) -> Self {
Self {
max_bytes,
action: ToolOutputOverflowAction::InlineClip,
}
}
pub fn store_for_readback(max_bytes: usize) -> Self {
Self {
max_bytes,
action: ToolOutputOverflowAction::StoreForReadback,
}
}
fn to_metadata_value(&self) -> Value {
serde_json::to_value(self).expect("ToolOutputLimit serializes")
}
fn from_metadata(metadata: &MetadataMap) -> Option<Self> {
metadata
.get(TOOL_OUTPUT_LIMIT_METADATA_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ToolSpec {
pub name: ToolName,
pub description: String,
pub input_schema: Value,
pub annotations: ToolAnnotations,
pub metadata: MetadataMap,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolCatalogEvent {
pub source: String,
pub added: Vec<String>,
pub removed: Vec<String>,
pub changed: Vec<String>,
}
impl ToolCatalogEvent {
pub fn new(source: impl Into<String>) -> Self {
Self {
source: source.into(),
added: Vec::new(),
removed: Vec::new(),
changed: Vec::new(),
}
}
pub fn for_each_name_mut(&mut self, mut f: impl FnMut(&mut String)) {
for vec in [&mut self.added, &mut self.removed, &mut self.changed] {
for name in vec.iter_mut() {
f(name);
}
}
}
pub fn retain_names(&mut self, mut predicate: impl FnMut(&str) -> bool) {
self.added.retain(|n| predicate(n));
self.removed.retain(|n| predicate(n));
self.changed.retain(|n| predicate(n));
}
}
impl ToolSpec {
pub fn new(
name: impl Into<ToolName>,
description: impl Into<String>,
input_schema: Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
annotations: ToolAnnotations::default(),
metadata: MetadataMap::new(),
}
}
pub fn with_annotations(mut self, annotations: ToolAnnotations) -> Self {
self.annotations = annotations;
self
}
pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
self.metadata = metadata;
self
}
pub fn with_output_limit(mut self, limit: ToolOutputLimit) -> Self {
self.metadata.insert(
TOOL_OUTPUT_LIMIT_METADATA_KEY.to_string(),
limit.to_metadata_value(),
);
self
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ToolRequest {
pub call_id: ToolCallId,
pub tool_name: ToolName,
pub input: Value,
pub session_id: SessionId,
pub turn_id: TurnId,
pub metadata: MetadataMap,
}
impl ToolRequest {
pub fn new(
call_id: impl Into<ToolCallId>,
tool_name: impl Into<ToolName>,
input: Value,
session_id: impl Into<SessionId>,
turn_id: impl Into<TurnId>,
) -> Self {
Self {
call_id: call_id.into(),
tool_name: tool_name.into(),
input,
session_id: session_id.into(),
turn_id: turn_id.into(),
metadata: MetadataMap::new(),
}
}
pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ToolResult {
pub result: ToolResultPart,
pub duration: Option<Duration>,
pub metadata: MetadataMap,
}
impl ToolResult {
pub fn new(result: ToolResultPart) -> Self {
Self {
result,
duration: None,
metadata: MetadataMap::new(),
}
}
pub fn with_duration(mut self, duration: Duration) -> Self {
self.duration = Some(duration);
self
}
pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
self.metadata = metadata;
self
}
}
pub trait ToolResources: Send + Sync {
fn as_any(&self) -> &dyn Any;
}
impl ToolResources for () {
fn as_any(&self) -> &dyn Any {
self
}
}
pub struct ToolContext<'a> {
pub capability: CapabilityContext<'a>,
pub permissions: &'a dyn PermissionChecker,
pub resources: &'a dyn ToolResources,
pub cancellation: Option<TurnCancellation>,
}
#[derive(Clone)]
pub struct OwnedToolContext {
pub session_id: SessionId,
pub turn_id: TurnId,
pub metadata: MetadataMap,
pub permissions: Arc<dyn PermissionChecker>,
pub resources: Arc<dyn ToolResources>,
pub cancellation: Option<TurnCancellation>,
}
impl OwnedToolContext {
pub fn borrowed(&self) -> ToolContext<'_> {
ToolContext {
capability: CapabilityContext {
session_id: Some(&self.session_id),
turn_id: Some(&self.turn_id),
metadata: &self.metadata,
},
permissions: self.permissions.as_ref(),
resources: self.resources.as_ref(),
cancellation: self.cancellation.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct ToolOutputTruncationContext {
pub tool_name: ToolName,
pub call_id: ToolCallId,
pub session_id: SessionId,
pub turn_id: TurnId,
pub tool_spec: ToolSpec,
}
impl From<(&ToolRequest, ToolSpec)> for ToolOutputTruncationContext {
fn from((request, tool_spec): (&ToolRequest, ToolSpec)) -> Self {
Self {
tool_name: request.tool_name.clone(),
call_id: request.call_id.clone(),
session_id: request.session_id.clone(),
turn_id: request.turn_id.clone(),
tool_spec,
}
}
}
#[async_trait]
pub trait ToolOutputTruncationStrategy: Send + Sync {
async fn apply(
&self,
ctx: ToolOutputTruncationContext,
output: ToolOutput,
) -> Result<ToolOutput, ToolError>;
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ToolOutputArtifactId(pub String);
impl fmt::Display for ToolOutputArtifactId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolOutputArtifact {
pub id: ToolOutputArtifactId,
pub tool_name: ToolName,
pub call_id: ToolCallId,
pub session_id: SessionId,
pub turn_id: TurnId,
pub original_bytes: usize,
pub body: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolOutputArtifactSlice {
pub id: ToolOutputArtifactId,
pub offset: usize,
pub next_offset: usize,
pub original_bytes: usize,
pub eof: bool,
pub content: String,
}
#[async_trait]
pub trait ToolOutputArtifactStore: Send + Sync {
async fn put(
&self,
ctx: &ToolOutputTruncationContext,
body: String,
original_bytes: usize,
) -> Result<ToolOutputArtifact, ToolError>;
async fn read(
&self,
id: &ToolOutputArtifactId,
offset: usize,
max_bytes: usize,
) -> Result<ToolOutputArtifactSlice, ToolError>;
}
#[derive(Debug, Default)]
pub struct InMemoryToolOutputArtifactStore {
next_id: AtomicU64,
artifacts: Mutex<BTreeMap<ToolOutputArtifactId, ToolOutputArtifact>>,
}
impl InMemoryToolOutputArtifactStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl ToolOutputArtifactStore for InMemoryToolOutputArtifactStore {
async fn put(
&self,
ctx: &ToolOutputTruncationContext,
body: String,
original_bytes: usize,
) -> Result<ToolOutputArtifact, ToolError> {
let n = self.next_id.fetch_add(1, Ordering::Relaxed);
let id = ToolOutputArtifactId(format!(
"{}:{}:{}",
sanitize_artifact_id_component(ctx.session_id.0.as_str()),
sanitize_artifact_id_component(ctx.call_id.0.as_str()),
n
));
let artifact = ToolOutputArtifact {
id: id.clone(),
tool_name: ctx.tool_name.clone(),
call_id: ctx.call_id.clone(),
session_id: ctx.session_id.clone(),
turn_id: ctx.turn_id.clone(),
original_bytes,
body,
};
self.artifacts
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id, artifact.clone());
Ok(artifact)
}
async fn read(
&self,
id: &ToolOutputArtifactId,
offset: usize,
max_bytes: usize,
) -> Result<ToolOutputArtifactSlice, ToolError> {
let artifact = self
.artifacts
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(id)
.cloned()
.ok_or_else(|| {
ToolError::InvalidInput(format!("unknown tool result artifact: {id}"))
})?;
let body = artifact.body;
if offset > body.len() || !body.is_char_boundary(offset) {
return Err(ToolError::InvalidInput(format!(
"offset {offset} is not a UTF-8 boundary in tool result artifact {id}"
)));
}
let requested_end = offset.saturating_add(max_bytes).min(body.len());
let end = body.floor_char_boundary(requested_end);
Ok(ToolOutputArtifactSlice {
id: id.clone(),
offset,
next_offset: end,
original_bytes: artifact.original_bytes,
eof: end == body.len(),
content: body[offset..end].to_string(),
})
}
}
fn sanitize_artifact_id_component(s: &str) -> String {
let cleaned: String = s
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.take(64)
.collect();
if cleaned.is_empty() {
"_".to_string()
} else {
cleaned
}
}
pub struct ConfigurableToolOutputTruncationStrategy {
default_limit: Option<ToolOutputLimit>,
per_tool_limits: BTreeMap<ToolName, ToolOutputLimit>,
use_tool_metadata: bool,
store: Arc<dyn ToolOutputArtifactStore>,
}
impl ConfigurableToolOutputTruncationStrategy {
pub fn new(store: Arc<dyn ToolOutputArtifactStore>) -> Self {
Self {
default_limit: None,
per_tool_limits: BTreeMap::new(),
use_tool_metadata: true,
store,
}
}
pub fn with_default_limit(mut self, limit: ToolOutputLimit) -> Self {
self.default_limit = Some(limit);
self
}
pub fn with_tool_limit(
mut self,
tool_name: impl Into<ToolName>,
limit: ToolOutputLimit,
) -> Self {
self.per_tool_limits.insert(tool_name.into(), limit);
self
}
pub fn use_tool_metadata(mut self, value: bool) -> Self {
self.use_tool_metadata = value;
self
}
fn limit_for(&self, ctx: &ToolOutputTruncationContext) -> Option<ToolOutputLimit> {
self.per_tool_limits
.get(&ctx.tool_name)
.cloned()
.or_else(|| {
self.use_tool_metadata
.then(|| ToolOutputLimit::from_metadata(&ctx.tool_spec.metadata))
.flatten()
})
.or_else(|| self.default_limit.clone())
}
}
#[async_trait]
impl ToolOutputTruncationStrategy for ConfigurableToolOutputTruncationStrategy {
async fn apply(
&self,
ctx: ToolOutputTruncationContext,
output: ToolOutput,
) -> Result<ToolOutput, ToolError> {
let Some(limit) = self.limit_for(&ctx) else {
return Ok(output);
};
let model_bytes = tool_output_model_bytes(&output);
if model_bytes <= limit.max_bytes {
return Ok(output);
}
match limit.action {
ToolOutputOverflowAction::Fail => Err(ToolError::ExecutionFailed(format!(
"tool {} produced {model_bytes} bytes, exceeding configured limit of {} bytes",
ctx.tool_name, limit.max_bytes
))),
ToolOutputOverflowAction::InlineClip => Ok(clip_tool_output_inline(
output,
limit.max_bytes,
model_bytes,
)),
ToolOutputOverflowAction::StoreForReadback => {
let body = tool_output_readback_body(&output);
let artifact = self.store.put(&ctx, body, model_bytes).await?;
Ok(fit_structured_tool_output(
json!({
"truncated": true,
"tool_result_id": artifact.id.0,
"read_tool": TOOL_RESULT_READ_TOOL_NAME,
"read_args": {
"id": artifact.id.0,
"offset": 0,
"limit": limit.max_bytes
},
"original_bytes": artifact.original_bytes,
}),
limit.max_bytes,
))
}
}
}
}
fn tool_output_model_bytes(output: &ToolOutput) -> usize {
match output {
ToolOutput::Text(s) => s.len(),
other => serde_json::to_string(other)
.map(|s| s.len())
.unwrap_or(usize::MAX),
}
}
fn tool_output_readback_body(output: &ToolOutput) -> String {
match output {
ToolOutput::Text(s) => s.clone(),
ToolOutput::Structured(value) => {
serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string())
}
ToolOutput::Parts(parts) => serde_json::to_string_pretty(parts).unwrap_or_default(),
ToolOutput::Files(files) => serde_json::to_string_pretty(files).unwrap_or_default(),
}
}
fn clip_tool_output_inline(
output: ToolOutput,
max_bytes: usize,
original_bytes: usize,
) -> ToolOutput {
match output {
ToolOutput::Text(s) => {
ToolOutput::Text(clip_string_with_marker(&s, max_bytes, original_bytes))
}
other => {
let body = tool_output_readback_body(&other);
fit_structured_tool_output(
json!({
"truncated": true,
"original_bytes": original_bytes,
"content": body,
}),
max_bytes,
)
}
}
}
fn clip_string_with_marker(s: &str, max_bytes: usize, original_bytes: usize) -> String {
let marker = format!("\n[tool output truncated: original_bytes={original_bytes}]");
if marker.len() >= max_bytes {
let cut = marker.floor_char_boundary(max_bytes.min(marker.len()));
return marker[..cut].to_string();
}
let keep_bytes = max_bytes.saturating_sub(marker.len());
let cut = s.floor_char_boundary(keep_bytes.min(s.len()));
format!("{}{}", &s[..cut], marker)
}
fn fit_structured_tool_output(mut value: Value, max_bytes: usize) -> ToolOutput {
loop {
let output = ToolOutput::Structured(value.clone());
if tool_output_model_bytes(&output) <= max_bytes {
return output;
}
let Some(Value::String(content)) = value.get_mut("content") else {
return ToolOutput::Structured(json!({
"truncated": true,
"error": "tool output metadata exceeded configured max_bytes"
}));
};
if content.is_empty() {
return ToolOutput::Structured(json!({
"truncated": true,
"error": "tool output metadata exceeded configured max_bytes"
}));
}
let current_len = content.len();
let shrink_by = tool_output_model_bytes(&output)
.saturating_sub(max_bytes)
.saturating_add(32)
.min(current_len);
let new_len = content.floor_char_boundary(current_len - shrink_by);
content.truncate(new_len);
}
}
pub const TOOL_RESULT_READ_TOOL_NAME: &str = "tool_result_read";
const TOOL_RESULT_READ_OUTPUT_ENVELOPE_BYTES: usize = 4096;
const TOOL_RESULT_READ_JSON_ESCAPE_BYTES_PER_INPUT_BYTE: usize = 6;
#[derive(Clone)]
pub struct ToolResultReadTool {
spec: ToolSpec,
store: Arc<dyn ToolOutputArtifactStore>,
max_read_bytes: usize,
}
impl ToolResultReadTool {
pub fn new(store: Arc<dyn ToolOutputArtifactStore>, max_read_bytes: usize) -> Self {
Self {
spec: ToolSpec::new(
TOOL_RESULT_READ_TOOL_NAME,
"Read a bounded UTF-8 byte slice from a stored oversized tool result.",
json!({
"type": "object",
"properties": {
"id": { "type": "string" },
"offset": { "type": "integer", "minimum": 0 },
"limit": { "type": "integer", "minimum": 1 }
},
"required": ["id", "offset", "limit"],
"additionalProperties": false
}),
)
.with_annotations(ToolAnnotations {
read_only_hint: true,
idempotent_hint: true,
..ToolAnnotations::default()
})
.with_output_limit(ToolOutputLimit::fail(
max_read_bytes
.saturating_mul(TOOL_RESULT_READ_JSON_ESCAPE_BYTES_PER_INPUT_BYTE)
.saturating_add(TOOL_RESULT_READ_OUTPUT_ENVELOPE_BYTES),
)),
store,
max_read_bytes,
}
}
}
#[derive(Deserialize)]
struct ToolResultReadInput {
id: String,
offset: usize,
limit: usize,
}
#[async_trait]
impl Tool for ToolResultReadTool {
fn spec(&self) -> &ToolSpec {
&self.spec
}
async fn invoke(
&self,
request: ToolRequest,
_ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
let input: ToolResultReadInput = serde_json::from_value(request.input.clone())
.map_err(|error| ToolError::InvalidInput(format!("invalid tool input: {error}")))?;
if input.limit == 0 {
return Err(ToolError::InvalidInput(
"limit must be greater than 0".to_string(),
));
}
if input.limit > self.max_read_bytes {
return Err(ToolError::InvalidInput(format!(
"limit {} exceeds maximum read size of {} bytes",
input.limit, self.max_read_bytes
)));
}
let slice = self
.store
.read(&ToolOutputArtifactId(input.id), input.offset, input.limit)
.await?;
Ok(ToolResult::new(ToolResultPart::success(
request.call_id,
ToolOutput::Structured(json!({
"id": slice.id.0,
"offset": slice.offset,
"next_offset": slice.next_offset,
"original_bytes": slice.original_bytes,
"eof": slice.eof,
"content": slice.content,
})),
)))
}
}
pub fn tool_result_readback_registry(
store: Arc<dyn ToolOutputArtifactStore>,
max_read_bytes: usize,
) -> ToolRegistry {
ToolRegistry::new().with(ToolResultReadTool::new(store, max_read_bytes))
}
pub trait PermissionRequest: Send + Sync {
fn kind(&self) -> &'static str;
fn summary(&self) -> String;
fn metadata(&self) -> &MetadataMap;
fn as_any(&self) -> &dyn Any;
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum PermissionCode {
PathNotAllowed,
CommandNotAllowed,
NetworkNotAllowed,
ServerNotTrusted,
AuthScopeNotAllowed,
CustomPolicyDenied,
UnknownRequest,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PermissionDenial {
pub code: PermissionCode,
pub message: String,
pub metadata: MetadataMap,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ApprovalReason {
PolicyRequiresConfirmation,
EscalatedRisk,
UnknownTarget,
SensitivePath,
SensitiveCommand,
SensitiveServer,
SensitiveAuthScope,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ApprovalRequest {
pub task_id: Option<TaskId>,
pub call_id: Option<ToolCallId>,
pub id: ApprovalId,
pub request_kind: String,
pub reason: ApprovalReason,
pub summary: String,
pub metadata: MetadataMap,
}
impl ApprovalRequest {
pub fn new(
id: impl Into<ApprovalId>,
request_kind: impl Into<String>,
reason: ApprovalReason,
summary: impl Into<String>,
) -> Self {
Self {
task_id: None,
call_id: None,
id: id.into(),
request_kind: request_kind.into(),
reason,
summary: summary.into(),
metadata: MetadataMap::new(),
}
}
pub fn with_task_id(mut self, task_id: impl Into<TaskId>) -> Self {
self.task_id = Some(task_id.into());
self
}
pub fn with_call_id(mut self, call_id: impl Into<ToolCallId>) -> Self {
self.call_id = Some(call_id.into());
self
}
pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ApprovalDecision {
Approve,
Deny {
reason: Option<String>,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ToolInterruption {
ApprovalRequired(ApprovalRequest),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum PermissionDecision {
Allow,
Deny(PermissionDenial),
RequireApproval(ApprovalRequest),
}
pub trait PermissionChecker: Send + Sync {
fn evaluate(&self, request: &dyn PermissionRequest) -> PermissionDecision;
}
#[derive(Copy, Clone, Debug, Default)]
pub struct AllowAllPermissions;
impl PermissionChecker for AllowAllPermissions {
fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
PermissionDecision::Allow
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum PolicyMatch {
NoOpinion,
Allow,
Deny(PermissionDenial),
RequireApproval(ApprovalRequest),
}
pub trait PermissionPolicy: Send + Sync {
fn evaluate(&self, request: &dyn PermissionRequest) -> PolicyMatch;
}
pub struct CompositePermissionChecker {
policies: Vec<Box<dyn PermissionPolicy>>,
fallback: PermissionDecision,
}
impl CompositePermissionChecker {
pub fn new(fallback: PermissionDecision) -> Self {
Self {
policies: Vec::new(),
fallback,
}
}
pub fn with_policy(mut self, policy: impl PermissionPolicy + 'static) -> Self {
self.policies.push(Box::new(policy));
self
}
}
impl PermissionChecker for CompositePermissionChecker {
fn evaluate(&self, request: &dyn PermissionRequest) -> PermissionDecision {
let mut saw_allow = false;
let mut approval = None;
for policy in &self.policies {
match policy.evaluate(request) {
PolicyMatch::NoOpinion => {}
PolicyMatch::Allow => saw_allow = true,
PolicyMatch::Deny(denial) => return PermissionDecision::Deny(denial),
PolicyMatch::RequireApproval(req) => approval = Some(req),
}
}
if let Some(req) = approval {
PermissionDecision::RequireApproval(req)
} else if saw_allow {
PermissionDecision::Allow
} else {
self.fallback.clone()
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShellPermissionRequest {
pub executable: String,
pub argv: Vec<String>,
pub cwd: Option<PathBuf>,
pub env_keys: Vec<String>,
pub metadata: MetadataMap,
}
impl PermissionRequest for ShellPermissionRequest {
fn kind(&self) -> &'static str {
"shell.command"
}
fn summary(&self) -> String {
if self.argv.is_empty() {
self.executable.clone()
} else {
format!("{} {}", self.executable, self.argv.join(" "))
}
}
fn metadata(&self) -> &MetadataMap {
&self.metadata
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum FileSystemPermissionRequest {
Read {
path: PathBuf,
metadata: MetadataMap,
},
Write {
path: PathBuf,
metadata: MetadataMap,
},
Edit {
path: PathBuf,
metadata: MetadataMap,
},
Delete {
path: PathBuf,
metadata: MetadataMap,
},
Move {
from: PathBuf,
to: PathBuf,
metadata: MetadataMap,
},
List {
path: PathBuf,
metadata: MetadataMap,
},
CreateDir {
path: PathBuf,
metadata: MetadataMap,
},
}
impl FileSystemPermissionRequest {
fn metadata_map(&self) -> &MetadataMap {
match self {
Self::Read { metadata, .. }
| Self::Write { metadata, .. }
| Self::Edit { metadata, .. }
| Self::Delete { metadata, .. }
| Self::Move { metadata, .. }
| Self::List { metadata, .. }
| Self::CreateDir { metadata, .. } => metadata,
}
}
}
impl PermissionRequest for FileSystemPermissionRequest {
fn kind(&self) -> &'static str {
match self {
Self::Read { .. } => "filesystem.read",
Self::Write { .. } => "filesystem.write",
Self::Edit { .. } => "filesystem.edit",
Self::Delete { .. } => "filesystem.delete",
Self::Move { .. } => "filesystem.move",
Self::List { .. } => "filesystem.list",
Self::CreateDir { .. } => "filesystem.mkdir",
}
}
fn summary(&self) -> String {
match self {
Self::Read { path, .. } => format!("Read {}", path.display()),
Self::Write { path, .. } => format!("Write {}", path.display()),
Self::Edit { path, .. } => format!("Edit {}", path.display()),
Self::Delete { path, .. } => format!("Delete {}", path.display()),
Self::Move { from, to, .. } => {
format!("Move {} to {}", from.display(), to.display())
}
Self::List { path, .. } => format!("List {}", path.display()),
Self::CreateDir { path, .. } => format!("Create directory {}", path.display()),
}
}
fn metadata(&self) -> &MetadataMap {
self.metadata_map()
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum McpPermissionRequest {
Connect {
server_id: String,
metadata: MetadataMap,
},
InvokeTool {
server_id: String,
tool_name: String,
metadata: MetadataMap,
},
ReadResource {
server_id: String,
resource_id: String,
metadata: MetadataMap,
},
FetchPrompt {
server_id: String,
prompt_id: String,
metadata: MetadataMap,
},
UseAuthScope {
server_id: String,
scope: String,
metadata: MetadataMap,
},
}
impl McpPermissionRequest {
fn metadata_map(&self) -> &MetadataMap {
match self {
Self::Connect { metadata, .. }
| Self::InvokeTool { metadata, .. }
| Self::ReadResource { metadata, .. }
| Self::FetchPrompt { metadata, .. }
| Self::UseAuthScope { metadata, .. } => metadata,
}
}
}
impl PermissionRequest for McpPermissionRequest {
fn kind(&self) -> &'static str {
match self {
Self::Connect { .. } => "mcp.connect",
Self::InvokeTool { .. } => "mcp.invoke_tool",
Self::ReadResource { .. } => "mcp.read_resource",
Self::FetchPrompt { .. } => "mcp.fetch_prompt",
Self::UseAuthScope { .. } => "mcp.use_auth_scope",
}
}
fn summary(&self) -> String {
match self {
Self::Connect { server_id, .. } => format!("Connect MCP server {server_id}"),
Self::InvokeTool {
server_id,
tool_name,
..
} => format!("Invoke MCP tool {server_id}.{tool_name}"),
Self::ReadResource {
server_id,
resource_id,
..
} => format!("Read MCP resource {server_id}:{resource_id}"),
Self::FetchPrompt {
server_id,
prompt_id,
..
} => format!("Fetch MCP prompt {server_id}:{prompt_id}"),
Self::UseAuthScope {
server_id, scope, ..
} => format!("Use MCP auth scope {server_id}:{scope}"),
}
}
fn metadata(&self) -> &MetadataMap {
self.metadata_map()
}
fn as_any(&self) -> &dyn Any {
self
}
}
pub struct CustomKindPolicy {
allowed_kinds: BTreeSet<String>,
denied_kinds: BTreeSet<String>,
require_approval_by_default: bool,
}
impl CustomKindPolicy {
pub fn new(require_approval_by_default: bool) -> Self {
Self {
allowed_kinds: BTreeSet::new(),
denied_kinds: BTreeSet::new(),
require_approval_by_default,
}
}
pub fn allow_kind(mut self, kind: impl Into<String>) -> Self {
self.allowed_kinds.insert(kind.into());
self
}
pub fn deny_kind(mut self, kind: impl Into<String>) -> Self {
self.denied_kinds.insert(kind.into());
self
}
}
impl PermissionPolicy for CustomKindPolicy {
fn evaluate(&self, request: &dyn PermissionRequest) -> PolicyMatch {
let kind = request.kind();
if !kind.starts_with("custom.") {
return PolicyMatch::NoOpinion;
}
if self.denied_kinds.contains(kind) {
return PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::CustomPolicyDenied,
message: format!("custom permission kind {kind} is denied"),
metadata: request.metadata().clone(),
});
}
if self.allowed_kinds.contains(kind) {
return PolicyMatch::Allow;
}
if self.require_approval_by_default {
PolicyMatch::RequireApproval(ApprovalRequest {
task_id: None,
call_id: None,
id: ApprovalId::new(format!("approval:{kind}")),
request_kind: kind.to_string(),
reason: ApprovalReason::PolicyRequiresConfirmation,
summary: request.summary(),
metadata: request.metadata().clone(),
})
} else {
PolicyMatch::NoOpinion
}
}
}
pub struct PathPolicy {
allowed_roots: Vec<CanonicalRoot>,
read_only_roots: Vec<CanonicalRoot>,
protected_roots: Vec<CanonicalRoot>,
require_approval_outside_allowed: bool,
}
impl PathPolicy {
pub fn new() -> Self {
Self {
allowed_roots: Vec::new(),
read_only_roots: Vec::new(),
protected_roots: Vec::new(),
require_approval_outside_allowed: true,
}
}
pub fn allow_root(mut self, root: impl Into<PathBuf>) -> Self {
self.allowed_roots.push(CanonicalRoot::new(root.into()));
self
}
pub fn read_only_root(mut self, root: impl Into<PathBuf>) -> Self {
self.read_only_roots.push(CanonicalRoot::new(root.into()));
self
}
pub fn protect_root(mut self, root: impl Into<PathBuf>) -> Self {
self.protected_roots.push(CanonicalRoot::new(root.into()));
self
}
pub fn require_approval_outside_allowed(mut self, value: bool) -> Self {
self.require_approval_outside_allowed = value;
self
}
}
impl Default for PathPolicy {
fn default() -> Self {
Self::new()
}
}
fn resolve_canonical(path: &Path) -> PathBuf {
let abs = std::path::absolute(path).unwrap_or_else(|_| path.to_path_buf());
canonicalize_with_partial_fallback(&abs).unwrap_or(abs)
}
fn canonicalize_with_partial_fallback(abs: &Path) -> Option<PathBuf> {
if let Ok(canonical) = std::fs::canonicalize(abs) {
return Some(canonical);
}
let mut tail: Vec<std::ffi::OsString> = Vec::new();
let mut current = abs.to_path_buf();
loop {
let name = current.file_name().map(|n| n.to_os_string())?;
tail.push(name);
if !current.pop() {
return None;
}
if let Ok(canonical) = std::fs::canonicalize(¤t) {
let mut out = canonical;
for seg in tail.iter().rev() {
out.push(seg);
}
return Some(out);
}
}
}
struct CanonicalRoot {
lexical: PathBuf,
canonical: OnceLock<PathBuf>,
}
impl CanonicalRoot {
fn new(lexical: PathBuf) -> Self {
Self {
lexical,
canonical: OnceLock::new(),
}
}
fn resolve(&self) -> std::borrow::Cow<'_, Path> {
if let Some(canonical) = self.canonical.get() {
return std::borrow::Cow::Borrowed(canonical);
}
let abs = std::path::absolute(&self.lexical).unwrap_or_else(|_| self.lexical.clone());
if let Ok(canonical) = std::fs::canonicalize(&abs) {
let _ = self.canonical.set(canonical);
return std::borrow::Cow::Borrowed(self.canonical.get().unwrap());
}
std::borrow::Cow::Owned(canonicalize_with_partial_fallback(&abs).unwrap_or(abs))
}
}
impl PermissionPolicy for PathPolicy {
fn evaluate(&self, request: &dyn PermissionRequest) -> PolicyMatch {
let Some(fs) = request
.as_any()
.downcast_ref::<FileSystemPermissionRequest>()
else {
return PolicyMatch::NoOpinion;
};
let raw_paths: Vec<&Path> = match fs {
FileSystemPermissionRequest::Move { from, to, .. } => {
vec![from.as_path(), to.as_path()]
}
FileSystemPermissionRequest::Read { path, .. }
| FileSystemPermissionRequest::Write { path, .. }
| FileSystemPermissionRequest::Edit { path, .. }
| FileSystemPermissionRequest::Delete { path, .. }
| FileSystemPermissionRequest::List { path, .. }
| FileSystemPermissionRequest::CreateDir { path, .. } => vec![path.as_path()],
};
let candidate_paths: Vec<PathBuf> =
raw_paths.iter().map(|p| resolve_canonical(p)).collect();
let mutates = matches!(
fs,
FileSystemPermissionRequest::Write { .. }
| FileSystemPermissionRequest::Edit { .. }
| FileSystemPermissionRequest::Delete { .. }
| FileSystemPermissionRequest::Move { .. }
| FileSystemPermissionRequest::CreateDir { .. }
);
if candidate_paths.iter().any(|path| {
self.protected_roots
.iter()
.any(|root| path.starts_with(root.resolve().as_ref()))
}) {
return PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::PathNotAllowed,
message: format!("path access denied for {}", fs.summary()),
metadata: fs.metadata().clone(),
});
}
if mutates
&& candidate_paths.iter().any(|path| {
self.read_only_roots
.iter()
.any(|root| path.starts_with(root.resolve().as_ref()))
})
{
return PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::PathNotAllowed,
message: format!("path is read-only for {}", fs.summary()),
metadata: fs.metadata().clone(),
});
}
if self.allowed_roots.is_empty() {
return PolicyMatch::NoOpinion;
}
let all_allowed = candidate_paths.iter().all(|path| {
self.allowed_roots
.iter()
.any(|root| path.starts_with(root.resolve().as_ref()))
});
if all_allowed {
PolicyMatch::Allow
} else if self.require_approval_outside_allowed {
PolicyMatch::RequireApproval(ApprovalRequest {
task_id: None,
call_id: None,
id: ApprovalId::new(format!("approval:{}", fs.kind())),
request_kind: fs.kind().to_string(),
reason: ApprovalReason::SensitivePath,
summary: fs.summary(),
metadata: fs.metadata().clone(),
})
} else {
PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::PathNotAllowed,
message: format!("path outside allowed roots for {}", fs.summary()),
metadata: fs.metadata().clone(),
})
}
}
}
pub struct CommandPolicy {
allowed_executables: BTreeSet<String>,
denied_executables: BTreeSet<String>,
allowed_cwds: Vec<PathBuf>,
denied_env_keys: BTreeSet<String>,
require_approval_for_unknown: bool,
}
impl CommandPolicy {
pub fn new() -> Self {
Self {
allowed_executables: BTreeSet::new(),
denied_executables: BTreeSet::new(),
allowed_cwds: Vec::new(),
denied_env_keys: BTreeSet::new(),
require_approval_for_unknown: true,
}
}
pub fn allow_executable(mut self, executable: impl Into<String>) -> Self {
self.allowed_executables.insert(executable.into());
self
}
pub fn deny_executable(mut self, executable: impl Into<String>) -> Self {
self.denied_executables.insert(executable.into());
self
}
pub fn allow_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
self.allowed_cwds.push(cwd.into());
self
}
pub fn deny_env_key(mut self, key: impl Into<String>) -> Self {
self.denied_env_keys.insert(key.into());
self
}
pub fn require_approval_for_unknown(mut self, value: bool) -> Self {
self.require_approval_for_unknown = value;
self
}
}
impl Default for CommandPolicy {
fn default() -> Self {
Self::new()
}
}
impl PermissionPolicy for CommandPolicy {
fn evaluate(&self, request: &dyn PermissionRequest) -> PolicyMatch {
let Some(shell) = request.as_any().downcast_ref::<ShellPermissionRequest>() else {
return PolicyMatch::NoOpinion;
};
if self.denied_executables.contains(&shell.executable)
|| shell
.env_keys
.iter()
.any(|key| self.denied_env_keys.contains(key))
{
return PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::CommandNotAllowed,
message: format!("command denied for {}", shell.summary()),
metadata: shell.metadata().clone(),
});
}
if let Some(cwd) = &shell.cwd
&& !self.allowed_cwds.is_empty()
&& !self.allowed_cwds.iter().any(|root| cwd.starts_with(root))
{
return PolicyMatch::RequireApproval(ApprovalRequest {
task_id: None,
call_id: None,
id: ApprovalId::new("approval:shell.cwd"),
request_kind: shell.kind().to_string(),
reason: ApprovalReason::SensitiveCommand,
summary: shell.summary(),
metadata: shell.metadata().clone(),
});
}
if self.allowed_executables.is_empty()
|| self.allowed_executables.contains(&shell.executable)
{
PolicyMatch::Allow
} else if self.require_approval_for_unknown {
PolicyMatch::RequireApproval(ApprovalRequest {
task_id: None,
call_id: None,
id: ApprovalId::new("approval:shell.command"),
request_kind: shell.kind().to_string(),
reason: ApprovalReason::SensitiveCommand,
summary: shell.summary(),
metadata: shell.metadata().clone(),
})
} else {
PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::CommandNotAllowed,
message: format!("executable {} is not allowed", shell.executable),
metadata: shell.metadata().clone(),
})
}
}
}
pub struct McpServerPolicy {
trusted_servers: BTreeSet<String>,
allowed_auth_scopes: BTreeSet<String>,
require_approval_for_untrusted: bool,
}
impl McpServerPolicy {
pub fn new() -> Self {
Self {
trusted_servers: BTreeSet::new(),
allowed_auth_scopes: BTreeSet::new(),
require_approval_for_untrusted: true,
}
}
pub fn trust_server(mut self, server_id: impl Into<String>) -> Self {
self.trusted_servers.insert(server_id.into());
self
}
pub fn allow_auth_scope(mut self, scope: impl Into<String>) -> Self {
self.allowed_auth_scopes.insert(scope.into());
self
}
}
impl Default for McpServerPolicy {
fn default() -> Self {
Self::new()
}
}
impl PermissionPolicy for McpServerPolicy {
fn evaluate(&self, request: &dyn PermissionRequest) -> PolicyMatch {
let Some(mcp) = request.as_any().downcast_ref::<McpPermissionRequest>() else {
return PolicyMatch::NoOpinion;
};
let server_id = match mcp {
McpPermissionRequest::Connect { server_id, .. }
| McpPermissionRequest::InvokeTool { server_id, .. }
| McpPermissionRequest::ReadResource { server_id, .. }
| McpPermissionRequest::FetchPrompt { server_id, .. }
| McpPermissionRequest::UseAuthScope { server_id, .. } => server_id,
};
if !self.trusted_servers.is_empty() && !self.trusted_servers.contains(server_id) {
return if self.require_approval_for_untrusted {
PolicyMatch::RequireApproval(ApprovalRequest {
task_id: None,
call_id: None,
id: ApprovalId::new(format!("approval:mcp:{server_id}")),
request_kind: mcp.kind().to_string(),
reason: ApprovalReason::SensitiveServer,
summary: mcp.summary(),
metadata: mcp.metadata().clone(),
})
} else {
PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::ServerNotTrusted,
message: format!("MCP server {server_id} is not trusted"),
metadata: mcp.metadata().clone(),
})
};
}
if let McpPermissionRequest::UseAuthScope { scope, .. } = mcp
&& !self.allowed_auth_scopes.is_empty()
&& !self.allowed_auth_scopes.contains(scope)
{
return PolicyMatch::Deny(PermissionDenial {
code: PermissionCode::AuthScopeNotAllowed,
message: format!("MCP auth scope {scope} is not allowed"),
metadata: mcp.metadata().clone(),
});
}
PolicyMatch::Allow
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn spec(&self) -> &ToolSpec;
fn current_spec(&self) -> Option<ToolSpec> {
Some(self.spec().clone())
}
fn proposed_requests(
&self,
_request: &ToolRequest,
) -> Result<Vec<Box<dyn PermissionRequest>>, ToolError> {
Ok(Vec::new())
}
async fn invoke(
&self,
request: ToolRequest,
ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError>;
}
#[derive(Clone, Default)]
pub struct ToolRegistry {
tools: BTreeMap<ToolName, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<T>(&mut self, tool: T) -> &mut Self
where
T: Tool + 'static,
{
self.tools.insert(tool.spec().name.clone(), Arc::new(tool));
self
}
pub fn with<T>(mut self, tool: T) -> Self
where
T: Tool + 'static,
{
self.register(tool);
self
}
pub fn register_arc(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
self.tools.insert(tool.spec().name.clone(), tool);
self
}
pub fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
self.tools.values().cloned().collect()
}
pub fn merge(mut self, other: Self) -> Self {
self.tools.extend(other.tools);
self
}
pub fn specs(&self) -> Vec<ToolSpec> {
self.tools
.values()
.filter_map(|tool| tool.current_spec())
.collect()
}
}
pub trait ToolSource: Send + Sync {
fn specs(&self) -> Vec<ToolSpec>;
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>>;
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
Vec::new()
}
fn prefixed(self, prefix: impl Into<String>) -> Prefixed<Self>
where
Self: Sized,
{
Prefixed::new(self, prefix)
}
fn filtered<F>(self, predicate: F) -> Filtered<Self, F>
where
Self: Sized,
F: Fn(&ToolName) -> bool + Send + Sync + 'static,
{
Filtered::new(self, predicate)
}
fn renamed<I>(self, mapping: I) -> Renamed<Self>
where
Self: Sized,
I: IntoIterator<Item = (ToolName, ToolName)>,
{
Renamed::new(self, mapping)
}
}
impl ToolSource for ToolRegistry {
fn specs(&self) -> Vec<ToolSpec> {
ToolRegistry::specs(self)
}
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
ToolRegistry::get(self, name)
}
}
impl<S> ToolSource for Arc<S>
where
S: ToolSource + ?Sized,
{
fn specs(&self) -> Vec<ToolSpec> {
(**self).specs()
}
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
(**self).get(name)
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
(**self).drain_catalog_events()
}
}
pub struct Prefixed<S> {
inner: S,
prefix: String,
}
impl<S> Prefixed<S> {
pub fn new(inner: S, prefix: impl Into<String>) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}
fn rewrite(&self, name: &str) -> String {
format!("{}_{}", self.prefix, name)
}
fn strip<'a>(&self, name: &'a str) -> Option<&'a str> {
name.strip_prefix(self.prefix.as_str())
.and_then(|rest| rest.strip_prefix('_'))
}
}
impl<S> ToolSource for Prefixed<S>
where
S: ToolSource,
{
fn specs(&self) -> Vec<ToolSpec> {
self.inner
.specs()
.into_iter()
.map(|mut spec| {
spec.name = ToolName::new(self.rewrite(spec.name.0.as_str()));
spec
})
.collect()
}
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
let original = self.strip(name.0.as_str())?;
let inner_name = ToolName::new(original);
let inner_tool = self.inner.get(&inner_name)?;
let mut public_spec = inner_tool.spec().clone();
public_spec.name = name.clone();
Some(Arc::new(RewrittenTool {
inner: inner_tool,
inner_name,
public_spec,
}))
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
self.inner
.drain_catalog_events()
.into_iter()
.map(|mut event| {
event.for_each_name_mut(|name| *name = self.rewrite(name.as_str()));
event
})
.collect()
}
}
pub struct Filtered<S, F> {
inner: S,
predicate: F,
}
impl<S, F> Filtered<S, F> {
pub fn new(inner: S, predicate: F) -> Self {
Self { inner, predicate }
}
}
impl<S, F> ToolSource for Filtered<S, F>
where
S: ToolSource,
F: Fn(&ToolName) -> bool + Send + Sync + 'static,
{
fn specs(&self) -> Vec<ToolSpec> {
self.inner
.specs()
.into_iter()
.filter(|spec| (self.predicate)(&spec.name))
.collect()
}
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
if !(self.predicate)(name) {
return None;
}
self.inner.get(name)
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
self.inner
.drain_catalog_events()
.into_iter()
.map(|mut event| {
event.retain_names(|n| (self.predicate)(&ToolName::new(n)));
event
})
.collect()
}
}
pub struct Renamed<S> {
inner: S,
forward: BTreeMap<ToolName, ToolName>,
backward: BTreeMap<ToolName, ToolName>,
}
impl<S> Renamed<S> {
pub fn new<I>(inner: S, mapping: I) -> Self
where
I: IntoIterator<Item = (ToolName, ToolName)>,
{
let forward: BTreeMap<ToolName, ToolName> = mapping.into_iter().collect();
let backward = forward
.iter()
.map(|(k, v)| (v.clone(), k.clone()))
.collect();
Self {
inner,
forward,
backward,
}
}
}
impl<S> ToolSource for Renamed<S>
where
S: ToolSource,
{
fn specs(&self) -> Vec<ToolSpec> {
self.inner
.specs()
.into_iter()
.map(|mut spec| {
if let Some(new_name) = self.forward.get(&spec.name) {
spec.name = new_name.clone();
}
spec
})
.collect()
}
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
if let Some(original) = self.backward.get(name) {
let inner_tool = self.inner.get(original)?;
let mut public_spec = inner_tool.spec().clone();
public_spec.name = name.clone();
Some(Arc::new(RewrittenTool {
inner: inner_tool,
inner_name: original.clone(),
public_spec,
}))
} else if self.forward.contains_key(name) {
None
} else {
self.inner.get(name)
}
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
self.inner
.drain_catalog_events()
.into_iter()
.map(|mut event| {
event.for_each_name_mut(|name| {
if let Some(new) = self.forward.get(&ToolName::new(name.as_str())) {
*name = new.0.clone();
}
});
event
})
.collect()
}
}
#[cfg(feature = "schemars")]
pub fn schema_for<T: schemars::JsonSchema>() -> Value {
let schema = schemars::schema_for!(T);
serde_json::to_value(schema)
.expect("schemars produces valid JSON; this conversion is infallible")
}
#[cfg(feature = "schemars")]
pub fn tool_spec_for<T: schemars::JsonSchema>(
name: impl Into<ToolName>,
description: impl Into<String>,
) -> ToolSpec {
ToolSpec::new(name, description, schema_for::<T>())
}
struct RewrittenTool {
inner: Arc<dyn Tool>,
inner_name: ToolName,
public_spec: ToolSpec,
}
#[async_trait]
impl Tool for RewrittenTool {
fn spec(&self) -> &ToolSpec {
&self.public_spec
}
fn current_spec(&self) -> Option<ToolSpec> {
let inner_current = self.inner.current_spec()?;
Some(ToolSpec {
name: self.public_spec.name.clone(),
description: inner_current.description,
input_schema: inner_current.input_schema,
annotations: inner_current.annotations,
metadata: inner_current.metadata,
})
}
fn proposed_requests(
&self,
request: &ToolRequest,
) -> Result<Vec<Box<dyn PermissionRequest>>, ToolError> {
let mut inner_request = request.clone();
inner_request.tool_name = self.inner_name.clone();
self.inner.proposed_requests(&inner_request)
}
async fn invoke(
&self,
mut request: ToolRequest,
ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
request.tool_name = self.inner_name.clone();
self.inner.invoke(request, ctx).await
}
}
struct ToolMap {
inner: std::sync::RwLock<BTreeMap<ToolName, Arc<dyn Tool>>>,
}
impl ToolMap {
fn new() -> Self {
Self {
inner: std::sync::RwLock::new(BTreeMap::new()),
}
}
fn read(&self) -> std::sync::RwLockReadGuard<'_, BTreeMap<ToolName, Arc<dyn Tool>>> {
self.inner.read().unwrap_or_else(|e| e.into_inner())
}
fn write(&self) -> std::sync::RwLockWriteGuard<'_, BTreeMap<ToolName, Arc<dyn Tool>>> {
self.inner.write().unwrap_or_else(|e| e.into_inner())
}
}
struct DynamicCatalogInner {
source_id: String,
tools: ToolMap,
events_tx: tokio::sync::broadcast::Sender<ToolCatalogEvent>,
}
pub fn dynamic_catalog(source_id: impl Into<String>) -> (CatalogWriter, CatalogReader) {
let (events_tx, events_rx) = tokio::sync::broadcast::channel(128);
let inner = Arc::new(DynamicCatalogInner {
source_id: source_id.into(),
tools: ToolMap::new(),
events_tx,
});
(
CatalogWriter {
inner: Arc::clone(&inner),
},
CatalogReader {
inner,
events_rx: std::sync::Mutex::new(events_rx),
},
)
}
pub struct CatalogWriter {
inner: Arc<DynamicCatalogInner>,
}
impl CatalogWriter {
pub fn source_id(&self) -> &str {
&self.inner.source_id
}
pub fn reader(&self) -> CatalogReader {
CatalogReader {
inner: Arc::clone(&self.inner),
events_rx: std::sync::Mutex::new(self.inner.events_tx.subscribe()),
}
}
pub fn upsert(&self, tool: Arc<dyn Tool>) {
let name = tool.spec().name.clone();
let mut guard = self.inner.tools.write();
let existed = guard.insert(name.clone(), tool).is_some();
drop(guard);
let mut event = ToolCatalogEvent::new(self.inner.source_id.clone());
if existed {
event.changed.push(name.0);
} else {
event.added.push(name.0);
}
let _ = self.inner.events_tx.send(event);
}
pub fn remove(&self, name: &ToolName) -> bool {
let mut guard = self.inner.tools.write();
let removed = guard.remove(name).is_some();
drop(guard);
if removed {
let mut event = ToolCatalogEvent::new(self.inner.source_id.clone());
event.removed.push(name.0.clone());
let _ = self.inner.events_tx.send(event);
}
removed
}
pub fn replace_all(&self, tools: impl IntoIterator<Item = Arc<dyn Tool>>) {
let new_map: BTreeMap<ToolName, Arc<dyn Tool>> = tools
.into_iter()
.map(|tool| (tool.spec().name.clone(), tool))
.collect();
let mut guard = self.inner.tools.write();
let mut event = ToolCatalogEvent::new(self.inner.source_id.clone());
for (name, new_tool) in new_map.iter() {
match guard.get(name) {
None => event.added.push(name.0.clone()),
Some(existing)
if !Arc::ptr_eq(existing, new_tool)
&& existing.current_spec() != new_tool.current_spec() =>
{
event.changed.push(name.0.clone());
}
Some(_) => {}
}
}
for name in guard.keys() {
if !new_map.contains_key(name) {
event.removed.push(name.0.clone());
}
}
*guard = new_map;
drop(guard);
if !event.added.is_empty() || !event.removed.is_empty() || !event.changed.is_empty() {
let _ = self.inner.events_tx.send(event);
}
}
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ToolCatalogEvent> {
self.inner.events_tx.subscribe()
}
}
pub struct CatalogReader {
inner: Arc<DynamicCatalogInner>,
events_rx: std::sync::Mutex<tokio::sync::broadcast::Receiver<ToolCatalogEvent>>,
}
impl CatalogReader {
pub fn source_id(&self) -> &str {
&self.inner.source_id
}
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ToolCatalogEvent> {
self.inner.events_tx.subscribe()
}
}
impl Clone for CatalogReader {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
events_rx: std::sync::Mutex::new(self.inner.events_tx.subscribe()),
}
}
}
impl ToolSource for CatalogReader {
fn specs(&self) -> Vec<ToolSpec> {
self.inner
.tools
.read()
.values()
.filter_map(|tool| tool.current_spec())
.collect()
}
fn get(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
self.inner.tools.read().get(name).cloned()
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
let mut rx = self.events_rx.lock().unwrap_or_else(|e| e.into_inner());
let mut out = Vec::new();
loop {
match rx.try_recv() {
Ok(event) => out.push(event),
Err(tokio::sync::broadcast::error::TryRecvError::Empty) => break,
Err(tokio::sync::broadcast::error::TryRecvError::Closed) => break,
Err(tokio::sync::broadcast::error::TryRecvError::Lagged(_)) => continue,
}
}
out
}
}
impl ToolSpec {
pub fn as_invocable_spec(&self) -> InvocableSpec {
InvocableSpec::new(
CapabilityName::new(self.name.0.clone()),
self.description.clone(),
self.input_schema.clone(),
)
.with_metadata(self.metadata.clone())
}
}
pub struct ToolInvocableAdapter {
spec: InvocableSpec,
tool: Arc<dyn Tool>,
permissions: Arc<dyn PermissionChecker>,
resources: Arc<dyn ToolResources>,
next_call_id: AtomicU64,
}
impl ToolInvocableAdapter {
pub fn new(
tool: Arc<dyn Tool>,
permissions: Arc<dyn PermissionChecker>,
resources: Arc<dyn ToolResources>,
) -> Option<Self> {
let spec = tool.current_spec()?.as_invocable_spec();
Some(Self {
spec,
tool,
permissions,
resources,
next_call_id: AtomicU64::new(1),
})
}
}
#[async_trait]
impl Invocable for ToolInvocableAdapter {
fn spec(&self) -> &InvocableSpec {
&self.spec
}
async fn invoke(
&self,
request: InvocableRequest,
ctx: &mut CapabilityContext<'_>,
) -> Result<InvocableResult, CapabilityError> {
let tool_request = ToolRequest {
call_id: ToolCallId::new(format!(
"tool-call-{}",
self.next_call_id.fetch_add(1, Ordering::Relaxed)
)),
tool_name: self.tool.spec().name.clone(),
input: request.input,
session_id: ctx
.session_id
.cloned()
.unwrap_or_else(|| SessionId::new("capability-session")),
turn_id: ctx
.turn_id
.cloned()
.unwrap_or_else(|| TurnId::new("capability-turn")),
metadata: request.metadata,
};
for permission_request in self
.tool
.proposed_requests(&tool_request)
.map_err(|error| CapabilityError::InvalidInput(error.to_string()))?
{
match self.permissions.evaluate(permission_request.as_ref()) {
PermissionDecision::Allow => {}
PermissionDecision::Deny(denial) => {
return Err(CapabilityError::ExecutionFailed(format!(
"tool permission denied: {denial:?}"
)));
}
PermissionDecision::RequireApproval(req) => {
return Err(CapabilityError::Unavailable(format!(
"tool invocation requires approval: {}",
req.summary
)));
}
}
}
let mut tool_ctx = ToolContext {
capability: CapabilityContext {
session_id: ctx.session_id,
turn_id: ctx.turn_id,
metadata: ctx.metadata,
},
permissions: self.permissions.as_ref(),
resources: self.resources.as_ref(),
cancellation: None,
};
let result = self
.tool
.invoke(tool_request, &mut tool_ctx)
.await
.map_err(|error| CapabilityError::ExecutionFailed(error.to_string()))?;
Ok(InvocableResult {
output: match result.result.output {
ToolOutput::Text(text) => InvocableOutput::Text(text),
ToolOutput::Structured(value) => InvocableOutput::Structured(value),
ToolOutput::Parts(parts) => InvocableOutput::Items(vec![Item {
id: None,
kind: ItemKind::Tool,
parts,
metadata: MetadataMap::new(),
usage: None,
finish_reason: None,
created_at: None,
}]),
ToolOutput::Files(files) => {
let parts = files.into_iter().map(Part::File).collect();
InvocableOutput::Items(vec![Item {
id: None,
kind: ItemKind::Tool,
parts,
metadata: MetadataMap::new(),
usage: None,
finish_reason: None,
created_at: None,
}])
}
},
metadata: result.metadata,
})
}
}
pub struct ToolCapabilityProvider {
invocables: Vec<Arc<dyn Invocable>>,
}
impl ToolCapabilityProvider {
pub fn from_registry(
registry: &ToolRegistry,
permissions: Arc<dyn PermissionChecker>,
resources: Arc<dyn ToolResources>,
) -> Self {
let invocables = registry
.tools()
.into_iter()
.filter_map(|tool| {
ToolInvocableAdapter::new(tool, permissions.clone(), resources.clone())
.map(|adapter| Arc::new(adapter) as Arc<dyn Invocable>)
})
.collect();
Self { invocables }
}
}
impl CapabilityProvider for ToolCapabilityProvider {
fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
self.invocables.clone()
}
fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
Vec::new()
}
fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
Vec::new()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum ToolExecutionOutcome {
Completed(ToolResult),
Interrupted(ToolInterruption),
Failed(ToolError),
}
#[async_trait]
pub trait ToolExecutor: Send + Sync {
fn specs(&self) -> Vec<ToolSpec>;
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
Vec::new()
}
async fn execute(
&self,
request: ToolRequest,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome;
async fn execute_owned(
&self,
request: ToolRequest,
ctx: OwnedToolContext,
) -> ToolExecutionOutcome {
let mut borrowed = ctx.borrowed();
self.execute(request, &mut borrowed).await
}
async fn execute_approved(
&self,
request: ToolRequest,
approved_request: &ApprovalRequest,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome {
let _ = approved_request;
self.execute(request, ctx).await
}
async fn execute_approved_owned(
&self,
request: ToolRequest,
approved_request: &ApprovalRequest,
ctx: OwnedToolContext,
) -> ToolExecutionOutcome {
let mut borrowed = ctx.borrowed();
self.execute_approved(request, approved_request, &mut borrowed)
.await
}
}
#[async_trait]
impl<T> ToolExecutor for Arc<T>
where
T: ToolExecutor + ?Sized,
{
fn specs(&self) -> Vec<ToolSpec> {
(**self).specs()
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
(**self).drain_catalog_events()
}
async fn execute(
&self,
request: ToolRequest,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome {
(**self).execute(request, ctx).await
}
async fn execute_approved(
&self,
request: ToolRequest,
approved_request: &ApprovalRequest,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome {
(**self)
.execute_approved(request, approved_request, ctx)
.await
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum CollisionPolicy {
#[default]
FirstWins,
LastWins,
}
pub struct BasicToolExecutor {
sources: Vec<Arc<dyn ToolSource>>,
collision: CollisionPolicy,
output_truncation: Option<Arc<dyn ToolOutputTruncationStrategy>>,
}
impl BasicToolExecutor {
pub fn new(sources: impl IntoIterator<Item = Arc<dyn ToolSource>>) -> Self {
Self {
sources: sources.into_iter().collect(),
collision: CollisionPolicy::default(),
output_truncation: None,
}
}
pub fn from_registry(registry: ToolRegistry) -> Self {
Self::new([Arc::new(registry) as Arc<dyn ToolSource>])
}
pub fn with_collision_policy(mut self, policy: CollisionPolicy) -> Self {
self.collision = policy;
self
}
pub fn with_output_truncation_strategy(
mut self,
strategy: impl ToolOutputTruncationStrategy + 'static,
) -> Self {
self.output_truncation = Some(Arc::new(strategy));
self
}
pub fn with_output_truncation_strategy_arc(
mut self,
strategy: Arc<dyn ToolOutputTruncationStrategy>,
) -> Self {
self.output_truncation = Some(strategy);
self
}
pub fn specs(&self) -> Vec<ToolSpec> {
let mut seen = BTreeSet::new();
let mut out = Vec::new();
let iter: Box<dyn Iterator<Item = &Arc<dyn ToolSource>>> = match self.collision {
CollisionPolicy::FirstWins => Box::new(self.sources.iter()),
CollisionPolicy::LastWins => Box::new(self.sources.iter().rev()),
};
for source in iter {
for spec in source.specs() {
if seen.insert(spec.name.clone()) {
out.push(spec);
}
}
}
out
}
fn lookup(&self, name: &ToolName) -> Option<Arc<dyn Tool>> {
match self.collision {
CollisionPolicy::FirstWins => self.sources.iter().find_map(|s| s.get(name)),
CollisionPolicy::LastWins => self.sources.iter().rev().find_map(|s| s.get(name)),
}
}
async fn execute_inner(
&self,
request: ToolRequest,
approved_request_id: Option<&ApprovalId>,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome {
let Some(tool) = self.lookup(&request.tool_name) else {
return ToolExecutionOutcome::Failed(ToolError::NotFound(request.tool_name));
};
match tool.proposed_requests(&request) {
Ok(requests) => {
for permission_request in requests {
match ctx.permissions.evaluate(permission_request.as_ref()) {
PermissionDecision::Allow => {}
PermissionDecision::Deny(denial) => {
return ToolExecutionOutcome::Failed(ToolError::PermissionDenied(
denial,
));
}
PermissionDecision::RequireApproval(mut req) => {
req.call_id = Some(request.call_id.clone());
if approved_request_id != Some(&req.id) {
return ToolExecutionOutcome::Interrupted(
ToolInterruption::ApprovalRequired(req),
);
}
}
}
}
}
Err(error) => return ToolExecutionOutcome::Failed(error),
}
let truncation_ctx = ToolOutputTruncationContext::from((&request, tool.spec().clone()));
match tool.invoke(request, ctx).await {
Ok(mut result) => {
if let Some(strategy) = &self.output_truncation {
match strategy.apply(truncation_ctx, result.result.output).await {
Ok(output) => {
result.result.output = output;
}
Err(error) => return ToolExecutionOutcome::Failed(error),
}
}
ToolExecutionOutcome::Completed(result)
}
Err(error) => ToolExecutionOutcome::Failed(error),
}
}
}
#[async_trait]
impl ToolExecutor for BasicToolExecutor {
fn specs(&self) -> Vec<ToolSpec> {
BasicToolExecutor::specs(self)
}
fn drain_catalog_events(&self) -> Vec<ToolCatalogEvent> {
self.sources
.iter()
.flat_map(|s| s.drain_catalog_events())
.collect()
}
async fn execute(
&self,
request: ToolRequest,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome {
self.execute_inner(request, None, ctx).await
}
async fn execute_approved(
&self,
request: ToolRequest,
approved_request: &ApprovalRequest,
ctx: &mut ToolContext<'_>,
) -> ToolExecutionOutcome {
self.execute_inner(request, Some(&approved_request.id), ctx)
.await
}
}
#[derive(Debug, Error, Clone, PartialEq, Serialize, Deserialize)]
pub enum ToolError {
#[error("tool not found: {0}")]
NotFound(ToolName),
#[error("invalid tool input: {0}")]
InvalidInput(String),
#[error("tool permission denied: {0:?}")]
PermissionDenied(PermissionDenial),
#[error("tool execution failed: {0}")]
ExecutionFailed(String),
#[error("tool unavailable: {0}")]
Unavailable(String),
#[error("tool execution cancelled")]
Cancelled,
#[error("internal tool error: {0}")]
Internal(String),
}
impl ToolError {
pub fn permission_denied(denial: PermissionDenial) -> Self {
Self::PermissionDenied(denial)
}
}
impl From<PermissionDenial> for ToolError {
fn from(value: PermissionDenial) -> Self {
Self::permission_denied(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use serde_json::json;
#[test]
fn command_policy_can_deny_unknown_executables_without_approval() {
let policy = CommandPolicy::new()
.allow_executable("pwd")
.require_approval_for_unknown(false);
let request = ShellPermissionRequest {
executable: "rm".into(),
argv: vec!["-rf".into(), "/tmp/demo".into()],
cwd: None,
env_keys: Vec::new(),
metadata: MetadataMap::new(),
};
match policy.evaluate(&request) {
PolicyMatch::Deny(denial) => {
assert_eq!(denial.code, PermissionCode::CommandNotAllowed);
}
other => panic!("unexpected policy match: {other:?}"),
}
}
#[test]
fn path_policy_allows_reads_under_read_only_roots() {
let policy = PathPolicy::new().read_only_root("/workspace/vendor");
let request = FileSystemPermissionRequest::Read {
path: PathBuf::from("/workspace/vendor/lib.rs"),
metadata: MetadataMap::new(),
};
match policy.evaluate(&request) {
PolicyMatch::NoOpinion | PolicyMatch::Allow => {}
other => panic!("unexpected policy match: {other:?}"),
}
}
#[test]
fn path_policy_denies_mutations_under_read_only_roots() {
let policy = PathPolicy::new().read_only_root("/workspace/vendor");
let request = FileSystemPermissionRequest::Edit {
path: PathBuf::from("/workspace/vendor/lib.rs"),
metadata: MetadataMap::new(),
};
match policy.evaluate(&request) {
PolicyMatch::Deny(denial) => {
assert_eq!(denial.code, PermissionCode::PathNotAllowed);
assert!(denial.message.contains("read-only"));
}
other => panic!("unexpected policy match: {other:?}"),
}
}
#[test]
fn path_policy_denies_moves_into_read_only_roots() {
let policy = PathPolicy::new().read_only_root("/workspace/vendor");
let request = FileSystemPermissionRequest::Move {
from: PathBuf::from("/workspace/src/lib.rs"),
to: PathBuf::from("/workspace/vendor/lib.rs"),
metadata: MetadataMap::new(),
};
match policy.evaluate(&request) {
PolicyMatch::Deny(denial) => {
assert_eq!(denial.code, PermissionCode::PathNotAllowed);
assert!(denial.message.contains("read-only"));
}
other => panic!("unexpected policy match: {other:?}"),
}
}
#[cfg(unix)]
struct SymlinkTmpDir(PathBuf);
#[cfg(unix)]
impl SymlinkTmpDir {
fn new(label: &str) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let dir = std::env::temp_dir().join(format!(
"agentkit-pathpolicy-{}-{}-{}",
label,
std::process::id(),
nanos
));
std::fs::create_dir_all(&dir).unwrap();
Self(std::fs::canonicalize(&dir).unwrap())
}
fn path(&self) -> &Path {
&self.0
}
}
#[cfg(unix)]
impl Drop for SymlinkTmpDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.0);
}
}
#[cfg(unix)]
fn assert_path_denied(
policy: &PathPolicy,
request: FileSystemPermissionRequest,
) -> PermissionDenial {
match policy.evaluate(&request) {
PolicyMatch::Deny(denial) => denial,
other => panic!("expected deny, got: {other:?}"),
}
}
#[cfg(unix)]
#[test]
fn path_policy_blocks_symlink_escape_from_allowed_root() {
let tmp = SymlinkTmpDir::new("allow-escape");
let allowed = tmp.path().join("workspace");
let outside = tmp.path().join("outside");
std::fs::create_dir_all(&allowed).unwrap();
std::fs::create_dir_all(&outside).unwrap();
let secret = outside.join("secret.txt");
std::fs::write(&secret, b"top-secret").unwrap();
let escape = allowed.join("leak");
std::os::unix::fs::symlink(&secret, &escape).unwrap();
let policy = PathPolicy::new()
.allow_root(&allowed)
.require_approval_outside_allowed(false);
let denial = assert_path_denied(
&policy,
FileSystemPermissionRequest::Read {
path: escape,
metadata: MetadataMap::new(),
},
);
assert_eq!(denial.code, PermissionCode::PathNotAllowed);
}
#[cfg(unix)]
#[test]
fn path_policy_blocks_symlink_into_protected_root() {
let tmp = SymlinkTmpDir::new("protect-bypass");
let workspace = tmp.path().join("workspace");
std::fs::create_dir_all(&workspace).unwrap();
let secret = workspace.join(".env");
std::fs::write(&secret, b"API_KEY=xxx").unwrap();
let alias = workspace.join("config");
std::os::unix::fs::symlink(&secret, &alias).unwrap();
let policy = PathPolicy::new()
.allow_root(&workspace)
.protect_root(&secret);
let denial = assert_path_denied(
&policy,
FileSystemPermissionRequest::Read {
path: alias,
metadata: MetadataMap::new(),
},
);
assert_eq!(denial.code, PermissionCode::PathNotAllowed);
assert!(denial.message.contains("denied"));
}
#[cfg(unix)]
#[test]
fn path_policy_blocks_symlink_write_into_read_only_root() {
let tmp = SymlinkTmpDir::new("readonly-bypass");
let workspace = tmp.path().join("workspace");
let vendor = workspace.join("vendor");
std::fs::create_dir_all(&vendor).unwrap();
let target = vendor.join("lib.rs");
std::fs::write(&target, b"// vendored").unwrap();
let writable_alias = workspace.join("writable");
std::os::unix::fs::symlink(&target, &writable_alias).unwrap();
let policy = PathPolicy::new()
.allow_root(&workspace)
.read_only_root(&vendor);
let denial = assert_path_denied(
&policy,
FileSystemPermissionRequest::Edit {
path: writable_alias,
metadata: MetadataMap::new(),
},
);
assert_eq!(denial.code, PermissionCode::PathNotAllowed);
assert!(denial.message.contains("read-only"));
}
#[cfg(unix)]
#[test]
fn path_policy_resolves_symlink_parent_for_nonexistent_leaf() {
let tmp = SymlinkTmpDir::new("create-escape");
let allowed = tmp.path().join("workspace");
let outside = tmp.path().join("outside");
std::fs::create_dir_all(&allowed).unwrap();
std::fs::create_dir_all(&outside).unwrap();
let escape_dir = allowed.join("escape");
std::os::unix::fs::symlink(&outside, &escape_dir).unwrap();
let new_file = escape_dir.join("new.txt");
let policy = PathPolicy::new()
.allow_root(&allowed)
.require_approval_outside_allowed(false);
let denial = assert_path_denied(
&policy,
FileSystemPermissionRequest::Write {
path: new_file,
metadata: MetadataMap::new(),
},
);
assert_eq!(denial.code, PermissionCode::PathNotAllowed);
}
#[derive(Clone)]
struct HiddenTool {
spec: ToolSpec,
}
impl HiddenTool {
fn new() -> Self {
Self {
spec: ToolSpec {
name: ToolName::new("hidden"),
description: "hidden".into(),
input_schema: json!({"type": "object"}),
annotations: ToolAnnotations::default(),
metadata: MetadataMap::new(),
},
}
}
}
#[async_trait]
impl Tool for HiddenTool {
fn spec(&self) -> &ToolSpec {
&self.spec
}
fn current_spec(&self) -> Option<ToolSpec> {
None
}
async fn invoke(
&self,
request: ToolRequest,
_ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult {
result: ToolResultPart {
call_id: request.call_id,
output: ToolOutput::Text("hidden".into()),
is_error: false,
metadata: MetadataMap::new(),
},
duration: None,
metadata: MetadataMap::new(),
})
}
}
#[test]
fn hidden_tools_are_omitted_from_specs_and_capabilities() {
let registry = ToolRegistry::new().with(HiddenTool::new());
assert!(registry.specs().is_empty());
let provider = ToolCapabilityProvider::from_registry(
®istry,
Arc::new(AllowAllPermissionChecker),
Arc::new(()),
);
assert!(provider.invocables().is_empty());
}
struct AllowAllPermissionChecker;
impl PermissionChecker for AllowAllPermissionChecker {
fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
PermissionDecision::Allow
}
}
#[derive(Clone)]
struct PanickingSpecTool {
spec: ToolSpec,
}
impl PanickingSpecTool {
fn new(name: &str) -> Self {
Self {
spec: ToolSpec {
name: ToolName::new(name),
description: "panics on current_spec".into(),
input_schema: json!({"type": "object"}),
annotations: ToolAnnotations::default(),
metadata: MetadataMap::new(),
},
}
}
}
#[async_trait]
impl Tool for PanickingSpecTool {
fn spec(&self) -> &ToolSpec {
&self.spec
}
fn current_spec(&self) -> Option<ToolSpec> {
panic!("PanickingSpecTool::current_spec");
}
async fn invoke(
&self,
request: ToolRequest,
_ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult {
result: ToolResultPart {
call_id: request.call_id,
output: ToolOutput::Text("never".into()),
is_error: false,
metadata: MetadataMap::new(),
},
duration: None,
metadata: MetadataMap::new(),
})
}
}
#[test]
fn catalog_recovers_from_panicked_writer() {
let (writer, reader) = dynamic_catalog("test");
writer.upsert(Arc::new(PanickingSpecTool::new("boom")));
let _ = reader.drain_catalog_events();
let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
writer.replace_all(vec![
Arc::new(PanickingSpecTool::new("boom")) as Arc<dyn Tool>
]);
}));
assert!(
panic_result.is_err(),
"PanickingSpecTool::current_spec must propagate"
);
assert!(
reader.get(&ToolName::new("boom")).is_some(),
"catalog still readable after poisoning panic"
);
assert!(writer.remove(&ToolName::new("boom")));
writer.upsert(Arc::new(HiddenTool::new()));
assert!(
reader.get(&ToolName::new("hidden")).is_some(),
"catalog usable for further writes + reads"
);
}
#[derive(Clone)]
struct EchoTool {
spec: ToolSpec,
}
impl EchoTool {
fn new(name: &str) -> Self {
Self {
spec: ToolSpec {
name: ToolName::new(name),
description: format!("echo {name}"),
input_schema: json!({"type": "object"}),
annotations: ToolAnnotations::default(),
metadata: MetadataMap::new(),
},
}
}
}
#[async_trait]
impl Tool for EchoTool {
fn spec(&self) -> &ToolSpec {
&self.spec
}
async fn invoke(
&self,
request: ToolRequest,
_ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult::new(ToolResultPart::success(
request.call_id,
ToolOutput::text(request.tool_name.0.clone()),
)))
}
}
fn registry_with(names: &[&str]) -> ToolRegistry {
names.iter().fold(ToolRegistry::new(), |reg, name| {
reg.with(EchoTool::new(name))
})
}
#[test]
fn prefixed_rewrites_specs_and_resolves_lookups() {
let source = registry_with(&["get_temp", "get_humidity"]).prefixed("weather");
let names: Vec<_> = source.specs().into_iter().map(|s| s.name.0).collect();
assert_eq!(names, vec!["weather_get_humidity", "weather_get_temp"]);
assert!(source.get(&ToolName::new("weather_get_temp")).is_some());
assert!(
source.get(&ToolName::new("get_temp")).is_none(),
"original name must not resolve when prefixed"
);
assert!(source.get(&ToolName::new("unknown")).is_none());
}
#[tokio::test]
async fn prefixed_invoke_sees_inner_name_on_request() {
let source = registry_with(&["get_temp"]).prefixed("weather");
let tool = source.get(&ToolName::new("weather_get_temp")).unwrap();
assert_eq!(tool.spec().name.0, "weather_get_temp");
let owned = OwnedToolContext {
session_id: SessionId::new("s"),
turn_id: TurnId::new("t"),
metadata: MetadataMap::new(),
permissions: Arc::new(AllowAllPermissions),
resources: Arc::new(()),
cancellation: None,
};
let mut ctx = owned.borrowed();
let request = ToolRequest {
call_id: ToolCallId::new("c"),
tool_name: ToolName::new("weather_get_temp"),
input: json!({}),
session_id: SessionId::new("s"),
turn_id: TurnId::new("t"),
metadata: MetadataMap::new(),
};
let result = tool.invoke(request, &mut ctx).await.unwrap();
match result.result.output {
ToolOutput::Text(text) => assert_eq!(text, "get_temp"),
other => panic!("unexpected output: {other:?}"),
}
}
#[derive(Clone)]
struct StaticOutputTool {
spec: ToolSpec,
output: ToolOutput,
}
impl StaticOutputTool {
fn new(name: &str, output: ToolOutput) -> Self {
Self {
spec: ToolSpec::new(name, format!("static {name}"), json!({"type": "object"})),
output,
}
}
fn with_output_limit(mut self, limit: ToolOutputLimit) -> Self {
self.spec = self.spec.with_output_limit(limit);
self
}
}
#[async_trait]
impl Tool for StaticOutputTool {
fn spec(&self) -> &ToolSpec {
&self.spec
}
async fn invoke(
&self,
request: ToolRequest,
_ctx: &mut ToolContext<'_>,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult::new(ToolResultPart::success(
request.call_id,
self.output.clone(),
)))
}
}
fn test_context() -> OwnedToolContext {
OwnedToolContext {
session_id: SessionId::new("s"),
turn_id: TurnId::new("t"),
metadata: MetadataMap::new(),
permissions: Arc::new(AllowAllPermissions),
resources: Arc::new(()),
cancellation: None,
}
}
#[tokio::test]
async fn executor_stores_oversized_output_using_tool_metadata_limit() {
let store = Arc::new(InMemoryToolOutputArtifactStore::new());
let strategy = ConfigurableToolOutputTruncationStrategy::new(store.clone());
let tool = StaticOutputTool::new("big", ToolOutput::text("x".repeat(500)))
.with_output_limit(ToolOutputLimit::store_for_readback(300));
let executor = BasicToolExecutor::from_registry(ToolRegistry::new().with(tool))
.with_output_truncation_strategy(strategy);
let outcome = executor
.execute_owned(
ToolRequest::new(
"call",
"big",
json!({}),
SessionId::new("s"),
TurnId::new("t"),
),
test_context(),
)
.await;
let ToolExecutionOutcome::Completed(result) = outcome else {
panic!("expected completed outcome, got {outcome:?}");
};
let ToolOutput::Structured(envelope) = result.result.output else {
panic!("expected truncation envelope");
};
assert_eq!(envelope["truncated"], true);
assert_eq!(envelope["read_tool"], TOOL_RESULT_READ_TOOL_NAME);
let id = envelope["tool_result_id"].as_str().expect("tool_result_id");
let slice = store
.read(&ToolOutputArtifactId(id.to_string()), 0, 50)
.await
.expect("read artifact");
assert_eq!(slice.content, "x".repeat(50));
assert_eq!(slice.next_offset, 50);
assert!(!slice.eof);
}
#[tokio::test]
async fn tool_result_read_enforces_explicit_max_read_size() {
let store = Arc::new(InMemoryToolOutputArtifactStore::new());
let spec = ToolSpec::new("big", "big output", json!({"type": "object"}));
let request = ToolRequest::new(
"call",
"big",
json!({}),
SessionId::new("s"),
TurnId::new("t"),
);
let ctx = ToolOutputTruncationContext::from((&request, spec));
let artifact = store
.put(&ctx, "abcdef".to_string(), 6)
.await
.expect("store artifact");
let tool = ToolResultReadTool::new(store, 4);
let owned_ctx = test_context();
let mut tool_ctx = owned_ctx.borrowed();
let err = tool
.invoke(
ToolRequest::new(
"read-call",
TOOL_RESULT_READ_TOOL_NAME,
json!({"id": artifact.id.0, "offset": 0, "limit": 5}),
SessionId::new("s"),
TurnId::new("t"),
),
&mut tool_ctx,
)
.await
.expect_err("read past max must fail");
match err {
ToolError::InvalidInput(message) => assert!(message.contains("exceeds maximum")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn tool_result_read_rejects_zero_limit() {
let store = Arc::new(InMemoryToolOutputArtifactStore::new());
let spec = ToolSpec::new("big", "big output", json!({"type": "object"}));
let request = ToolRequest::new(
"call",
"big",
json!({}),
SessionId::new("s"),
TurnId::new("t"),
);
let ctx = ToolOutputTruncationContext::from((&request, spec));
let artifact = store
.put(&ctx, "abcdef".to_string(), 6)
.await
.expect("store artifact");
let tool = ToolResultReadTool::new(store, 4);
let owned_ctx = test_context();
let mut tool_ctx = owned_ctx.borrowed();
let err = tool
.invoke(
ToolRequest::new(
"read-call",
TOOL_RESULT_READ_TOOL_NAME,
json!({"id": artifact.id.0, "offset": 0, "limit": 0}),
SessionId::new("s"),
TurnId::new("t"),
),
&mut tool_ctx,
)
.await
.expect_err("zero limit must fail");
match err {
ToolError::InvalidInput(message) => assert!(message.contains("greater than 0")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn tool_result_read_executor_allows_full_content_limit_with_envelope() {
let store = Arc::new(InMemoryToolOutputArtifactStore::new());
let spec = ToolSpec::new("big", "big output", json!({"type": "object"}));
let request = ToolRequest::new(
"call",
"big",
json!({}),
SessionId::new("s"),
TurnId::new("t"),
);
let ctx = ToolOutputTruncationContext::from((&request, spec));
let artifact = store
.put(&ctx, "abcd".to_string(), 4)
.await
.expect("store artifact");
let executor = BasicToolExecutor::from_registry(
ToolRegistry::new().with(ToolResultReadTool::new(store.clone(), 4)),
)
.with_output_truncation_strategy(ConfigurableToolOutputTruncationStrategy::new(store));
let outcome = executor
.execute_owned(
ToolRequest::new(
"read-call",
TOOL_RESULT_READ_TOOL_NAME,
json!({"id": artifact.id.0, "offset": 0, "limit": 4}),
SessionId::new("s"),
TurnId::new("t"),
),
test_context(),
)
.await;
let ToolExecutionOutcome::Completed(result) = outcome else {
panic!("expected completed outcome, got {outcome:?}");
};
let ToolOutput::Structured(output) = result.result.output else {
panic!("expected structured readback output");
};
assert_eq!(output["content"], "abcd");
assert_eq!(output["eof"], true);
}
#[tokio::test]
async fn tool_result_read_executor_allows_json_escaped_full_content_limit() {
let store = Arc::new(InMemoryToolOutputArtifactStore::new());
let spec = ToolSpec::new("big", "big output", json!({"type": "object"}));
let request = ToolRequest::new(
"call",
"big",
json!({}),
SessionId::new("s"),
TurnId::new("t"),
);
let ctx = ToolOutputTruncationContext::from((&request, spec));
let content = "\0".repeat(4);
let artifact = store
.put(&ctx, content.clone(), content.len())
.await
.expect("store artifact");
let executor = BasicToolExecutor::from_registry(
ToolRegistry::new().with(ToolResultReadTool::new(store.clone(), 4)),
)
.with_output_truncation_strategy(ConfigurableToolOutputTruncationStrategy::new(store));
let outcome = executor
.execute_owned(
ToolRequest::new(
"read-call",
TOOL_RESULT_READ_TOOL_NAME,
json!({"id": artifact.id.0, "offset": 0, "limit": 4}),
SessionId::new("s"),
TurnId::new("t"),
),
test_context(),
)
.await;
let ToolExecutionOutcome::Completed(result) = outcome else {
panic!("expected completed outcome, got {outcome:?}");
};
let ToolOutput::Structured(output) = result.result.output else {
panic!("expected structured readback output");
};
assert_eq!(output["content"], content);
assert_eq!(output["eof"], true);
}
#[test]
fn inline_clip_respects_limit_when_marker_exceeds_budget() {
let clipped = clip_string_with_marker("abcdef", 8, 1000);
assert!(clipped.len() <= 8);
assert!(clipped.is_char_boundary(clipped.len()));
}
#[test]
fn filtered_hides_tools_rejected_by_predicate() {
let source = registry_with(&["safe", "danger_drop", "danger_delete"])
.filtered(|name| !name.0.starts_with("danger_"));
let names: Vec<_> = source.specs().into_iter().map(|s| s.name.0).collect();
assert_eq!(names, vec!["safe"]);
assert!(source.get(&ToolName::new("safe")).is_some());
assert!(source.get(&ToolName::new("danger_drop")).is_none());
}
#[test]
fn renamed_remaps_specs_and_lookups() {
let source = registry_with(&["legacy_name", "passthrough"])
.renamed([(ToolName::new("legacy_name"), ToolName::new("modern_name"))]);
let mut names: Vec<_> = source.specs().into_iter().map(|s| s.name.0).collect();
names.sort();
assert_eq!(names, vec!["modern_name", "passthrough"]);
assert!(source.get(&ToolName::new("modern_name")).is_some());
assert!(
source.get(&ToolName::new("legacy_name")).is_none(),
"original name is hidden after renaming"
);
assert!(source.get(&ToolName::new("passthrough")).is_some());
}
#[cfg(feature = "schemars")]
mod schemars_helpers {
use super::*;
use schemars::JsonSchema;
use serde::Deserialize;
#[derive(JsonSchema, Deserialize)]
#[allow(dead_code)]
struct WeatherInput {
location: String,
#[serde(default)]
celsius: bool,
}
#[test]
fn schema_for_emits_object_schema_with_typed_fields() {
let schema = schema_for::<WeatherInput>();
let obj = schema.as_object().expect("schema is a JSON object");
assert_eq!(
obj.get("type").and_then(|v| v.as_str()),
Some("object"),
"root type should be object"
);
let properties = obj
.get("properties")
.and_then(|v| v.as_object())
.expect("properties block");
assert!(properties.contains_key("location"));
assert!(properties.contains_key("celsius"));
}
#[test]
fn tool_spec_for_carries_schema_name_and_description() {
let spec = tool_spec_for::<WeatherInput>("get_weather", "Fetch current weather");
assert_eq!(spec.name.0, "get_weather");
assert_eq!(spec.description, "Fetch current weather");
assert!(spec.input_schema.is_object());
}
}
#[test]
fn transforms_compose_via_chained_methods() {
let source = registry_with(&["read_file", "write_file", "delete_file"])
.filtered(|name| name.0 != "delete_file")
.prefixed("fs");
let mut names: Vec<_> = source.specs().into_iter().map(|s| s.name.0).collect();
names.sort();
assert_eq!(names, vec!["fs_read_file", "fs_write_file"]);
}
}