Skip to main content

rain_engine_skills/
http_fetch.rs

1//! HTTP fetch skill with host allowlist.
2
3use crate::{AccessPolicy, SharedAccessPolicy, shared_access_policy};
4use async_trait::async_trait;
5use rain_engine_core::{
6    NativeSkill, SkillExecutionError, SkillFailureKind, SkillInvocation, SkillManifest,
7};
8use serde_json::{Value, json};
9use std::time::Duration;
10use tracing::warn;
11
12pub struct HttpFetchSkill {
13    client: reqwest::Client,
14    policy: SharedAccessPolicy,
15    timeout: Duration,
16}
17
18impl HttpFetchSkill {
19    /// Create with host allowlist. Empty set = deny all.
20    pub fn new(allowed_hosts: std::collections::HashSet<String>, timeout: Duration) -> Self {
21        Self {
22            client: reqwest::Client::new(),
23            policy: shared_access_policy(allowed_hosts, false),
24            timeout,
25        }
26    }
27
28    pub fn permissive(timeout: Duration) -> Self {
29        Self {
30            client: reqwest::Client::new(),
31            policy: shared_access_policy(std::collections::HashSet::new(), true),
32            timeout,
33        }
34    }
35
36    pub fn with_shared_policy(policy: SharedAccessPolicy, timeout: Duration) -> Self {
37        Self {
38            client: reqwest::Client::new(),
39            policy,
40            timeout,
41        }
42    }
43
44    async fn is_allowed(&self, url: &str) -> bool {
45        let policy = self.policy.read().await;
46        if policy.permissive {
47            return true;
48        }
49        reqwest::Url::parse(url)
50            .ok()
51            .and_then(|parsed: reqwest::Url| parsed.host_str().map(|h| h.to_string()))
52            .is_some_and(|host| policy.allowlist.contains(&host))
53    }
54
55    pub async fn access_policy(&self) -> AccessPolicy {
56        self.policy.read().await.clone()
57    }
58}
59
60pub fn manifest() -> SkillManifest {
61    crate::base_manifest(
62        "http_fetch",
63        "Make an HTTP request and return the response. Hosts must be on the allowlist.",
64        json!({
65            "type": "object",
66            "properties": {
67                "url": { "type": "string", "description": "The URL to fetch" },
68                "method": { "type": "string", "description": "HTTP method (GET, POST, etc.)", "default": "GET" },
69                "headers": { "type": "object", "description": "Optional headers" },
70                "body": { "type": "string", "description": "Optional request body" }
71            },
72            "required": ["url"]
73        }),
74    )
75}
76
77#[async_trait]
78impl NativeSkill for HttpFetchSkill {
79    async fn execute(&self, invocation: SkillInvocation) -> Result<Value, SkillExecutionError> {
80        let url = invocation.args["url"].as_str().ok_or_else(|| {
81            SkillExecutionError::new(SkillFailureKind::InvalidResponse, "missing 'url' arg")
82        })?;
83
84        if !self.is_allowed(url).await {
85            warn!(url = %url, "http_fetch: host not on allowlist");
86            return Err(SkillExecutionError::new(
87                SkillFailureKind::PermissionDenied,
88                "host not allowed",
89            ));
90        }
91
92        let method_str = invocation.args["method"].as_str().unwrap_or("GET");
93        let method: reqwest::Method = method_str.parse().map_err(|_| {
94            SkillExecutionError::new(
95                SkillFailureKind::InvalidResponse,
96                format!("invalid method: {method_str}"),
97            )
98        })?;
99
100        let mut builder = self.client.request(method, url).timeout(self.timeout);
101
102        if let Some(headers) = invocation.args["headers"].as_object() {
103            for (key, value) in headers {
104                if let Some(val) = value.as_str() {
105                    builder = builder.header(key.as_str(), val);
106                }
107            }
108        }
109
110        if let Some(body) = invocation.args["body"].as_str() {
111            builder = builder.body(body.to_string());
112        }
113
114        let response = builder.send().await.map_err(|err| {
115            SkillExecutionError::new(SkillFailureKind::Internal, format!("request failed: {err}"))
116        })?;
117
118        let status = response.status().as_u16();
119        let response_headers: serde_json::Map<String, Value> = response
120            .headers()
121            .iter()
122            .map(|(k, v)| {
123                (
124                    k.to_string(),
125                    Value::String(v.to_str().unwrap_or_default().to_string()),
126                )
127            })
128            .collect();
129        let body = response.text().await.map_err(|err| {
130            SkillExecutionError::new(
131                SkillFailureKind::Internal,
132                format!("body read failed: {err}"),
133            )
134        })?;
135
136        Ok(json!({
137            "status": status,
138            "headers": response_headers,
139            "body": body,
140        }))
141    }
142
143    fn executor_kind(&self) -> &'static str {
144        "native:http_fetch"
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use rain_engine_core::{
152        AgentContextSnapshot, AgentId, AgentStateSnapshot, EnginePolicy, SkillInvocation,
153    };
154
155    fn invocation(url: &str) -> SkillInvocation {
156        SkillInvocation {
157            call_id: "call-1".to_string(),
158            manifest: manifest(),
159            args: json!({ "url": url }),
160            dry_run: false,
161            context: AgentContextSnapshot {
162                session_id: "session".to_string(),
163                granted_scopes: vec!["tool:run".to_string()],
164                trigger_id: "trigger".to_string(),
165                idempotency_key: None,
166                current_step: 0,
167                max_steps: 1,
168                history: Vec::new(),
169                prior_tool_results: Vec::new(),
170                session_cost_usd: 0.0,
171                state: AgentStateSnapshot {
172                    agent_id: AgentId("session".to_string()),
173                    profile: None,
174                    goals: Vec::new(),
175                    tasks: Vec::new(),
176                    observations: Vec::new(),
177                    artifacts: Vec::new(),
178                    resources: Vec::new(),
179                    relationships: Vec::new(),
180                    pending_wake: None,
181                },
182                policy: EnginePolicy::default(),
183                active_execution_plan: None,
184            },
185        }
186    }
187
188    #[tokio::test]
189    async fn empty_allowlist_denies_by_default() {
190        let skill = HttpFetchSkill::new(std::collections::HashSet::new(), Duration::from_secs(1));
191        let err = skill
192            .execute(invocation("https://example.com"))
193            .await
194            .expect_err("empty allowlist denies");
195        assert_eq!(err.kind, SkillFailureKind::PermissionDenied);
196    }
197}