use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use wasmtime::Store;
use wasmtime::component::Linker;
use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiView};
use crate::context::JobContext;
use crate::llm::recording::{HttpExchangeRequest, HttpExchangeResponse, HttpInterceptor};
use crate::safety::LeakDetector;
use crate::secrets::{DecryptedSecret, SecretsStore};
use crate::tools::tool::{Tool, ToolError, ToolOutput};
use crate::tools::wasm::capabilities::Capabilities;
use crate::tools::wasm::credential_injector::{
InjectedCredentials, host_matches_pattern, inject_credential,
};
use crate::tools::wasm::error::WasmError;
use crate::tools::wasm::host::{HostState, LogLevel};
use crate::tools::wasm::limits::{ResourceLimits, WasmResourceLimiter};
use crate::tools::wasm::runtime::{EPOCH_TICK_INTERVAL, PreparedModule, WasmToolRuntime};
wasmtime::component::bindgen!({
path: "wit/tool.wit",
world: "sandboxed-tool",
async: false,
with: {},
});
use crate::cli::oauth_defaults;
use exports::near::agent::tool as wit_tool;
#[derive(Debug, Clone)]
pub struct OAuthRefreshConfig {
pub token_url: String,
pub client_id: String,
pub client_secret: Option<String>,
pub exchange_proxy_url: Option<String>,
pub gateway_token: Option<String>,
pub secret_name: String,
pub provider: Option<String>,
}
impl OAuthRefreshConfig {
fn oauth_proxy_auth_token(&self) -> Option<&str> {
self.gateway_token.as_deref()
}
}
struct ResolvedHostCredential {
host_patterns: Vec<String>,
headers: HashMap<String, String>,
query_params: HashMap<String, String>,
secret_value: String,
}
struct StoreData {
limiter: WasmResourceLimiter,
host_state: HostState,
wasi: WasiCtx,
table: ResourceTable,
credentials: HashMap<String, String>,
host_credentials: Vec<ResolvedHostCredential>,
http_runtime: Option<tokio::runtime::Runtime>,
http_interceptor: Option<Arc<dyn HttpInterceptor>>,
}
impl StoreData {
fn new(
memory_limit: u64,
capabilities: Capabilities,
credentials: HashMap<String, String>,
host_credentials: Vec<ResolvedHostCredential>,
) -> Self {
let wasi = WasiCtxBuilder::new().build();
Self {
limiter: WasmResourceLimiter::new(memory_limit),
host_state: HostState::new(capabilities),
wasi,
table: ResourceTable::new(),
credentials,
host_credentials,
http_runtime: None,
http_interceptor: None,
}
}
fn inject_credentials(&self, input: &str, context: &str) -> String {
let mut result = input.to_string();
for (name, value) in &self.credentials {
let placeholder = format!("{{{}}}", name);
if result.contains(&placeholder) {
tracing::debug!(
placeholder = %placeholder,
context = %context,
"Replacing credential placeholder in tool request"
);
result = result.replace(&placeholder, value);
}
}
result
}
fn redact_credentials(&self, text: &str) -> String {
let mut result = text.to_string();
for (name, value) in &self.credentials {
if !value.is_empty() {
result = result.replace(value, &format!("[REDACTED:{}]", name));
}
}
for cred in &self.host_credentials {
if !cred.secret_value.is_empty() {
result = result.replace(&cred.secret_value, "[REDACTED:host_credential]");
}
}
result
}
fn inject_host_credentials(
&self,
url_host: &str,
headers: &mut HashMap<String, String>,
url: &mut String,
) {
for cred in &self.host_credentials {
let matches = cred
.host_patterns
.iter()
.any(|pattern| host_matches_pattern(url_host, pattern));
if !matches {
continue;
}
for (key, value) in &cred.headers {
headers.insert(key.clone(), value.clone());
}
if !cred.query_params.is_empty() {
let (base, fragment) = match url.find('#') {
Some(i) => (url[..i].to_string(), Some(url[i..].to_string())),
None => (url.clone(), None),
};
*url = base;
let separator = if url.contains('?') { '&' } else { '?' };
for (i, (name, value)) in cred.query_params.iter().enumerate() {
if i == 0 {
url.push(separator);
} else {
url.push('&');
}
url.push_str(&urlencoding::encode(name));
url.push('=');
url.push_str(&urlencoding::encode(value));
}
if let Some(frag) = fragment {
url.push_str(&frag);
}
}
}
}
}
impl WasiView for StoreData {
fn ctx(&mut self) -> &mut WasiCtx {
&mut self.wasi
}
fn table(&mut self) -> &mut ResourceTable {
&mut self.table
}
}
impl near::agent::host::Host for StoreData {
fn log(&mut self, level: near::agent::host::LogLevel, message: String) {
let log_level = match level {
near::agent::host::LogLevel::Trace => LogLevel::Trace,
near::agent::host::LogLevel::Debug => LogLevel::Debug,
near::agent::host::LogLevel::Info => LogLevel::Info,
near::agent::host::LogLevel::Warn => LogLevel::Warn,
near::agent::host::LogLevel::Error => LogLevel::Error,
};
let _ = self.host_state.log(log_level, message);
}
fn now_millis(&mut self) -> u64 {
self.host_state.now_millis()
}
fn workspace_read(&mut self, path: String) -> Option<String> {
self.host_state.workspace_read(&path).ok().flatten()
}
fn http_request(
&mut self,
method: String,
url: String,
headers_json: String,
body: Option<Vec<u8>>,
timeout_ms: Option<u32>,
) -> Result<near::agent::host::HttpResponse, String> {
let injected_url = self.inject_credentials(&url, "url");
self.host_state
.check_http_allowed(&injected_url, &method)
.map_err(|e| format!("HTTP not allowed: {}", e))?;
self.host_state
.record_http_request()
.map_err(|e| format!("Rate limit exceeded: {}", e))?;
let raw_headers: HashMap<String, String> =
serde_json::from_str(&headers_json).unwrap_or_default();
let leak_detector = LeakDetector::new();
leak_detector
.scan_and_clean(&injected_url)
.map_err(|e| format!("Potential secret leak in URL blocked: {}", e))?;
for (name, value) in &raw_headers {
leak_detector.scan_and_clean(value).map_err(|e| {
format!("Potential secret leak in header '{}' blocked: {}", name, e)
})?;
}
if let Some(body_bytes) = body.as_deref() {
let body_str = String::from_utf8_lossy(body_bytes);
leak_detector
.scan_and_clean(&body_str)
.map_err(|e| format!("Potential secret leak in body blocked: {}", e))?;
}
let mut headers: HashMap<String, String> = raw_headers
.into_iter()
.map(|(k, v)| {
(
k.clone(),
self.inject_credentials(&v, &format!("header:{}", k)),
)
})
.collect();
let mut url = injected_url;
if let Some(host) = extract_host_from_url(&url) {
self.inject_host_credentials(&host, &mut headers, &mut url);
}
let max_response_bytes = self
.host_state
.capabilities()
.http
.as_ref()
.map(|h| h.max_response_bytes)
.unwrap_or(10 * 1024 * 1024);
reject_private_ip(&url)?;
if self.http_runtime.is_none() {
self.http_runtime = Some(
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| format!("Failed to create HTTP runtime: {e}"))?,
);
}
let rt = self.http_runtime.as_ref().expect("just initialized");
if let Some(interceptor) = &self.http_interceptor {
let interceptor = Arc::clone(interceptor);
let intercept_url = url.clone();
let intercept_method = method.clone();
let mut intercept_headers: Vec<(String, String)> = headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
intercept_headers.sort_by(|a, b| a.0.cmp(&b.0));
let intercept_body = body
.as_ref()
.map(|b| String::from_utf8_lossy(b).to_string());
let intercepted = rt.block_on(async {
let req = HttpExchangeRequest {
method: intercept_method,
url: intercept_url,
headers: intercept_headers,
body: intercept_body,
};
interceptor.before_request(&req).await
});
if let Some(resp) = intercepted {
let resp_headers: HashMap<String, String> = resp
.headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let resp_headers_json =
serde_json::to_string(&resp_headers).unwrap_or_else(|_| "{}".to_string());
return Ok(near::agent::host::HttpResponse {
status: resp.status,
headers_json: resp_headers_json,
body: resp.body.into_bytes(),
});
}
}
let interceptor_req = self.http_interceptor.as_ref().map(|_| HttpExchangeRequest {
method: method.clone(),
url: url.clone(),
headers: headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
body: body
.as_ref()
.map(|b| String::from_utf8_lossy(b).to_string()),
});
let result = rt.block_on(async {
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(10))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| format!("Failed to build HTTP client: {e}"))?;
let mut request = match method.to_uppercase().as_str() {
"GET" => client.get(&url),
"POST" => client.post(&url),
"PUT" => client.put(&url),
"DELETE" => client.delete(&url),
"PATCH" => client.patch(&url),
"HEAD" => client.head(&url),
_ => return Err(format!("Unsupported HTTP method: {}", method)),
};
for (key, value) in &headers {
request = request.header(key, value);
}
if let Some(body_bytes) = body {
request = request.body(body_bytes);
} else if needs_content_length_zero(&method, &headers) {
request = request.header("content-length", "0");
}
let timeout_ms = timeout_ms.unwrap_or(30_000).min(300_000) as u64;
let timeout = Duration::from_millis(timeout_ms);
let response = request.timeout(timeout).send().await.map_err(|e| {
let mut chain = format!("HTTP request failed: {}", e);
let mut source = std::error::Error::source(&e);
while let Some(cause) = source {
chain.push_str(&format!(" -> {}", cause));
source = cause.source();
}
chain
})?;
let status = response.status().as_u16();
let response_headers: HashMap<String, String> = response
.headers()
.iter()
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|v| (k.as_str().to_string(), v.to_string()))
})
.collect();
let headers_json = serde_json::to_string(&response_headers).unwrap_or_default();
let max_response = max_response_bytes;
if let Some(cl) = response.content_length()
&& cl as usize > max_response
{
return Err(format!(
"Response body too large: {} bytes exceeds limit of {} bytes",
cl, max_response
));
}
let body = response
.bytes()
.await
.map_err(|e| format!("Failed to read response body: {}", e))?;
if body.len() > max_response {
return Err(format!(
"Response body too large: {} bytes exceeds limit of {} bytes",
body.len(),
max_response
));
}
let body = body.to_vec();
if let Ok(body_str) = std::str::from_utf8(&body) {
leak_detector
.scan_and_clean(body_str)
.map_err(|e| format!("Potential secret leak in response: {}", e))?;
}
Ok(near::agent::host::HttpResponse {
status,
headers_json,
body,
})
});
if let (Some(interceptor), Some(req), Ok(resp)) =
(&self.http_interceptor, &interceptor_req, &result)
{
let interceptor = Arc::clone(interceptor);
let mut redacted_req = req.clone();
redacted_req.url = self.redact_credentials(&redacted_req.url);
redacted_req.headers = redacted_req
.headers
.into_iter()
.map(|(k, v)| (k, self.redact_credentials(&v)))
.collect();
redacted_req.body = redacted_req.body.map(|b| self.redact_credentials(&b));
let resp_headers: Vec<(String, String)> =
serde_json::from_str::<HashMap<String, String>>(&resp.headers_json)
.unwrap_or_default()
.into_iter()
.collect();
let resp_body = String::from_utf8_lossy(&resp.body).to_string();
let redacted_headers: Vec<(String, String)> = resp_headers
.into_iter()
.map(|(k, v)| (k, self.redact_credentials(&v)))
.collect();
let redacted_body = self.redact_credentials(&resp_body);
let exchange_resp = HttpExchangeResponse {
status: resp.status,
headers: redacted_headers,
body: redacted_body,
};
rt.block_on(async {
interceptor
.after_response(&redacted_req, &exchange_resp)
.await;
});
}
result.map_err(|e| self.redact_credentials(&e))
}
fn tool_invoke(&mut self, alias: String, _params_json: String) -> Result<String, String> {
let _real_name = self.host_state.check_tool_invoke_allowed(&alias)?;
self.host_state.record_tool_invoke()?;
Err("Tool invocation from WASM tools is not yet supported".to_string())
}
fn secret_exists(&mut self, name: String) -> bool {
self.host_state.secret_exists(&name)
}
}
pub struct WasmToolWrapper {
runtime: Arc<WasmToolRuntime>,
prepared: Arc<PreparedModule>,
capabilities: Capabilities,
description: String,
schemas: WasmToolSchemas,
credentials: HashMap<String, String>,
secrets_store: Option<Arc<dyn SecretsStore + Send + Sync>>,
oauth_refresh: Option<OAuthRefreshConfig>,
http_interceptor: Option<Arc<dyn HttpInterceptor>>,
}
#[derive(Debug, Clone)]
struct WasmToolSchemas {
advertised: serde_json::Value,
discovery: serde_json::Value,
}
impl WasmToolSchemas {
fn permissive_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": true
})
}
fn is_permissive_schema(schema: &serde_json::Value) -> bool {
if schema
.get("properties")
.and_then(|p| p.as_object())
.is_some_and(|p| !p.is_empty())
{
return false;
}
for key in ["oneOf", "anyOf", "allOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array())
&& variants.iter().any(|v| {
v.get("properties")
.and_then(|p| p.as_object())
.is_some_and(|p| !p.is_empty())
})
{
return false;
}
}
true
}
fn typed_property_count(schema: &serde_json::Value) -> usize {
let mut all_props = serde_json::Map::new();
if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
all_props.extend(props.iter().map(|(k, v)| (k.clone(), v.clone())));
}
for key in ["allOf", "oneOf", "anyOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array()) {
for variant in variants {
if let Some(props) = variant.get("properties").and_then(|p| p.as_object()) {
all_props.extend(props.iter().map(|(k, v)| (k.clone(), v.clone())));
}
}
}
}
all_props
.values()
.filter(|prop| schema_is_typed_property(prop))
.count()
}
fn new(discovery: serde_json::Value) -> Self {
let advertised = Self::compact_schema(&discovery);
Self {
advertised,
discovery,
}
}
fn compact_schema(discovery: &serde_json::Value) -> serde_json::Value {
const MAX_COMPACT_PROPERTIES: usize = 100;
let required: std::collections::HashSet<String> = discovery
.get("required")
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let mut all_properties = serde_json::Map::new();
let mut const_values: std::collections::HashMap<String, Vec<serde_json::Value>> =
std::collections::HashMap::new();
if let Some(props) = discovery.get("properties").and_then(|p| p.as_object()) {
for (k, v) in props {
if all_properties.len() >= MAX_COMPACT_PROPERTIES {
break;
}
all_properties.insert(k.clone(), v.clone());
}
}
for key in ["oneOf", "anyOf", "allOf"] {
if let Some(variants) = discovery.get(key).and_then(|v| v.as_array()) {
for variant in variants {
if let Some(props) = variant.get("properties").and_then(|p| p.as_object()) {
for (k, v) in props {
if all_properties.len() >= MAX_COMPACT_PROPERTIES
&& !all_properties.contains_key(k)
{
continue;
}
if let Some(c) = v.get("const") {
const_values.entry(k.clone()).or_default().push(c.clone());
}
all_properties.entry(k.clone()).or_insert_with(|| v.clone());
}
}
}
}
}
for (name, values) in &const_values {
if values.len() > 1
&& let Some(prop) = all_properties.get_mut(name)
{
let mut merged = prop.clone();
if let Some(obj) = merged.as_object_mut() {
obj.remove("const");
obj.insert("enum".to_string(), serde_json::Value::Array(values.clone()));
}
*prop = merged;
}
}
if all_properties.is_empty() {
return Self::permissive_schema();
}
let kept: serde_json::Map<String, serde_json::Value> = all_properties
.iter()
.filter(|(name, prop)| {
required.contains(name.as_str())
|| prop.get("enum").is_some()
|| prop.get("const").is_some()
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if kept.is_empty() {
let typed: serde_json::Map<String, serde_json::Value> = all_properties
.into_iter()
.filter(|(_, prop)| schema_is_typed_property(prop))
.collect();
if typed.is_empty() {
return Self::permissive_schema();
}
return serde_json::json!({
"type": "object",
"properties": typed,
"additionalProperties": true,
});
}
let kept_required: Vec<serde_json::Value> = required
.iter()
.filter(|name| kept.contains_key(name.as_str()))
.map(|name| serde_json::Value::String(name.clone()))
.collect();
let mut result = serde_json::json!({
"type": "object",
"properties": kept,
"additionalProperties": true,
});
if !kept_required.is_empty() {
result["required"] = serde_json::Value::Array(kept_required);
}
result
}
fn with_override(&self, schema: serde_json::Value) -> Self {
Self {
advertised: schema.clone(),
discovery: schema,
}
}
fn is_advertised_permissive(&self) -> bool {
Self::is_permissive_schema(&self.advertised)
}
fn advertised(&self) -> serde_json::Value {
self.advertised.clone()
}
fn discovery(&self) -> serde_json::Value {
self.discovery.clone()
}
}
impl WasmToolWrapper {
pub fn new(
runtime: Arc<WasmToolRuntime>,
prepared: Arc<PreparedModule>,
capabilities: Capabilities,
) -> Self {
Self {
description: prepared.description.clone(),
schemas: WasmToolSchemas::new(prepared.schema.clone()),
runtime,
prepared,
capabilities,
credentials: HashMap::new(),
secrets_store: None,
oauth_refresh: None,
http_interceptor: None,
}
}
pub fn with_http_interceptor(mut self, interceptor: Arc<dyn HttpInterceptor>) -> Self {
self.http_interceptor = Some(interceptor);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn with_schema(mut self, schema: serde_json::Value) -> Self {
let override_typed = WasmToolSchemas::typed_property_count(&schema);
let prepared_typed = WasmToolSchemas::typed_property_count(&self.prepared.schema);
if override_typed == 0 && prepared_typed > 0 {
tracing::warn!(
tool = %self.prepared.name,
"Ignoring untyped schema override for discovery/runtime preparation and preserving extracted WASM schema"
);
self.schemas = WasmToolSchemas {
advertised: schema,
discovery: self.prepared.schema.clone(),
};
} else {
self.schemas = self.schemas.with_override(schema);
}
self
}
pub fn with_credentials(mut self, credentials: HashMap<String, String>) -> Self {
self.credentials = credentials;
self
}
pub fn with_secrets_store(mut self, store: Arc<dyn SecretsStore + Send + Sync>) -> Self {
self.secrets_store = Some(store);
self
}
pub fn with_oauth_refresh(mut self, config: OAuthRefreshConfig) -> Self {
self.oauth_refresh = Some(config);
self
}
pub fn limits(&self) -> &ResourceLimits {
&self.prepared.limits
}
fn add_host_functions(linker: &mut Linker<StoreData>) -> Result<(), WasmError> {
wasmtime_wasi::add_to_linker_sync(linker)
.map_err(|e| WasmError::ConfigError(format!("Failed to add WASI functions: {}", e)))?;
near::agent::host::add_to_linker(linker, |state| state)
.map_err(|e| WasmError::ConfigError(format!("Failed to add host functions: {}", e)))?;
Ok(())
}
fn execute_sync(
&self,
params: serde_json::Value,
context_json: Option<String>,
host_credentials: Vec<ResolvedHostCredential>,
) -> Result<(String, Vec<crate::tools::wasm::host::LogEntry>), WasmError> {
let engine = self.runtime.engine();
let limits = &self.prepared.limits;
let mut store_data = StoreData::new(
limits.memory_bytes,
self.capabilities.clone(),
self.credentials.clone(),
host_credentials,
);
store_data.http_interceptor = self.http_interceptor.clone();
let mut store = Store::new(engine, store_data);
if self.runtime.config().fuel_config.enabled {
store
.set_fuel(limits.fuel)
.map_err(|e| WasmError::ConfigError(format!("Failed to set fuel: {}", e)))?;
}
store.epoch_deadline_trap();
let ticks = (limits.timeout.as_millis() / EPOCH_TICK_INTERVAL.as_millis()).max(1) as u64;
store.set_epoch_deadline(ticks);
store.limiter(|data| &mut data.limiter);
let component = self.prepared.component().clone();
let mut linker = Linker::new(engine);
Self::add_host_functions(&mut linker)?;
let instance =
SandboxedTool::instantiate(&mut store, &component, &linker).map_err(|e| {
let msg = e.to_string();
if msg.contains("near:agent") || msg.contains("import") {
WasmError::InstantiationFailed(format!(
"{msg}. This usually means the extension was compiled against \
a different WIT version than the host supports. \
Rebuild the extension against the current WIT (host: {}).",
crate::tools::wasm::WIT_TOOL_VERSION
))
} else {
WasmError::InstantiationFailed(msg)
}
})?;
let tool_iface = instance.near_agent_tool();
let params_json = serde_json::to_string(¶ms)
.map_err(|e| WasmError::InvalidResponseJson(e.to_string()))?;
let request = wit_tool::Request {
params: params_json,
context: context_json,
};
let response = tool_iface.call_execute(&mut store, &request).map_err(|e| {
let error_str = e.to_string();
if error_str.contains("out of fuel") {
WasmError::FuelExhausted { limit: limits.fuel }
} else if error_str.contains("unreachable") {
WasmError::Trapped("unreachable code executed".to_string())
} else {
WasmError::Trapped(error_str)
}
})?;
let logs = store.data_mut().host_state.take_logs();
if let Some(err) = response.error {
let hint = build_tool_usage_hint(&self.prepared.name, &self.schemas.discovery());
return Err(WasmError::ToolReturnedError { message: err, hint });
}
Ok((response.output.unwrap_or_default(), logs))
}
}
pub(super) fn extract_wasm_metadata(
engine: &wasmtime::Engine,
component: &wasmtime::component::Component,
limits: &ResourceLimits,
) -> Result<(String, serde_json::Value), WasmError> {
let store_data = StoreData::new(
limits.memory_bytes,
Capabilities::default(),
HashMap::new(),
vec![],
);
let mut store = Store::new(engine, store_data);
if let Err(e) = store.set_fuel(limits.fuel) {
tracing::debug!("Fuel not enabled for metadata extraction: {e}");
}
store.epoch_deadline_trap();
let ticks = (limits.timeout.as_millis() / EPOCH_TICK_INTERVAL.as_millis()).max(1) as u64;
store.set_epoch_deadline(ticks);
store.limiter(|data| &mut data.limiter);
let mut linker = Linker::new(engine);
WasmToolWrapper::add_host_functions(&mut linker)?;
let instance = SandboxedTool::instantiate(&mut store, component, &linker)
.map_err(|e| WasmError::InstantiationFailed(e.to_string()))?;
let tool_iface = instance.near_agent_tool();
let description = tool_iface
.call_description(&mut store)
.unwrap_or_else(|_| "WASM sandboxed tool".to_string());
let schema = tool_iface
.call_schema(&mut store)
.ok()
.and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
.unwrap_or_else(|| {
serde_json::json!({"type": "object", "properties": {}, "additionalProperties": true})
});
Ok((description, schema))
}
#[async_trait]
impl Tool for WasmToolWrapper {
fn name(&self) -> &str {
&self.prepared.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> serde_json::Value {
self.schemas.advertised()
}
fn discovery_schema(&self) -> serde_json::Value {
self.schemas.discovery()
}
fn schema(&self) -> crate::tools::tool::ToolSchema {
let description = if self.schemas.is_advertised_permissive() {
format!(
"{} (call tool_info(name: \"{}\", include_schema: true) for parameter schema)",
self.description, self.prepared.name
)
} else {
self.description.clone()
};
crate::tools::tool::ToolSchema {
name: self.prepared.name.clone(),
description,
parameters: self.schemas.advertised(),
}
}
async fn execute(
&self,
params: serde_json::Value,
ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
let start = Instant::now();
let timeout = self.prepared.limits.timeout;
let credential_user_id = &ctx.user_id;
let host_credentials = resolve_host_credentials(
&self.capabilities,
self.secrets_store.as_deref(),
credential_user_id,
self.oauth_refresh.as_ref(),
)
.await;
let context_json = serde_json::to_string(ctx).ok();
let runtime = Arc::clone(&self.runtime);
let prepared = Arc::clone(&self.prepared);
let capabilities = self.capabilities.clone();
let description = self.description.clone();
let schemas = self.schemas.clone();
let credentials = self.credentials.clone();
let result = tokio::time::timeout(timeout, async move {
let wrapper = WasmToolWrapper {
runtime,
prepared,
capabilities,
description,
schemas,
credentials,
secrets_store: None, oauth_refresh: None, http_interceptor: self.http_interceptor.clone(),
};
tokio::task::spawn_blocking(move || {
wrapper.execute_sync(params, context_json, host_credentials)
})
.await
.map_err(|e| WasmError::ExecutionPanicked(e.to_string()))?
})
.await;
let duration = start.elapsed();
match result {
Ok(Ok((result_json, logs))) => {
for log in logs {
match log.level {
LogLevel::Trace => tracing::trace!(target: "wasm_tool", "{}", log.message),
LogLevel::Debug => tracing::debug!(target: "wasm_tool", "{}", log.message),
LogLevel::Info => tracing::info!(target: "wasm_tool", "{}", log.message),
LogLevel::Warn => tracing::warn!(target: "wasm_tool", "{}", log.message),
LogLevel::Error => tracing::error!(target: "wasm_tool", "{}", log.message),
}
}
let result: serde_json::Value = serde_json::from_str(&result_json)
.unwrap_or(serde_json::Value::String(result_json));
Ok(ToolOutput::success(result, duration))
}
Ok(Err(wasm_err)) => Err(wasm_err.into()),
Err(_) => Err(WasmError::Timeout(timeout).into()),
}
}
fn requires_sanitization(&self) -> bool {
true
}
fn estimated_duration(&self, _params: &serde_json::Value) -> Option<Duration> {
Some(self.prepared.limits.timeout)
}
fn webhook_capability(&self) -> Option<crate::tools::wasm::WebhookCapability> {
self.capabilities.webhook.clone()
}
}
impl std::fmt::Debug for WasmToolWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WasmToolWrapper")
.field("name", &self.prepared.name)
.field("description", &self.description)
.field("limits", &self.prepared.limits)
.finish()
}
}
async fn refresh_oauth_token(
store: &(dyn SecretsStore + Send + Sync),
user_id: &str,
config: &OAuthRefreshConfig,
) -> bool {
let refresh_name = format!("{}_refresh_token", config.secret_name);
if let Some(proxy_url) = config.exchange_proxy_url.as_deref() {
let Some(oauth_proxy_auth_token) = config.oauth_proxy_auth_token() else {
tracing::warn!(
"OAuth refresh proxy is configured, but no OAuth proxy auth token is available"
);
return false;
};
let refresh_secret = match load_oauth_refresh_secret(store, user_id, &refresh_name).await {
Some(secret) => secret,
None => return false,
};
let token_response = match oauth_defaults::refresh_token_via_proxy(
oauth_defaults::ProxyRefreshTokenRequest {
proxy_url,
gateway_token: oauth_proxy_auth_token,
token_url: &config.token_url,
client_id: &config.client_id,
client_secret: config.client_secret.as_deref(),
refresh_token: refresh_secret.expose(),
resource: None,
provider: config.provider.as_deref(),
},
)
.await
{
Ok(response) => response,
Err(error) => {
tracing::warn!(error = %error, "OAuth token refresh via proxy failed");
return false;
}
};
return persist_refreshed_oauth_tokens(
store,
user_id,
config,
&refresh_name,
token_response,
)
.await;
}
if !config.token_url.starts_with("https://") {
tracing::warn!(
token_url = %config.token_url,
"OAuth token_url must use HTTPS, refusing token refresh"
);
return false;
}
if let Err(reason) = reject_private_ip(&config.token_url) {
tracing::warn!(
token_url = %config.token_url,
reason = %reason,
"OAuth token_url points to a private/internal IP, refusing token refresh"
);
return false;
}
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(15))
.redirect(reqwest::redirect::Policy::none())
.build()
{
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "Failed to build HTTP client for token refresh");
return false;
}
};
let refresh_secret = match load_oauth_refresh_secret(store, user_id, &refresh_name).await {
Some(secret) => secret,
None => return false,
};
let mut params = vec![
("grant_type", "refresh_token".to_string()),
("refresh_token", refresh_secret.expose().to_string()),
("client_id", config.client_id.clone()),
];
if let Some(ref secret) = config.client_secret {
params.push(("client_secret", secret.clone()));
}
let response = match client.post(&config.token_url).form(¶ms).send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!(error = %e, "OAuth token refresh request failed");
return false;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
tracing::warn!(
status = %status,
body = %body,
"OAuth token refresh returned non-success status"
);
return false;
}
let token_data: serde_json::Value = match response.json().await {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "Failed to parse token refresh response");
return false;
}
};
let token_response = match token_data.get("access_token").and_then(|v| v.as_str()) {
Some(access_token) => oauth_defaults::OAuthTokenResponse {
access_token: access_token.to_string(),
refresh_token: token_data
.get("refresh_token")
.and_then(|v| v.as_str())
.map(str::to_string),
expires_in: token_data.get("expires_in").and_then(|v| v.as_u64()),
token_type: token_data
.get("token_type")
.and_then(|v| v.as_str())
.map(str::to_string),
scope: token_data
.get("scope")
.and_then(|v| v.as_str())
.map(str::to_string),
},
None => {
tracing::warn!("Token refresh response missing access_token field");
return false;
}
};
persist_refreshed_oauth_tokens(store, user_id, config, &refresh_name, token_response).await
}
async fn load_oauth_refresh_secret(
store: &(dyn SecretsStore + Send + Sync),
user_id: &str,
refresh_name: &str,
) -> Option<DecryptedSecret> {
match store.get_decrypted(user_id, refresh_name).await {
Ok(secret) => Some(secret),
Err(error) => {
tracing::debug!(
secret_name = %refresh_name,
error = %error,
"No refresh token available, skipping token refresh"
);
None
}
}
}
async fn persist_refreshed_oauth_tokens(
store: &(dyn SecretsStore + Send + Sync),
user_id: &str,
config: &OAuthRefreshConfig,
refresh_name: &str,
token_response: oauth_defaults::OAuthTokenResponse,
) -> bool {
let mut access_params =
crate::secrets::CreateSecretParams::new(&config.secret_name, &token_response.access_token);
if let Some(ref provider) = config.provider {
access_params = access_params.with_provider(provider);
}
if let Some(expires_in) = token_response.expires_in {
let expires_at = chrono::Utc::now() + chrono::Duration::seconds(expires_in as i64);
access_params = access_params.with_expiry(expires_at);
}
if let Err(e) = store.create(user_id, access_params).await {
tracing::warn!(error = %e, "Failed to store refreshed access token");
return false;
}
if let Some(new_refresh) = token_response.refresh_token.as_deref() {
let mut refresh_params = crate::secrets::CreateSecretParams::new(refresh_name, new_refresh);
if let Some(ref provider) = config.provider {
refresh_params = refresh_params.with_provider(provider);
}
if let Err(e) = store.create(user_id, refresh_params).await {
tracing::warn!(error = %e, "Failed to store rotated refresh token");
}
}
tracing::info!(
secret_name = %config.secret_name,
"OAuth access token refreshed successfully"
);
true
}
async fn resolve_host_credentials(
capabilities: &Capabilities,
store: Option<&(dyn SecretsStore + Send + Sync)>,
user_id: &str,
oauth_refresh: Option<&OAuthRefreshConfig>,
) -> Vec<ResolvedHostCredential> {
let store = match store {
Some(s) => s,
None => {
if let Some(http_cap) = &capabilities.http
&& !http_cap.credentials.is_empty()
{
tracing::warn!(
user_id = %user_id,
"WASM tool requires credentials but secrets_store is not configured - authentication will fail"
);
}
return Vec::new();
}
};
if let Some(config) = oauth_refresh {
let needs_refresh = match store.get(user_id, &config.secret_name).await {
Ok(secret) => match secret.expires_at {
Some(expires_at) => {
let buffer = chrono::Duration::minutes(5);
expires_at - buffer < chrono::Utc::now()
}
None => false,
},
Err(crate::secrets::SecretError::Expired) => true,
Err(_) => false,
};
if needs_refresh {
tracing::debug!(
secret_name = %config.secret_name,
"Access token expired or near expiry, attempting refresh"
);
refresh_oauth_token(store, user_id, config).await;
}
}
let http_cap = match &capabilities.http {
Some(cap) => cap,
None => return Vec::new(),
};
if http_cap.credentials.is_empty() {
return Vec::new();
}
let mut resolved = Vec::new();
for mapping in http_cap.credentials.values() {
if matches!(
mapping.location,
crate::secrets::CredentialLocation::UrlPath { .. }
) {
continue;
}
let secret = match store.get_decrypted(user_id, &mapping.secret_name).await {
Ok(s) => Some(s),
Err(e) => {
tracing::trace!(
user_id = %user_id,
secret_name = %mapping.secret_name,
error = %e,
"No matching host credential resolved for WASM tool in the requested scope"
);
if user_id != "default" {
tracing::debug!(
secret_name = %mapping.secret_name,
user_id = %user_id,
error = %e,
"Credential not found for user, trying default global credentials"
);
store
.get_decrypted("default", &mapping.secret_name)
.await
.ok()
} else {
None
}
}
};
let secret = match secret {
Some(s) => s,
None => {
tracing::warn!(
secret_name = %mapping.secret_name,
user_id = %user_id,
"Could not resolve credential for WASM tool (not found in user context or default)"
);
continue;
}
};
let mut injected = InjectedCredentials::empty();
inject_credential(&mut injected, &mapping.location, &secret);
if injected.is_empty() {
continue;
}
resolved.push(ResolvedHostCredential {
host_patterns: mapping.host_patterns.clone(),
headers: injected.headers,
query_params: injected.query_params,
secret_value: secret.expose().to_string(),
});
}
if !resolved.is_empty() {
tracing::debug!(
count = resolved.len(),
"Pre-resolved host credentials for WASM tool execution"
);
}
resolved
}
fn extract_host_from_url(url: &str) -> Option<String> {
let parsed = url::Url::parse(url).ok()?;
if !matches!(parsed.scheme(), "http" | "https") {
return None;
}
parsed.host_str().map(|h| {
h.strip_prefix('[')
.and_then(|v| v.strip_suffix(']'))
.unwrap_or(h)
.to_lowercase()
})
}
fn reject_private_ip(url: &str) -> Result<(), String> {
let parsed = url::Url::parse(url).map_err(|e| format!("Failed to parse URL: {e}"))?;
if !matches!(parsed.scheme(), "http" | "https") {
return Err(format!("Unsupported URL scheme: {}", parsed.scheme()));
}
if !parsed.username().is_empty() || parsed.password().is_some() {
return Err("URL contains userinfo (@) which is not allowed".to_string());
}
let host = parsed
.host_str()
.map(|h| {
h.strip_prefix('[')
.and_then(|v| v.strip_suffix(']'))
.unwrap_or(h)
})
.ok_or_else(|| "Failed to parse host from URL".to_string())?;
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return if is_private_ip(ip) {
Err(format!(
"HTTP request to private/internal IP {} is not allowed",
ip
))
} else {
Ok(())
};
}
use std::net::ToSocketAddrs;
let addrs: Vec<_> = format!("{}:0", host)
.to_socket_addrs()
.map_err(|e| format!("DNS resolution failed for {}: {}", host, e))?
.collect();
if addrs.is_empty() {
return Err(format!("DNS resolution returned no addresses for {}", host));
}
for addr in &addrs {
if is_private_ip(addr.ip()) {
return Err(format!(
"DNS rebinding detected: {} resolved to private IP {}",
host,
addr.ip()
));
}
}
Ok(())
}
fn is_private_ip(ip: std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 }
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xFE00) == 0xFC00
|| (v6.segments()[0] & 0xFFC0) == 0xFE80
}
}
}
fn schema_contains_container_properties(schema: &serde_json::Value) -> bool {
let has_container = |props: &serde_json::Map<String, serde_json::Value>| {
props
.values()
.any(|prop| schema_declares_type(prop, "array") || schema_declares_type(prop, "object"))
};
if schema
.get("properties")
.and_then(|p| p.as_object())
.is_some_and(has_container)
{
return true;
}
for key in ["allOf", "oneOf", "anyOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array())
&& variants.iter().any(|v| {
v.get("properties")
.and_then(|p| p.as_object())
.is_some_and(has_container)
})
{
return true;
}
}
false
}
fn schema_declares_type(schema: &serde_json::Value, expected: &str) -> bool {
match schema.get("type") {
Some(serde_json::Value::String(t)) => t == expected,
Some(serde_json::Value::Array(types)) => types.iter().any(|t| t.as_str() == Some(expected)),
_ => match expected {
"object" => {
schema
.get("properties")
.and_then(|p| p.as_object())
.is_some()
|| schema
.get("additionalProperties")
.is_some_and(serde_json::Value::is_object)
}
"array" => schema.get("items").is_some(),
_ => false,
},
}
}
fn schema_is_typed_property(schema: &serde_json::Value) -> bool {
matches!(
schema.get("type"),
Some(serde_json::Value::String(_)) | Some(serde_json::Value::Array(_))
) || schema.get("$ref").is_some()
|| schema.get("anyOf").is_some()
|| schema.get("oneOf").is_some()
|| schema.get("allOf").is_some()
|| schema.get("items").is_some()
|| schema
.get("properties")
.and_then(|p| p.as_object())
.is_some()
|| schema
.get("additionalProperties")
.is_some_and(serde_json::Value::is_object)
}
fn build_tool_usage_hint(tool_name: &str, schema: &serde_json::Value) -> String {
let mut hint = format!(
"Tip: call tool_info(name: \"{}\", include_schema: true) for the full parameter schema.",
tool_name
);
if schema_contains_container_properties(schema) {
hint.push_str(
" For array/object fields, pass native JSON arrays/objects, not quoted JSON strings.",
);
}
hint
}
fn needs_content_length_zero(method: &str, headers: &HashMap<String, String>) -> bool {
let mutating = method.eq_ignore_ascii_case("POST")
|| method.eq_ignore_ascii_case("PUT")
|| method.eq_ignore_ascii_case("PATCH")
|| method.eq_ignore_ascii_case("DELETE");
mutating
&& !headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("content-length"))
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use axum::extract::{Form, State};
use axum::http::HeaderMap;
use axum::routing::post;
use axum::{Json, Router};
use serde_json::json;
use tokio::net::TcpListener;
use tokio::sync::{Mutex as AsyncMutex, oneshot};
use uuid::Uuid;
use crate::context::JobContext;
use crate::secrets::{
CreateSecretParams, DecryptedSecret, InMemorySecretsStore, Secret, SecretError, SecretRef,
SecretsStore,
};
use crate::testing::credentials::{
TEST_BEARER_TOKEN_123, TEST_GOOGLE_OAUTH_FRESH, TEST_GOOGLE_OAUTH_LEGACY,
TEST_GOOGLE_OAUTH_TOKEN, TEST_OAUTH_CLIENT_ID, TEST_OAUTH_CLIENT_SECRET,
test_secrets_store,
};
use crate::tools::tool::Tool;
use crate::tools::wasm::capabilities::Capabilities;
use crate::tools::wasm::runtime::{WasmRuntimeConfig, WasmToolRuntime};
struct RecordingSecretsStore {
inner: InMemorySecretsStore,
get_decrypted_lookups: Mutex<Vec<(String, String)>>,
}
impl RecordingSecretsStore {
fn new() -> Self {
Self {
inner: test_secrets_store(),
get_decrypted_lookups: Mutex::new(Vec::new()),
}
}
fn decrypted_lookups(&self) -> Vec<(String, String)> {
self.get_decrypted_lookups.lock().unwrap().clone()
}
}
#[async_trait]
impl SecretsStore for RecordingSecretsStore {
async fn create(
&self,
user_id: &str,
params: CreateSecretParams,
) -> Result<Secret, SecretError> {
self.inner.create(user_id, params).await
}
async fn get(&self, user_id: &str, name: &str) -> Result<Secret, SecretError> {
self.inner.get(user_id, name).await
}
async fn get_decrypted(
&self,
user_id: &str,
name: &str,
) -> Result<DecryptedSecret, SecretError> {
self.get_decrypted_lookups
.lock()
.unwrap()
.push((user_id.to_string(), name.to_string()));
self.inner.get_decrypted(user_id, name).await
}
async fn exists(&self, user_id: &str, name: &str) -> Result<bool, SecretError> {
self.inner.exists(user_id, name).await
}
async fn list(&self, user_id: &str) -> Result<Vec<SecretRef>, SecretError> {
self.inner.list(user_id).await
}
async fn delete(&self, user_id: &str, name: &str) -> Result<bool, SecretError> {
self.inner.delete(user_id, name).await
}
async fn record_usage(&self, secret_id: Uuid) -> Result<(), SecretError> {
self.inner.record_usage(secret_id).await
}
async fn is_accessible(
&self,
user_id: &str,
secret_name: &str,
allowed_secrets: &[String],
) -> Result<bool, SecretError> {
self.inner
.is_accessible(user_id, secret_name, allowed_secrets)
.await
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct RecordedProxyRequest {
authorization: Option<String>,
form: HashMap<String, String>,
}
struct MockProxyServer {
addr: SocketAddr,
requests: Arc<AsyncMutex<Vec<RecordedProxyRequest>>>,
shutdown_tx: Option<oneshot::Sender<()>>,
server_task: Option<tokio::task::JoinHandle<()>>,
}
impl MockProxyServer {
async fn start() -> Self {
async fn refresh_handler(
State(requests): State<Arc<AsyncMutex<Vec<RecordedProxyRequest>>>>,
headers: HeaderMap,
Form(form): Form<HashMap<String, String>>,
) -> Json<serde_json::Value> {
requests.lock().await.push(RecordedProxyRequest {
authorization: headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::to_string),
form,
});
Json(json!({
"access_token": "mock-refreshed-access-token",
"refresh_token": "mock-rotated-refresh-token",
"expires_in": 3600
}))
}
let requests = Arc::new(AsyncMutex::new(Vec::new()));
let app = Router::new()
.route("/oauth/refresh", post(refresh_handler))
.with_state(Arc::clone(&requests));
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind mock proxy");
let addr = listener.local_addr().expect("read mock proxy addr");
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let server_task = tokio::spawn(async move {
let _ = axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.await;
});
Self {
addr,
requests,
shutdown_tx: Some(shutdown_tx),
server_task: Some(server_task),
}
}
fn base_url(&self) -> String {
format!("http://{}", self.addr)
}
async fn requests(&self) -> Vec<RecordedProxyRequest> {
self.requests.lock().await.clone()
}
async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(task) = self.server_task.take() {
let _ = task.await;
}
}
}
impl Drop for MockProxyServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(task) = self.server_task.take() {
task.abort();
}
}
}
#[test]
fn test_wrapper_creation() {
let config = WasmRuntimeConfig::for_testing();
let runtime = Arc::new(WasmToolRuntime::new(config).unwrap());
assert!(runtime.config().fuel_config.enabled);
}
#[tokio::test]
async fn test_advertised_schema_auto_compacted_from_discovery() {
let discovery_schema = serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" },
"limit": { "type": "integer" }
},
"required": ["query"]
});
let runtime = Arc::new(WasmToolRuntime::new(WasmRuntimeConfig::for_testing()).unwrap());
let prepared = runtime
.prepare("search", b"\0asm\x0d\0\x01\0", None)
.await
.unwrap();
let mut wrapper =
super::WasmToolWrapper::new(Arc::clone(&runtime), prepared, Capabilities::default());
wrapper.schemas = super::WasmToolSchemas::new(discovery_schema.clone());
wrapper.description = "Search documents".to_string();
assert_eq!(
wrapper.parameters_schema(),
serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"],
"additionalProperties": true
})
);
assert_eq!(wrapper.discovery_schema(), discovery_schema);
let schema = wrapper.schema();
assert!(
!schema.description.contains("tool_info"),
"schema().description should not contain tool_info hint when auto-compacted: {}",
schema.description
);
}
#[tokio::test]
async fn test_typed_schema_without_required_is_advertised() {
let discovery_schema = serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" },
"limit": { "type": "integer" }
}
});
let runtime = Arc::new(WasmToolRuntime::new(WasmRuntimeConfig::for_testing()).unwrap());
let prepared = runtime
.prepare("typed_search", b"\0asm\x0d\0\x01\0", None)
.await
.unwrap();
let mut wrapper =
super::WasmToolWrapper::new(Arc::clone(&runtime), prepared, Capabilities::default());
wrapper.schemas = super::WasmToolSchemas::new(discovery_schema.clone());
wrapper.description = "Typed search tool".to_string();
let advertised = wrapper.parameters_schema();
let props = advertised["properties"].as_object().unwrap();
assert!(
props.contains_key("query"),
"advertised schema should contain 'query' property"
);
assert!(
props.contains_key("limit"),
"advertised schema should contain 'limit' property"
);
assert_eq!(props.len(), 2);
assert!(
!super::WasmToolSchemas::is_permissive_schema(&advertised),
"advertised schema should not be permissive when typed properties exist"
);
let schema = wrapper.schema();
assert!(
!schema.description.contains("tool_info"),
"description should not contain tool_info hint: {}",
schema.description
);
}
#[test]
fn test_compact_schema_keeps_required_and_enum_properties() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["list", "get", "create"],
"description": "The operation"
},
"query": { "type": "string" },
"limit": { "type": "integer" },
"format": {
"type": "string",
"enum": ["json", "csv"]
}
},
"required": ["action"]
});
let compacted = super::WasmToolSchemas::compact_schema(&schema);
let props = compacted["properties"].as_object().unwrap();
assert!(props.contains_key("action"));
assert!(props.contains_key("format"));
assert!(!props.contains_key("query"));
assert!(!props.contains_key("limit"));
assert_eq!(compacted["additionalProperties"], true);
assert_eq!(compacted["required"], serde_json::json!(["action"]));
}
#[test]
fn test_compact_schema_preserves_typed_properties_when_no_required() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" },
"limit": { "type": "integer" }
}
});
let compacted = super::WasmToolSchemas::compact_schema(&schema);
let props = compacted["properties"].as_object().unwrap();
assert_eq!(props.len(), 2);
assert!(props.contains_key("query"));
assert!(props.contains_key("limit"));
assert_eq!(compacted["additionalProperties"], true);
}
#[test]
fn test_compact_schema_falls_back_to_permissive_when_no_typed_properties() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"data": {}
}
});
let compacted = super::WasmToolSchemas::compact_schema(&schema);
assert!(compacted["properties"].as_object().unwrap().is_empty());
}
#[test]
fn test_compact_schema_handles_no_properties() {
let schema = serde_json::json!({ "type": "object" });
let compacted = super::WasmToolSchemas::compact_schema(&schema);
assert!(compacted["properties"].as_object().unwrap().is_empty());
}
#[test]
fn test_compact_schema_handles_oneof_variants() {
let schema = serde_json::json!({
"type": "object",
"required": ["action"],
"oneOf": [
{
"properties": {
"action": { "const": "get_repo" },
"owner": { "type": "string" },
"repo": { "type": "string" }
},
"required": ["action", "owner", "repo"]
},
{
"properties": {
"action": { "const": "list_issues" },
"owner": { "type": "string" },
"repo": { "type": "string" },
"state": { "type": "string", "enum": ["open", "closed", "all"] }
},
"required": ["action", "owner", "repo"]
}
]
});
let compacted = super::WasmToolSchemas::compact_schema(&schema);
let props = compacted["properties"].as_object().unwrap();
let action = &props["action"];
assert!(
action.get("enum").is_some(),
"action const values should be merged into enum: {action}"
);
let action_enum = action["enum"].as_array().unwrap();
assert!(
action_enum.contains(&serde_json::json!("get_repo")),
"enum should contain get_repo"
);
assert!(
action_enum.contains(&serde_json::json!("list_issues")),
"enum should contain list_issues"
);
assert!(
action.get("const").is_none(),
"const should be removed after merging into enum"
);
assert!(
props.contains_key("state"),
"state should be kept (has enum)"
);
assert!(!props.contains_key("owner"), "owner should be dropped");
assert!(!props.contains_key("repo"), "repo should be dropped");
assert_eq!(compacted["additionalProperties"], true);
assert_eq!(compacted["required"], serde_json::json!(["action"]));
}
#[test]
fn test_capabilities_default() {
let caps = Capabilities::default();
assert!(caps.workspace_read.is_none());
assert!(caps.http.is_none());
assert!(caps.tool_invoke.is_none());
assert!(caps.secrets.is_none());
}
#[test]
fn test_extract_host_from_url() {
use crate::tools::wasm::wrapper::extract_host_from_url;
assert_eq!(
extract_host_from_url("https://www.googleapis.com/calendar/v3/events"),
Some("www.googleapis.com".to_string())
);
assert_eq!(
extract_host_from_url("https://api.example.com:443/v1/foo"),
Some("api.example.com".to_string())
);
assert_eq!(
extract_host_from_url("http://localhost:8080/test?q=1"),
Some("localhost".to_string())
);
assert_eq!(
extract_host_from_url("https://user:pass@host.com/path"),
Some("host.com".to_string())
);
assert_eq!(extract_host_from_url("ftp://bad.com"), None);
assert_eq!(extract_host_from_url("not a url"), None);
assert_eq!(
extract_host_from_url("http://[::1]:8080/test"),
Some("::1".to_string())
);
assert_eq!(
extract_host_from_url("https://[2001:db8::1]/path"),
Some("2001:db8::1".to_string())
);
}
#[test]
fn test_inject_host_credentials_bearer() {
use crate::tools::wasm::wrapper::{ResolvedHostCredential, StoreData};
use std::collections::HashMap;
let host_credentials = vec![ResolvedHostCredential {
host_patterns: vec!["www.googleapis.com".to_string()],
headers: {
let mut h = HashMap::new();
h.insert(
"Authorization".to_string(),
format!("Bearer {TEST_BEARER_TOKEN_123}"),
);
h
},
query_params: HashMap::new(),
secret_value: TEST_BEARER_TOKEN_123.to_string(),
}];
let store_data = StoreData::new(
1024 * 1024,
Capabilities::default(),
HashMap::new(),
host_credentials,
);
let mut headers = HashMap::new();
let mut url = "https://www.googleapis.com/calendar/v3/events".to_string();
store_data.inject_host_credentials("www.googleapis.com", &mut headers, &mut url);
assert_eq!(
headers.get("Authorization"),
Some(&format!("Bearer {TEST_BEARER_TOKEN_123}"))
);
let mut headers2 = HashMap::new();
let mut url2 = "https://other.com/api".to_string();
store_data.inject_host_credentials("other.com", &mut headers2, &mut url2);
assert!(!headers2.contains_key("Authorization"));
}
#[test]
fn test_inject_host_credentials_query_params() {
use crate::tools::wasm::wrapper::{ResolvedHostCredential, StoreData};
use std::collections::HashMap;
let host_credentials = vec![ResolvedHostCredential {
host_patterns: vec!["api.example.com".to_string()],
headers: HashMap::new(),
query_params: {
let mut q = HashMap::new();
q.insert("api_key".to_string(), "secret123".to_string());
q
},
secret_value: "secret123".to_string(),
}];
let store_data = StoreData::new(
1024 * 1024,
Capabilities::default(),
HashMap::new(),
host_credentials,
);
let mut headers = HashMap::new();
let mut url = "https://api.example.com/v1/data".to_string();
store_data.inject_host_credentials("api.example.com", &mut headers, &mut url);
assert!(url.contains("api_key=secret123"));
assert!(url.contains('?'));
}
#[test]
fn test_redact_credentials_includes_host_credentials() {
use crate::tools::wasm::wrapper::{ResolvedHostCredential, StoreData};
use std::collections::HashMap;
let host_credentials = vec![ResolvedHostCredential {
host_patterns: vec!["api.example.com".to_string()],
headers: HashMap::new(),
query_params: HashMap::new(),
secret_value: "super-secret-token".to_string(),
}];
let store_data = StoreData::new(
1024 * 1024,
Capabilities::default(),
HashMap::new(),
host_credentials,
);
let text = "Error: request to https://api.example.com?key=super-secret-token failed";
let redacted = store_data.redact_credentials(text);
assert!(!redacted.contains("super-secret-token"));
assert!(redacted.contains("[REDACTED:host_credential]"));
}
#[tokio::test]
async fn test_resolve_host_credentials_no_store() {
use crate::tools::wasm::wrapper::resolve_host_credentials;
let caps = Capabilities::default();
let result = resolve_host_credentials(&caps, None, "user1", None).await;
assert!(result.is_empty());
}
#[tokio::test]
async fn test_resolve_host_credentials_no_http_cap() {
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
let caps = Capabilities::default();
let result = resolve_host_credentials(&caps, Some(&store), "user1", None).await;
assert!(result.is_empty());
}
#[tokio::test]
async fn test_resolve_host_credentials_bearer() {
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_TOKEN),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let result = resolve_host_credentials(&caps, Some(&store), "user1", None).await;
assert_eq!(result.len(), 1);
assert_eq!(result[0].host_patterns, vec!["www.googleapis.com"]);
assert_eq!(
result[0].headers.get("Authorization"),
Some(&format!("Bearer {TEST_GOOGLE_OAUTH_TOKEN}"))
);
}
#[tokio::test]
async fn test_resolve_host_credentials_owner_scope_bearer() {
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
let ctx = JobContext::with_user("owner-scope", "owner-scope test", "owner-scope test");
store
.create(
&ctx.user_id,
CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_TOKEN),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let result = resolve_host_credentials(&caps, Some(&store), &ctx.user_id, None).await;
assert_eq!(result.len(), 1);
assert_eq!(
result[0].headers.get("Authorization"),
Some(&format!("Bearer {TEST_GOOGLE_OAUTH_TOKEN}"))
);
}
#[tokio::test]
async fn test_execute_resolves_host_credentials_from_owner_scope_context() {
use crate::secrets::{CredentialLocation, CredentialMapping};
use crate::tools::wasm::capabilities::HttpCapability;
let runtime = Arc::new(WasmToolRuntime::new(WasmRuntimeConfig::for_testing()).unwrap());
let prepared = runtime
.prepare("search", b"\0asm\x0d\0\x01\0", None)
.await
.unwrap();
let store = Arc::new(RecordingSecretsStore::new());
let ctx = JobContext::with_user("owner-scope", "owner-scope test", "owner-scope test");
store
.create(
&ctx.user_id,
CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_TOKEN),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let wrapper = super::WasmToolWrapper::new(Arc::clone(&runtime), prepared, caps)
.with_secrets_store(store.clone());
let result = wrapper.execute(serde_json::json!({}), &ctx).await;
assert!(result.is_err());
let lookups = store.decrypted_lookups();
assert!(lookups.contains(&("owner-scope".to_string(), "google_oauth_token".to_string())));
assert!(!lookups.contains(&("default".to_string(), "google_oauth_token".to_string())));
}
#[tokio::test]
async fn test_resolve_host_credentials_missing_secret() {
use crate::secrets::{CredentialLocation, CredentialMapping};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
let mut credentials = HashMap::new();
credentials.insert(
"missing_token".to_string(),
CredentialMapping {
secret_name: "missing_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["api.example.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let result = resolve_host_credentials(&caps, Some(&store), "user1", None).await;
assert!(result.is_empty());
}
#[tokio::test]
async fn test_resolve_host_credentials_skips_refresh_when_not_expired() {
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::{OAuthRefreshConfig, resolve_host_credentials};
let store = test_secrets_store();
let expires_at = chrono::Utc::now() + chrono::Duration::hours(2);
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_FRESH)
.with_expiry(expires_at),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let oauth_config = OAuthRefreshConfig {
token_url: "https://oauth2.googleapis.com/token".to_string(),
client_id: TEST_OAUTH_CLIENT_ID.to_string(),
client_secret: Some(TEST_OAUTH_CLIENT_SECRET.to_string()),
exchange_proxy_url: None,
gateway_token: None,
secret_name: "google_oauth_token".to_string(),
provider: Some("google".to_string()),
};
let result =
resolve_host_credentials(&caps, Some(&store), "user1", Some(&oauth_config)).await;
assert_eq!(result.len(), 1);
assert_eq!(
result[0].headers.get("Authorization"),
Some(&format!("Bearer {TEST_GOOGLE_OAUTH_FRESH}"))
);
}
#[tokio::test]
async fn test_resolve_host_credentials_skips_refresh_no_config() {
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
let expires_at = chrono::Utc::now() - chrono::Duration::hours(1);
store
.create(
"user1",
CreateSecretParams::new("my_token", "expired-value").with_expiry(expires_at),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"my_token".to_string(),
CredentialMapping {
secret_name: "my_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["api.example.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let result = resolve_host_credentials(&caps, Some(&store), "user1", None).await;
assert!(result.is_empty());
}
#[tokio::test]
async fn test_resolve_host_credentials_skips_refresh_no_expires_at() {
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::{OAuthRefreshConfig, resolve_host_credentials};
let store = test_secrets_store();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token", TEST_GOOGLE_OAUTH_LEGACY),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let oauth_config = OAuthRefreshConfig {
token_url: "https://oauth2.googleapis.com/token".to_string(),
client_id: TEST_OAUTH_CLIENT_ID.to_string(),
client_secret: Some(TEST_OAUTH_CLIENT_SECRET.to_string()),
exchange_proxy_url: None,
gateway_token: None,
secret_name: "google_oauth_token".to_string(),
provider: Some("google".to_string()),
};
let result =
resolve_host_credentials(&caps, Some(&store), "user1", Some(&oauth_config)).await;
assert_eq!(result.len(), 1);
assert_eq!(
result[0].headers.get("Authorization"),
Some(&format!("Bearer {TEST_GOOGLE_OAUTH_LEGACY}"))
);
}
#[tokio::test]
async fn test_resolve_host_credentials_refreshes_via_proxy_without_direct_token_url_validation()
{
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::{OAuthRefreshConfig, resolve_host_credentials};
let proxy = MockProxyServer::start().await;
let store = test_secrets_store();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token", "expired-access-token")
.with_expiry(chrono::Utc::now() - chrono::Duration::hours(1)),
)
.await
.unwrap();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token_refresh_token", "stored-refresh-token"),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let oauth_config = OAuthRefreshConfig {
token_url: "http://127.0.0.1:9/provider-token-endpoint".to_string(),
client_id: "hosted-google-client-id".to_string(),
client_secret: None,
exchange_proxy_url: Some(proxy.base_url()),
gateway_token: Some("gateway-test-token".to_string()),
secret_name: "google_oauth_token".to_string(),
provider: Some("google".to_string()),
};
let resolved =
resolve_host_credentials(&caps, Some(&store), "user1", Some(&oauth_config)).await;
assert_eq!(resolved.len(), 1);
assert_eq!(
resolved[0].headers.get("Authorization"),
Some(&"Bearer mock-refreshed-access-token".to_string())
);
let access_secret = store.get("user1", "google_oauth_token").await.unwrap();
assert!(
access_secret
.expires_at
.expect("refreshed access token expiry")
> chrono::Utc::now()
);
let access_value = store
.get_decrypted("user1", "google_oauth_token")
.await
.unwrap();
assert_eq!(access_value.expose(), "mock-refreshed-access-token");
let refresh_value = store
.get_decrypted("user1", "google_oauth_token_refresh_token")
.await
.unwrap();
assert_eq!(refresh_value.expose(), "mock-rotated-refresh-token");
let requests = proxy.requests().await;
assert_eq!(requests.len(), 1);
assert_eq!(
requests[0].authorization.as_deref(),
Some("Bearer gateway-test-token")
);
assert_eq!(
requests[0].form.get("client_id").map(String::as_str),
Some("hosted-google-client-id")
);
assert_eq!(
requests[0].form.get("token_url").map(String::as_str),
Some("http://127.0.0.1:9/provider-token-endpoint")
);
assert_eq!(
requests[0].form.get("refresh_token").map(String::as_str),
Some("stored-refresh-token")
);
assert_eq!(
requests[0].form.get("provider").map(String::as_str),
Some("google")
);
assert!(!requests[0].form.contains_key("client_secret"));
proxy.shutdown().await;
}
#[tokio::test]
async fn test_resolve_host_credentials_skips_refresh_token_lookup_without_oauth_proxy_auth_token()
{
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::{OAuthRefreshConfig, resolve_host_credentials};
let store = RecordingSecretsStore::new();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token", "expired-access-token")
.with_expiry(chrono::Utc::now() - chrono::Duration::hours(1)),
)
.await
.unwrap();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token_refresh_token", "stored-refresh-token"),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let oauth_config = OAuthRefreshConfig {
token_url: "https://oauth2.googleapis.com/token".to_string(),
client_id: "hosted-google-client-id".to_string(),
client_secret: None,
exchange_proxy_url: Some("https://compose-api.example.com".to_string()),
gateway_token: None,
secret_name: "google_oauth_token".to_string(),
provider: Some("google".to_string()),
};
let resolved =
resolve_host_credentials(&caps, Some(&store), "user1", Some(&oauth_config)).await;
assert!(resolved.is_empty());
let lookups = store.decrypted_lookups();
assert!(lookups.contains(&("user1".to_string(), "google_oauth_token".to_string())));
assert!(!lookups.contains(&(
"user1".to_string(),
"google_oauth_token_refresh_token".to_string(),
)));
}
#[tokio::test]
async fn test_resolve_host_credentials_skips_refresh_token_lookup_for_invalid_direct_token_url()
{
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, SecretsStore,
};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::{OAuthRefreshConfig, resolve_host_credentials};
let store = RecordingSecretsStore::new();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token", "expired-access-token")
.with_expiry(chrono::Utc::now() - chrono::Duration::hours(1)),
)
.await
.unwrap();
store
.create(
"user1",
CreateSecretParams::new("google_oauth_token_refresh_token", "stored-refresh-token"),
)
.await
.unwrap();
let mut credentials = HashMap::new();
credentials.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["www.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
credentials,
..Default::default()
}),
..Default::default()
};
let oauth_config = OAuthRefreshConfig {
token_url: "http://127.0.0.1:9/provider-token-endpoint".to_string(),
client_id: TEST_OAUTH_CLIENT_ID.to_string(),
client_secret: Some(TEST_OAUTH_CLIENT_SECRET.to_string()),
exchange_proxy_url: None,
gateway_token: None,
secret_name: "google_oauth_token".to_string(),
provider: Some("google".to_string()),
};
let resolved =
resolve_host_credentials(&caps, Some(&store), "user1", Some(&oauth_config)).await;
assert!(resolved.is_empty());
let lookups = store.decrypted_lookups();
assert!(lookups.contains(&("user1".to_string(), "google_oauth_token".to_string())));
assert!(!lookups.contains(&(
"user1".to_string(),
"google_oauth_token_refresh_token".to_string(),
)));
}
#[test]
fn test_is_private_ip_v4() {
use std::net::IpAddr;
assert!(super::is_private_ip("127.0.0.1".parse::<IpAddr>().unwrap()));
assert!(super::is_private_ip("10.0.0.1".parse::<IpAddr>().unwrap()));
assert!(super::is_private_ip(
"172.16.0.1".parse::<IpAddr>().unwrap()
));
assert!(super::is_private_ip(
"192.168.1.1".parse::<IpAddr>().unwrap()
));
assert!(super::is_private_ip(
"169.254.1.1".parse::<IpAddr>().unwrap()
));
assert!(super::is_private_ip("0.0.0.0".parse::<IpAddr>().unwrap()));
assert!(super::is_private_ip(
"100.64.0.1".parse::<IpAddr>().unwrap()
));
assert!(!super::is_private_ip("8.8.8.8".parse::<IpAddr>().unwrap()));
assert!(!super::is_private_ip("1.1.1.1".parse::<IpAddr>().unwrap()));
assert!(!super::is_private_ip(
"93.184.216.34".parse::<IpAddr>().unwrap()
));
}
#[test]
fn test_is_private_ip_v6() {
use std::net::IpAddr;
assert!(super::is_private_ip("::1".parse::<IpAddr>().unwrap()));
assert!(super::is_private_ip("::".parse::<IpAddr>().unwrap()));
assert!(super::is_private_ip("fc00::1".parse::<IpAddr>().unwrap()));
assert!(super::is_private_ip("fe80::1".parse::<IpAddr>().unwrap()));
assert!(!super::is_private_ip(
"2606:4700::1111".parse::<IpAddr>().unwrap()
));
}
#[test]
fn test_reject_private_ip_loopback() {
let result = super::reject_private_ip("https://127.0.0.1:8080/api");
assert!(result.is_err());
assert!(result.unwrap_err().contains("private/internal IP"));
}
#[test]
fn test_reject_private_ip_internal() {
let result = super::reject_private_ip("https://192.168.1.1/admin");
assert!(result.is_err());
}
#[test]
fn test_reject_private_ip_public_ok() {
let result = super::reject_private_ip("https://8.8.8.8/dns-query");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_untyped_override_preserves_extracted_discovery_schema() {
let typed_schema = serde_json::json!({
"type": "object",
"properties": {
"values": {
"type": ["array", "null"],
"items": { "type": "array" }
}
}
});
let runtime = Arc::new(WasmToolRuntime::new(WasmRuntimeConfig::for_testing()).unwrap()); let mut prepared = runtime
.prepare("sheets", b"\0asm\x0d\0\x01\0", None)
.await
.unwrap(); Arc::get_mut(&mut prepared).unwrap().schema = typed_schema.clone();
let wrapper =
super::WasmToolWrapper::new(Arc::clone(&runtime), prepared, Capabilities::default())
.with_schema(serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": true
}));
#[rustfmt::skip]
assert_eq!( wrapper.parameters_schema(),
serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": true
})
);
assert_eq!(wrapper.discovery_schema(), typed_schema); }
#[test]
fn test_build_tool_usage_hint_detects_nullable_container_properties() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"requests": {
"type": ["array", "null"],
"items": { "type": "object" }
}
}
});
let hint = super::build_tool_usage_hint("google_docs", &schema);
assert!(hint.contains("native JSON arrays/objects")); }
#[test]
fn test_leak_scan_runs_before_credential_injection() {
use crate::safety::LeakDetector;
let raw_headers: Vec<(String, String)> = vec![
(
"Authorization".to_string(),
"Bearer {SLACK_BOT_TOKEN}".to_string(),
),
("Content-Type".to_string(), "application/json".to_string()),
];
let detector = LeakDetector::new();
let pre_result = detector.scan_http_request(
"https://slack.com/api/chat.postMessage",
&raw_headers,
None,
);
assert!(
pre_result.is_ok(),
"Leak scan on pre-injection headers should pass, but got: {:?}",
pre_result
);
let post_injection_headers: Vec<(String, String)> = vec![
(
"Authorization".to_string(),
"Bearer xoxb-1234567890-abcdefghij".to_string(),
),
("Content-Type".to_string(), "application/json".to_string()),
];
let post_result = detector.scan_http_request(
"https://slack.com/api/chat.postMessage",
&post_injection_headers,
None,
);
assert!(
post_result.is_err(),
"Leak scan on post-injection headers should block the Slack token"
);
}
#[tokio::test]
async fn test_resolve_host_credentials_fallback_to_default_user() {
use crate::secrets::{CredentialLocation, CredentialMapping, SecretsStore};
use crate::tools::wasm::capabilities::HttpCapability;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
store
.create(
"default",
crate::secrets::CreateSecretParams::new("google_oauth_token", "global_token_value"),
)
.await
.expect("Failed to store global token");
let mut creds = std::collections::HashMap::new();
creds.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["sheets.googleapis.com".to_string()],
},
);
let caps = Capabilities {
http: Some(HttpCapability {
allowlist: vec![],
credentials: creds,
rate_limit: crate::tools::wasm::capabilities::RateLimitConfig::default(),
max_request_bytes: 1024 * 1024,
max_response_bytes: 10 * 1024 * 1024,
timeout: std::time::Duration::from_secs(30),
}),
..Default::default()
};
let result = resolve_host_credentials(&caps, Some(&store), "routine_user_123", None).await;
assert!(!result.is_empty(), "fallback to default"); assert_eq!(result[0].secret_value, "global_token_value"); }
fn test_capabilities_with_google_oauth() -> Capabilities {
use crate::secrets::{CredentialLocation, CredentialMapping};
use crate::tools::wasm::capabilities::HttpCapability;
let mut creds = std::collections::HashMap::new();
creds.insert(
"google_oauth_token".to_string(),
CredentialMapping {
secret_name: "google_oauth_token".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["sheets.googleapis.com".to_string()],
},
);
Capabilities {
http: Some(HttpCapability {
allowlist: vec![],
credentials: creds,
rate_limit: crate::tools::wasm::capabilities::RateLimitConfig::default(),
max_request_bytes: 1024 * 1024,
max_response_bytes: 10 * 1024 * 1024,
timeout: std::time::Duration::from_secs(30),
}),
..Default::default()
}
}
#[tokio::test]
async fn test_resolve_host_credentials_prefers_user_specific_over_default() {
use crate::secrets::SecretsStore;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
store
.create(
"default",
crate::secrets::CreateSecretParams::new("google_oauth_token", "global_token"),
)
.await
.expect("Failed to store global token");
store
.create(
"user_123",
crate::secrets::CreateSecretParams::new(
"google_oauth_token",
"user_specific_token",
),
)
.await
.expect("Failed to store user token");
let caps = test_capabilities_with_google_oauth();
let result = resolve_host_credentials(&caps, Some(&store), "user_123", None).await;
assert!(!result.is_empty(), "has user credentials"); assert_eq!(result[0].secret_value, "user_specific_token", "user token"); }
#[tokio::test]
async fn test_resolve_host_credentials_no_fallback_when_already_default() {
use crate::secrets::SecretsStore;
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
store
.create(
"default",
crate::secrets::CreateSecretParams::new("google_oauth_token", "default_token"),
)
.await
.expect("Failed to store default token");
let caps = test_capabilities_with_google_oauth();
let result = resolve_host_credentials(&caps, Some(&store), "default", None).await;
assert!(!result.is_empty(), "Should find default token"); assert_eq!(result[0].secret_value, "default_token"); }
#[tokio::test]
async fn test_resolve_host_credentials_missing_secret_warns() {
use crate::tools::wasm::wrapper::resolve_host_credentials;
let store = test_secrets_store();
let caps = test_capabilities_with_google_oauth();
let result = resolve_host_credentials(&caps, Some(&store), "user_456", None).await;
assert!(result.is_empty(), "no credentials found"); }
#[test]
fn post_no_body_needs_content_length() {
let headers = HashMap::new();
assert!(
super::needs_content_length_zero("POST", &headers),
"POST with no body must get Content-Length: 0 to avoid 411"
);
}
#[test]
fn put_no_body_needs_content_length() {
assert!(super::needs_content_length_zero("PUT", &HashMap::new()));
}
#[test]
fn delete_no_body_needs_content_length() {
assert!(super::needs_content_length_zero("DELETE", &HashMap::new()));
}
#[test]
fn patch_no_body_needs_content_length() {
assert!(super::needs_content_length_zero("PATCH", &HashMap::new()));
}
#[test]
fn get_no_body_skips_content_length() {
assert!(!super::needs_content_length_zero("GET", &HashMap::new()));
}
#[test]
fn head_no_body_skips_content_length() {
assert!(!super::needs_content_length_zero("HEAD", &HashMap::new()));
}
#[test]
fn post_no_body_respects_explicit_content_length() {
let mut headers = HashMap::new();
headers.insert("Content-Length".to_string(), "0".to_string());
assert!(
!super::needs_content_length_zero("POST", &headers),
"should not double-add when tool already sets Content-Length"
);
}
#[test]
fn content_length_check_is_case_insensitive() {
let mut headers = HashMap::new();
headers.insert("content-length".to_string(), "0".to_string());
assert!(!super::needs_content_length_zero("POST", &headers));
}
}