1use crate::{AccessPolicy, SharedAccessPolicy, shared_access_policy};
4use async_trait::async_trait;
5use rain_engine_core::{
6 NativeSkill, SkillExecutionError, SkillFailureKind, SkillInvocation, SkillManifest,
7};
8use serde_json::{Value, json};
9use std::time::Duration;
10use tracing::warn;
11
12pub struct HttpFetchSkill {
13 client: reqwest::Client,
14 policy: SharedAccessPolicy,
15 timeout: Duration,
16}
17
18impl HttpFetchSkill {
19 pub fn new(allowed_hosts: std::collections::HashSet<String>, timeout: Duration) -> Self {
21 Self {
22 client: reqwest::Client::new(),
23 policy: shared_access_policy(allowed_hosts, false),
24 timeout,
25 }
26 }
27
28 pub fn permissive(timeout: Duration) -> Self {
29 Self {
30 client: reqwest::Client::new(),
31 policy: shared_access_policy(std::collections::HashSet::new(), true),
32 timeout,
33 }
34 }
35
36 pub fn with_shared_policy(policy: SharedAccessPolicy, timeout: Duration) -> Self {
37 Self {
38 client: reqwest::Client::new(),
39 policy,
40 timeout,
41 }
42 }
43
44 async fn is_allowed(&self, url: &str) -> bool {
45 let policy = self.policy.read().await;
46 if policy.permissive {
47 return true;
48 }
49 reqwest::Url::parse(url)
50 .ok()
51 .and_then(|parsed: reqwest::Url| parsed.host_str().map(|h| h.to_string()))
52 .is_some_and(|host| policy.allowlist.contains(&host))
53 }
54
55 pub async fn access_policy(&self) -> AccessPolicy {
56 self.policy.read().await.clone()
57 }
58}
59
60pub fn manifest() -> SkillManifest {
61 crate::base_manifest(
62 "http_fetch",
63 "Make an HTTP request and return the response. Hosts must be on the allowlist.",
64 json!({
65 "type": "object",
66 "properties": {
67 "url": { "type": "string", "description": "The URL to fetch" },
68 "method": { "type": "string", "description": "HTTP method (GET, POST, etc.)", "default": "GET" },
69 "headers": { "type": "object", "description": "Optional headers" },
70 "body": { "type": "string", "description": "Optional request body" }
71 },
72 "required": ["url"]
73 }),
74 )
75}
76
77#[async_trait]
78impl NativeSkill for HttpFetchSkill {
79 async fn execute(&self, invocation: SkillInvocation) -> Result<Value, SkillExecutionError> {
80 let url = invocation.args["url"].as_str().ok_or_else(|| {
81 SkillExecutionError::new(SkillFailureKind::InvalidResponse, "missing 'url' arg")
82 })?;
83
84 if !self.is_allowed(url).await {
85 warn!(url = %url, "http_fetch: host not on allowlist");
86 return Err(SkillExecutionError::new(
87 SkillFailureKind::PermissionDenied,
88 "host not allowed",
89 ));
90 }
91
92 let method_str = invocation.args["method"].as_str().unwrap_or("GET");
93 let method: reqwest::Method = method_str.parse().map_err(|_| {
94 SkillExecutionError::new(
95 SkillFailureKind::InvalidResponse,
96 format!("invalid method: {method_str}"),
97 )
98 })?;
99
100 let mut builder = self.client.request(method, url).timeout(self.timeout);
101
102 if let Some(headers) = invocation.args["headers"].as_object() {
103 for (key, value) in headers {
104 if let Some(val) = value.as_str() {
105 builder = builder.header(key.as_str(), val);
106 }
107 }
108 }
109
110 if let Some(body) = invocation.args["body"].as_str() {
111 builder = builder.body(body.to_string());
112 }
113
114 let response = builder.send().await.map_err(|err| {
115 SkillExecutionError::new(SkillFailureKind::Internal, format!("request failed: {err}"))
116 })?;
117
118 let status = response.status().as_u16();
119 let response_headers: serde_json::Map<String, Value> = response
120 .headers()
121 .iter()
122 .map(|(k, v)| {
123 (
124 k.to_string(),
125 Value::String(v.to_str().unwrap_or_default().to_string()),
126 )
127 })
128 .collect();
129 let body = response.text().await.map_err(|err| {
130 SkillExecutionError::new(
131 SkillFailureKind::Internal,
132 format!("body read failed: {err}"),
133 )
134 })?;
135
136 Ok(json!({
137 "status": status,
138 "headers": response_headers,
139 "body": body,
140 }))
141 }
142
143 fn executor_kind(&self) -> &'static str {
144 "native:http_fetch"
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use rain_engine_core::{
152 AgentContextSnapshot, AgentId, AgentStateSnapshot, EnginePolicy, SkillInvocation,
153 };
154
155 fn invocation(url: &str) -> SkillInvocation {
156 SkillInvocation {
157 call_id: "call-1".to_string(),
158 manifest: manifest(),
159 args: json!({ "url": url }),
160 dry_run: false,
161 context: AgentContextSnapshot {
162 session_id: "session".to_string(),
163 granted_scopes: vec!["tool:run".to_string()],
164 trigger_id: "trigger".to_string(),
165 idempotency_key: None,
166 current_step: 0,
167 max_steps: 1,
168 history: Vec::new(),
169 prior_tool_results: Vec::new(),
170 session_cost_usd: 0.0,
171 state: AgentStateSnapshot {
172 agent_id: AgentId("session".to_string()),
173 profile: None,
174 goals: Vec::new(),
175 tasks: Vec::new(),
176 observations: Vec::new(),
177 artifacts: Vec::new(),
178 resources: Vec::new(),
179 relationships: Vec::new(),
180 pending_wake: None,
181 },
182 policy: EnginePolicy::default(),
183 active_execution_plan: None,
184 },
185 }
186 }
187
188 #[tokio::test]
189 async fn empty_allowlist_denies_by_default() {
190 let skill = HttpFetchSkill::new(std::collections::HashSet::new(), Duration::from_secs(1));
191 let err = skill
192 .execute(invocation("https://example.com"))
193 .await
194 .expect_err("empty allowlist denies");
195 assert_eq!(err.kind, SkillFailureKind::PermissionDenied);
196 }
197}