use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use atd_protocol::{Request, Response};
use crate::audit::AuditSink;
use crate::capability::CapabilitySet;
use crate::context::CallContext;
use crate::error::ToolCallError;
use crate::middleware::Middleware;
use crate::registry::Registry;
use crate::secrets::TokenBroker;
use crate::tier::{TierPolicy, ToolTier};
use crate::tracker::ReadTracker;
pub struct SharedServerConfig {
pub cwd: PathBuf,
pub max_output_bytes: usize,
pub default_call_timeout_ms: u64,
pub granted_capabilities: Vec<String>,
pub audit_sink: Option<Arc<dyn AuditSink>>,
pub server_version: String,
pub token_broker: Option<Arc<dyn TokenBroker>>,
pub max_ucan_chain_depth: u8,
pub ucan_revocation_store: Option<Arc<dyn crate::ucan::UcanRevocationStore>>,
pub frame_deadline_active_ms: u64,
pub frame_deadline_handshake_ms: u64,
pub cursor_signing_key: [u8; 32],
pub cursor_ttl_seconds: u64,
}
impl SharedServerConfig {
pub fn for_test() -> Self {
Self {
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
max_output_bytes: 1_048_576,
default_call_timeout_ms: 60_000,
granted_capabilities: vec![],
audit_sink: None,
server_version: "atd-runtime-test 0.0.0".into(),
token_broker: None,
max_ucan_chain_depth: 5,
ucan_revocation_store: None,
frame_deadline_active_ms: 30_000,
frame_deadline_handshake_ms: 5_000,
cursor_signing_key: [0u8; 32],
cursor_ttl_seconds: 300,
}
}
}
pub struct ServerState {
pub registry: Registry,
pub config: SharedServerConfig,
pub tier_policy: TierPolicy,
pub middleware: Vec<Arc<dyn Middleware>>,
pub metrics: Arc<crate::metrics::MetricsCounters>,
pub cursor_issuer: Arc<crate::cursor::CursorIssuer>,
}
pub async fn dispatch_request(
state: &Arc<ServerState>,
tracker: &Arc<ReadTracker>,
caps: &mut Arc<CapabilitySet>,
caller_id: &mut Option<String>,
req: Request,
) -> Response {
state
.metrics
.dispatched_requests
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let resp = dispatch_request_inner(state, tracker, caps, caller_id, req).await;
if let Response::Error { code: Some(c), .. } = &resp {
state.metrics.record_error(*c);
}
resp
}
async fn dispatch_request_inner(
state: &Arc<ServerState>,
tracker: &Arc<ReadTracker>,
caps: &mut Arc<CapabilitySet>,
caller_id: &mut Option<String>,
req: Request,
) -> Response {
match req {
Request::Ping => Response::Pong,
Request::Hello {
client_id,
requested_capabilities,
ucan_tokens,
} => {
*caller_id = client_id.clone();
let allow = CapabilitySet::from_iter(state.config.granted_capabilities.iter().cloned());
let (granted_strings_vec, _denied) = allow.intersect(&requested_capabilities);
let string_provenance = granted_strings_vec
.iter()
.map(|c| crate::audit::CapProvenance {
cap: c.clone(),
source: crate::audit::ProvSource::StringAllowList,
})
.collect();
let granted_strings =
CapabilitySet::with_provenance(granted_strings_vec, string_provenance);
let granted_ucan = if ucan_tokens.is_empty() {
CapabilitySet::empty()
} else {
let expected_aud = match caller_id.as_ref() {
Some(s) if !s.is_empty() => s.clone(),
_ => {
return Response::Error {
message: "UCAN tokens require Hello.client_id (audience pin)"
.to_string(),
code: Some(atd_protocol::ERR_AUDIENCE_MISMATCH),
retryable: Some(false),
details: None,
};
}
};
let mut cfg = crate::ucan::VerifyConfig::new(expected_aud);
cfg.max_chain_depth = state.config.max_ucan_chain_depth;
cfg.revocation_store = state.config.ucan_revocation_store.clone();
match crate::ucan::verify_tokens(&ucan_tokens, &cfg, std::time::SystemTime::now()) {
Ok(c) => c,
Err(e) => {
let code = crate::ucan::wire_code(&e);
return Response::Error {
message: e.to_string(),
code: Some(code),
retryable: Some(false),
details: None,
};
}
}
};
let granted_caps = granted_strings.union(&granted_ucan);
let granted_vec = granted_caps.granted();
*caps = Arc::new(granted_caps);
Response::HelloAck {
granted_capabilities: granted_vec,
server_version: state.config.server_version.clone(),
supported_tiers: vec!["hot".into(), "warm".into(), "cold".into()],
}
}
Request::ToolList => {
let summaries: Vec<_> = state
.registry
.summaries()
.into_iter()
.filter(|s| !matches!(s.visibility, atd_protocol::ToolVisibility::Hidden))
.collect();
Response::ToolListResponse {
tools: serde_json::to_value(&summaries).unwrap_or_else(|_| serde_json::json!([])),
}
}
Request::ToolSchema { tool_id } => match state.registry.get(&tool_id) {
Some(entry) => Response::ToolSchemaResponse {
schema: serde_json::to_value(entry.definition())
.unwrap_or_else(|_| serde_json::json!({})),
},
None => Response::Error {
message: format!("tool not found: {tool_id}"),
code: None,
retryable: Some(false),
details: None,
},
},
Request::RunTool {
tool_id,
args,
dry_run,
} => {
run_tool(
state,
tracker,
caps,
caller_id.as_deref(),
tool_id,
args,
dry_run,
)
.await
}
Request::RunToolContinue { tool_id, cursor } => {
run_tool_continue(state, tracker, caps, caller_id.as_deref(), tool_id, cursor).await
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn run_tool_continue(
state: &Arc<ServerState>,
tracker: &Arc<ReadTracker>,
caps: &Arc<CapabilitySet>,
caller_id: Option<&str>,
tool_id: String,
cursor: String,
) -> Response {
use std::sync::atomic::Ordering;
let start = Instant::now();
let audit_call_id = ulid::Ulid::new();
let payload = match state
.cursor_issuer
.verify(&cursor, state.config.cursor_ttl_seconds)
{
Ok(p) => p,
Err(crate::cursor::CursorError::Expired) => {
return Response::Error {
message: "cursor expired; re-issue the original RunTool".into(),
code: Some(atd_protocol::ERR_CURSOR_EXPIRED),
retryable: Some(false),
details: None,
};
}
Err(_) => {
return Response::Error {
message: "cursor invalid".into(),
code: Some(atd_protocol::ERR_CURSOR_INVALID),
retryable: Some(false),
details: None,
};
}
};
if payload.tool_id != tool_id {
return Response::Error {
message: format!(
"cursor tool_id mismatch: cursor={} request={tool_id}",
payload.tool_id
),
code: Some(atd_protocol::ERR_CURSOR_INVALID),
retryable: Some(false),
details: None,
};
}
let entry = match state.registry.get(&tool_id) {
Some(e) => e.clone(),
None => {
return Response::Error {
message: format!("tool not found: {tool_id}"),
code: None,
retryable: Some(false),
details: None,
};
}
};
if !entry.tool.supports_pagination() {
return Response::Error {
message: format!("tool {tool_id} does not support pagination but received a cursor"),
code: Some(atd_protocol::ERR_CURSOR_INVALID),
retryable: Some(false),
details: None,
};
}
let tier = entry.definition().tier.unwrap_or(ToolTier::Warm);
let required = entry.definition().required_capabilities.clone();
let missing: Vec<String> = required
.iter()
.filter(|c| !caps.contains(c))
.cloned()
.collect();
if !missing.is_empty() {
return Response::Error {
message: format!("capability denied for {tool_id}: missing {missing:?}"),
code: Some(atd_protocol::ERR_CAPABILITY_DENIED),
retryable: Some(false),
details: None,
};
}
let _permit = match entry.semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
return Response::Error {
message: format!("rate limited for {tool_id} (continuation)"),
code: Some(atd_protocol::ERR_RATE_LIMITED),
retryable: Some(true),
details: None,
};
}
};
let tier_timeout = state.tier_policy.timeout(tier);
let tier_max_output = state.tier_policy.max_output(tier);
let ctx = CallContext::new(
state.config.cwd.clone(),
tier_max_output,
audit_call_id,
Some(Instant::now() + tier_timeout),
Some(tracker.clone()),
caps.clone(),
tier,
caller_id.map(|s| s.to_string()),
None, )
.with_cursor_issuer(state.cursor_issuer.clone());
let page_index = payload.page_index;
let result = entry
.tool
.call_paginated(serde_json::Value::Null, &ctx, Some(&cursor))
.await;
let response = match result {
Ok(crate::registry::PaginatedResult {
mut value,
next_cursor,
}) => {
run_result_middleware(state, &tool_id, entry.definition(), &mut value);
Response::ToolResultResponse {
tool_id: tool_id.clone(),
result: value,
success: true,
dry_run: false,
next_cursor,
}
}
Err(crate::error::ToolCallError::ExecutionFailed {
code,
message,
retryable,
}) => {
let mut value = serde_json::json!({
"code": code,
"message": message,
"retryable": retryable,
});
run_result_middleware(state, &tool_id, entry.definition(), &mut value);
Response::ToolResultResponse {
tool_id: tool_id.clone(),
result: value,
success: false,
dry_run: false,
next_cursor: None,
}
}
Err(e) => {
let mut message = format!("tool {tool_id} continuation failed: {e:?}");
let mut details = None;
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
Response::Error {
message,
code: None,
retryable: Some(false),
details,
}
}
};
if let Some(sink) = state.config.audit_sink.as_ref() {
state
.metrics
.audit_events_total
.fetch_add(1, Ordering::Relaxed);
let outcome = match &response {
Response::ToolResultResponse { success: true, .. } => crate::audit::Outcome::Success,
_ => crate::audit::Outcome::ExecutionFailed {
code: "continuation_failed".into(),
retryable: false,
},
};
sink.on_call(&crate::audit::CallEvent {
ts: crate::audit::now_rfc3339(),
call_id: audit_call_id.to_string(),
tool_id: tool_id.clone(),
caller_id: caller_id.map(|s| s.to_string()),
granted_capabilities: caps.granted(),
duration_ms: start.elapsed().as_millis() as u64,
outcome,
tier: crate::tier::tier_as_str(tier).to_string(),
dry_run: false,
schema_version: crate::audit::SCHEMA_VERSION,
secrets_resolved: false,
cursor_page: Some(page_index),
capability_provenance: prov_for(caps),
});
}
response
}
fn prov_for(caps: &CapabilitySet) -> Option<Vec<crate::audit::CapProvenance>> {
let p = caps.provenance();
if p.is_empty() { None } else { Some(p.to_vec()) }
}
fn run_error_middleware(
state: &ServerState,
tool_id: &str,
def: &atd_protocol::ToolDefinition,
message: &mut String,
details: &mut Option<serde_json::Value>,
) {
for mw in &state.middleware {
mw.on_error(tool_id, def, message, details);
}
}
fn run_result_middleware(
state: &ServerState,
tool_id: &str,
def: &atd_protocol::ToolDefinition,
value: &mut serde_json::Value,
) {
for mw in &state.middleware {
mw.on_result(tool_id, def, value);
}
}
#[allow(clippy::too_many_arguments)]
pub async fn run_tool(
state: &Arc<ServerState>,
tracker: &Arc<ReadTracker>,
caps: &Arc<CapabilitySet>,
caller_id: Option<&str>,
tool_id: String,
args: serde_json::Value,
dry_run: bool,
) -> Response {
let start = Instant::now();
let audit_call_id = ulid::Ulid::new();
let cursor_page: Option<u32> = None;
let emit = |outcome: crate::audit::Outcome, tier: ToolTier, secrets_resolved: bool| {
if let Some(sink) = state.config.audit_sink.as_ref() {
state
.metrics
.audit_events_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
sink.on_call(&crate::audit::CallEvent {
ts: crate::audit::now_rfc3339(),
call_id: audit_call_id.to_string(),
tool_id: tool_id.clone(),
caller_id: caller_id.map(|s| s.to_string()),
granted_capabilities: caps.granted(),
duration_ms: start.elapsed().as_millis() as u64,
outcome,
tier: crate::tier::tier_as_str(tier).to_string(),
dry_run,
schema_version: crate::audit::SCHEMA_VERSION,
secrets_resolved,
cursor_page,
capability_provenance: prov_for(caps),
});
}
};
if dry_run {
emit(crate::audit::Outcome::Success, ToolTier::Warm, false);
return Response::ToolResultResponse {
tool_id: tool_id.clone(),
result: serde_json::json!({
"dry_run": true,
"tool_id": tool_id,
"args_preview": args,
}),
success: true,
dry_run: true,
next_cursor: None,
};
}
let entry = match state.registry.get(&tool_id) {
Some(e) => e.clone(),
None => {
emit(crate::audit::Outcome::ToolNotFound, ToolTier::Warm, false);
return Response::Error {
message: format!("tool not found: {tool_id}"),
code: None,
retryable: Some(false),
details: None,
};
}
};
let tier = entry.definition().tier.unwrap_or(ToolTier::Warm);
let required = entry.definition().required_capabilities.clone();
let missing: Vec<String> = required
.iter()
.filter(|c| !caps.contains(c))
.cloned()
.collect();
if !missing.is_empty() {
let mut required_sorted = required.clone();
required_sorted.sort();
let mut missing_sorted = missing.clone();
missing_sorted.sort();
emit(
crate::audit::Outcome::CapabilityDenied {
missing: missing_sorted.clone(),
},
tier,
false,
);
let mut message = format!("capability denied for {tool_id}: missing {missing_sorted:?}");
let mut details = Some(serde_json::json!({
"required": required_sorted,
"granted": caps.granted(),
"missing": missing_sorted,
}));
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
return Response::Error {
message,
code: Some(atd_protocol::ERR_CAPABILITY_DENIED),
retryable: Some(false),
details,
};
}
let _permit = match entry.semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
let max_conc = entry.tool.definition().resources.max_concurrent;
emit(
crate::audit::Outcome::RateLimited {
retry_after_ms: None,
},
tier,
false,
);
let mut message =
format!("rate limited for {tool_id}: max_concurrent={max_conc} in-flight");
let mut details = Some(serde_json::json!({
"tool_id": tool_id,
"limit": max_conc,
}));
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
return Response::Error {
message,
code: Some(atd_protocol::ERR_RATE_LIMITED),
retryable: Some(true),
details,
};
}
};
let secrets = match state.config.token_broker.as_ref() {
None => None,
Some(broker) => match broker.resolve(caller_id).await {
Ok(bundle) => bundle,
Err(e) => {
emit(
crate::audit::Outcome::ExecutionFailed {
code: "broker_error".into(),
retryable: true,
},
tier,
false,
);
let mut message = format!("token broker error for {tool_id}: {e}");
let mut details = None;
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
return Response::Error {
message,
code: Some(atd_protocol::ERR_BROKER_FAILED),
retryable: Some(true),
details,
};
}
},
};
let secrets_resolved = secrets.is_some();
let tier_timeout = state.tier_policy.timeout(tier);
let tier_max_output = state.tier_policy.max_output(tier);
let ctx = if entry.tool.supports_pagination() {
CallContext::new(
state.config.cwd.clone(),
tier_max_output,
audit_call_id,
Some(Instant::now() + tier_timeout),
Some(tracker.clone()),
caps.clone(),
tier,
caller_id.map(|s| s.to_string()),
secrets,
)
.with_cursor_issuer(state.cursor_issuer.clone())
} else {
CallContext::new(
state.config.cwd.clone(),
tier_max_output,
audit_call_id,
Some(Instant::now() + tier_timeout),
Some(tracker.clone()),
caps.clone(),
tier,
caller_id.map(|s| s.to_string()),
secrets,
)
};
let call_result: Result<(serde_json::Value, Option<String>), ToolCallError> =
if entry.tool.supports_pagination() {
entry
.tool
.call_paginated(args, &ctx, None)
.await
.map(|p| (p.value, p.next_cursor))
} else {
entry
.binding
.call(entry.definition(), args, &ctx)
.await
.map(|v| (v, None))
};
match call_result {
Ok((mut data, next_cursor)) => {
run_result_middleware(state, &tool_id, entry.definition(), &mut data);
emit(crate::audit::Outcome::Success, tier, secrets_resolved);
Response::ToolResultResponse {
tool_id,
result: data,
success: true,
dry_run: false,
next_cursor,
}
}
Err(ToolCallError::InvalidArgs(msg)) => {
emit(
crate::audit::Outcome::InvalidArgs {
message: msg.clone(),
},
tier,
secrets_resolved,
);
let mut message = format!("invalid args for {tool_id}: {msg}");
let mut details = None;
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
Response::Error {
message,
code: None,
retryable: Some(false),
details,
}
}
Err(ToolCallError::ExecutionFailed {
code,
message,
retryable,
}) => {
emit(
crate::audit::Outcome::ExecutionFailed {
code: code.clone(),
retryable,
},
tier,
secrets_resolved,
);
let mut result = serde_json::json!({
"code": code,
"message": message,
"retryable": retryable,
});
run_result_middleware(state, &tool_id, entry.definition(), &mut result);
Response::ToolResultResponse {
tool_id,
result,
success: false,
dry_run: false,
next_cursor: None,
}
}
Err(ToolCallError::InternalError(msg)) => {
emit(
crate::audit::Outcome::ExecutionFailed {
code: "INTERNAL".into(),
retryable: false,
},
tier,
secrets_resolved,
);
let mut message = format!("internal error in {tool_id}: {msg}");
let mut details = None;
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
Response::Error {
message,
code: None,
retryable: Some(false),
details,
}
}
Err(other) => {
emit(
crate::audit::Outcome::ExecutionFailed {
code: "UNHANDLED".into(),
retryable: false,
},
tier,
secrets_resolved,
);
let mut message = format!("unhandled tool error in {tool_id}: {other}");
let mut details = None;
run_error_middleware(
state,
&tool_id,
entry.definition(),
&mut message,
&mut details,
);
Response::Error {
message,
code: Some(1999),
retryable: Some(false),
details,
}
}
}
}