Skip to main content

nika_engine/runtime/
policy.rs

1//! Policy Enforcer - Security policy enforcement
2//!
3//! Enforces allow/block rules for:
4//! - Shell commands (exec: verb)
5//! - Network access (fetch: verb)
6//! - Token budget limits
7//! - Host restrictions
8
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
10
11use crate::error::NikaError;
12use crate::runtime::boot::PolicyConfig;
13use url::Url;
14
15/// Exact-match SSRF blocklist: cloud metadata hostnames and special addresses.
16///
17/// These are ALWAYS blocked regardless of user configuration.
18/// Cloud metadata services (AWS/GCP/Alibaba) and special addresses are
19/// common SSRF targets that should never be reachable from workflow fetch: verbs.
20const SSRF_BLOCKED_EXACT: &[&str] = &["metadata.google.internal", "localhost", "0.0.0.0"];
21
22/// Check whether a host (already lowercased, brackets stripped) is SSRF-blocked.
23///
24/// 1. Exact-match against known hostnames (metadata.google.internal, localhost, 0.0.0.0).
25/// 2. Parse as IP and check private/reserved CIDR ranges:
26///    - 127.0.0.0/8       (loopback)
27///    - 10.0.0.0/8        (private class A)
28///    - 172.16.0.0/12     (private class B)
29///    - 192.168.0.0/16    (private class C)
30///    - 169.254.0.0/16    (link-local, includes AWS metadata 169.254.169.254)
31///    - 100.64.0.0/10     (CGN / shared, includes Alibaba 100.100.100.200)
32///    - ::1               (IPv6 loopback)
33///    - ::ffff:0:0/96     (IPv4-mapped IPv6 — re-checks the inner v4 address)
34pub(crate) fn is_ssrf_blocked(host: &str) -> bool {
35    // 1. Exact hostname match
36    if SSRF_BLOCKED_EXACT.contains(&host) {
37        return true;
38    }
39
40    // 2. Try to parse as IP
41    let ip: IpAddr = match host.parse() {
42        Ok(addr) => addr,
43        Err(_) => return false, // Not an IP, and not in exact list — allow
44    };
45
46    match ip {
47        IpAddr::V4(v4) => is_blocked_v4(v4),
48        IpAddr::V6(v6) => {
49            // IPv6 loopback
50            if v6 == Ipv6Addr::LOCALHOST {
51                return true;
52            }
53            // IPv4-mapped IPv6 (::ffff:a.b.c.d) — extract and re-check inner v4
54            if let Some(mapped) = v6.to_ipv4_mapped() {
55                return is_blocked_v4(mapped);
56            }
57            false
58        }
59    }
60}
61
62/// Check an IPv4 address against blocked private/reserved ranges.
63fn is_blocked_v4(v4: Ipv4Addr) -> bool {
64    let octets = v4.octets();
65    // 127.0.0.0/8 — loopback
66    if octets[0] == 127 {
67        return true;
68    }
69    // 10.0.0.0/8 — private class A
70    if octets[0] == 10 {
71        return true;
72    }
73    // 172.16.0.0/12 — private class B (172.16.x.x – 172.31.x.x)
74    if octets[0] == 172 && (16..=31).contains(&octets[1]) {
75        return true;
76    }
77    // 192.168.0.0/16 — private class C
78    if octets[0] == 192 && octets[1] == 168 {
79        return true;
80    }
81    // 169.254.0.0/16 — link-local (covers AWS metadata 169.254.169.254)
82    if octets[0] == 169 && octets[1] == 254 {
83        return true;
84    }
85    // 100.64.0.0/10 — CGN / shared address space (covers Alibaba 100.100.100.200)
86    // 100.64.0.0 – 100.127.255.255
87    if octets[0] == 100 && (64..=127).contains(&octets[1]) {
88        return true;
89    }
90    // 0.0.0.0
91    if v4 == Ipv4Addr::UNSPECIFIED {
92        return true;
93    }
94    false
95}
96
97/// Policy enforcement result
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub enum PolicyDecision {
100    /// Action is allowed
101    Allow,
102    /// Action is blocked with reason
103    Block(String),
104    /// Action requires user confirmation
105    RequiresApproval(String),
106}
107
108impl PolicyDecision {
109    pub fn is_allowed(&self) -> bool {
110        matches!(self, Self::Allow)
111    }
112
113    pub fn is_blocked(&self) -> bool {
114        matches!(self, Self::Block(_))
115    }
116}
117
118/// Token budget tracker
119#[derive(Debug, Clone, Default)]
120pub struct TokenBudget {
121    pub limit: Option<u64>,
122    pub used: u64,
123}
124
125impl TokenBudget {
126    pub fn new(limit: Option<u64>) -> Self {
127        Self { limit, used: 0 }
128    }
129
130    /// Check if spending tokens would exceed budget
131    pub fn can_spend(&self, tokens: u64) -> bool {
132        match self.limit {
133            Some(limit) => self
134                .used
135                .checked_add(tokens)
136                .is_some_and(|total| total <= limit),
137            None => true,
138        }
139    }
140
141    /// Record token usage (saturating to prevent u64 overflow)
142    pub fn spend(&mut self, tokens: u64) {
143        self.used = self.used.saturating_add(tokens);
144    }
145
146    /// Remaining budget
147    pub fn remaining(&self) -> Option<u64> {
148        self.limit.map(|l| l.saturating_sub(self.used))
149    }
150}
151
152/// Policy enforcer instance
153#[derive(Debug, Clone)]
154pub struct PolicyEnforcer {
155    config: PolicyConfig,
156    token_budget: TokenBudget,
157}
158
159impl Default for PolicyEnforcer {
160    fn default() -> Self {
161        Self::new(PolicyConfig::default())
162    }
163}
164
165impl PolicyEnforcer {
166    /// Create a new policy enforcer with configuration
167    pub fn new(config: PolicyConfig) -> Self {
168        let token_budget = TokenBudget::new(config.max_token_spend);
169        Self {
170            config,
171            token_budget,
172        }
173    }
174
175    /// Check if the exec verb is allowed for a command
176    pub fn check_exec(&self, command: &str) -> PolicyDecision {
177        // Check if exec is globally disabled
178        if !self.config.allow_exec {
179            return PolicyDecision::Block("exec: verb is disabled by policy".into());
180        }
181
182        // Check for blocked command patterns
183        let command_lower = command.to_lowercase();
184        for blocked in &self.config.blocked_commands {
185            if command_lower.contains(&blocked.to_lowercase()) {
186                return PolicyDecision::Block(format!(
187                    "Command contains blocked pattern: '{}'",
188                    blocked
189                ));
190            }
191        }
192
193        PolicyDecision::Allow
194    }
195
196    /// Check if fetch: verb is allowed for a URL
197    pub fn check_fetch(&self, url: &str) -> PolicyDecision {
198        // Check if network is globally disabled
199        if !self.config.allow_network {
200            return PolicyDecision::Block(
201                "fetch: verb (network access) is disabled by policy".into(),
202            );
203        }
204
205        // Parse URL to check host — fail-closed on invalid URLs
206        let parsed = match Url::parse(url) {
207            Ok(u) => u,
208            Err(_) => {
209                return PolicyDecision::Block(format!(
210                    "Unparseable URL rejected (fail-closed): '{}'",
211                    url
212                ));
213            }
214        };
215
216        let host = match parsed.host_str() {
217            Some(h) => h.to_lowercase(),
218            None => {
219                return PolicyDecision::Block(format!("URL has no host (fail-closed): '{}'", url));
220            }
221        };
222
223        // Normalize IPv6: url crate returns "[::1]" with brackets, blocklist uses "::1"
224        let host_normalized = host.trim_start_matches('[').trim_end_matches(']');
225
226        // SSRF protection: block cloud metadata, loopback, and private ranges.
227        // Exception: explicit allowed_hosts override SSRF blocklist (for testing, local services)
228        let explicitly_allowed = self
229            .config
230            .allowed_hosts
231            .iter()
232            .any(|allowed| host_normalized == allowed.to_lowercase());
233        if !explicitly_allowed && is_ssrf_blocked(host_normalized) {
234            return PolicyDecision::Block(format!(
235                "SSRF protection: access to '{}' is blocked",
236                host
237            ));
238        }
239
240        // Check blocked hosts first (takes precedence).
241        // Uses proper domain-suffix matching: "evil.com" blocks "evil.com" and
242        // "sub.evil.com" but NOT "not-evil.com".
243        for blocked in &self.config.blocked_hosts {
244            let blocked_lower = blocked.to_lowercase();
245            if host == blocked_lower || host.ends_with(&format!(".{}", blocked_lower)) {
246                return PolicyDecision::Block(format!("Host '{}' is blocked by policy", host));
247            }
248        }
249
250        // If allowed_hosts is non-empty, only those hosts are allowed.
251        // Uses proper domain-suffix matching: "openai.com" allows "openai.com"
252        // and "api.openai.com" but NOT "openai.com.evil.com".
253        if !self.config.allowed_hosts.is_empty() {
254            let is_allowed = self.config.allowed_hosts.iter().any(|allowed| {
255                let allowed_lower = allowed.to_lowercase();
256                host == allowed_lower || host.ends_with(&format!(".{}", allowed_lower))
257            });
258            if !is_allowed {
259                return PolicyDecision::Block(format!(
260                    "Host '{}' is not in allowed hosts list",
261                    host
262                ));
263            }
264        }
265
266        PolicyDecision::Allow
267    }
268
269    /// Check if token spend is within budget
270    pub fn check_token_spend(&self, tokens: u64) -> PolicyDecision {
271        if !self.token_budget.can_spend(tokens) {
272            let remaining = self.token_budget.remaining().unwrap_or(0);
273            return PolicyDecision::Block(format!(
274                "Token budget exceeded: requested {} but only {} remaining",
275                tokens, remaining
276            ));
277        }
278        PolicyDecision::Allow
279    }
280
281    /// Atomically reserve tokens from the budget (check + spend in one call).
282    ///
283    /// Prevents TOCTOU races where concurrent for_each tasks all pass the
284    /// check before any records spending.
285    pub fn reserve_tokens(&mut self, estimated: u64) -> Result<(), String> {
286        if !self.token_budget.can_spend(estimated) {
287            return Err(format!(
288                "Token budget exceeded: {} used + {} estimated > {} limit",
289                self.token_budget.used,
290                estimated,
291                self.token_budget.limit.unwrap_or(u64::MAX),
292            ));
293        }
294        self.token_budget.spend(estimated);
295        Ok(())
296    }
297
298    /// Adjust a previous reservation to match actual token usage.
299    pub fn adjust_reservation(&mut self, estimated: u64, actual: u64) {
300        if actual < estimated {
301            self.token_budget.used = self.token_budget.used.saturating_sub(estimated - actual);
302        } else if actual > estimated {
303            self.token_budget.spend(actual - estimated);
304        }
305    }
306
307    /// Record token usage
308    pub fn record_token_spend(&mut self, tokens: u64) {
309        self.token_budget.spend(tokens);
310    }
311
312    /// Get remaining token budget
313    pub fn remaining_budget(&self) -> Option<u64> {
314        self.token_budget.remaining()
315    }
316
317    /// Get total tokens used
318    pub fn tokens_used(&self) -> u64 {
319        self.token_budget.used
320    }
321
322    /// Convert policy decision to result
323    pub fn enforce(&self, decision: PolicyDecision) -> Result<(), NikaError> {
324        match decision {
325            PolicyDecision::Allow => Ok(()),
326            PolicyDecision::Block(reason) => Err(NikaError::PolicyViolation { reason }),
327            PolicyDecision::RequiresApproval(reason) => {
328                // For now, treat as block. HITL integration can handle approval flow.
329                Err(NikaError::PolicyViolation {
330                    reason: format!("Requires approval: {}", reason),
331                })
332            }
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_default_policy_allows_exec() {
343        let enforcer = PolicyEnforcer::default();
344        assert!(enforcer.check_exec("ls -la").is_allowed());
345    }
346
347    #[test]
348    fn test_policy_blocks_dangerous_commands() {
349        let enforcer = PolicyEnforcer::default();
350
351        // Default blocked commands
352        assert!(enforcer.check_exec("sudo apt install").is_blocked());
353        assert!(enforcer.check_exec("rm -rf /").is_blocked());
354        assert!(enforcer.check_exec("chmod 777 /etc").is_blocked());
355
356        // Safe commands allowed
357        assert!(enforcer.check_exec("echo hello").is_allowed());
358        assert!(enforcer.check_exec("npm run build").is_allowed());
359    }
360
361    #[test]
362    fn test_policy_disables_exec() {
363        let config = PolicyConfig {
364            allow_exec: false,
365            ..Default::default()
366        };
367        let enforcer = PolicyEnforcer::new(config);
368
369        assert!(enforcer.check_exec("echo hello").is_blocked());
370    }
371
372    #[test]
373    fn test_default_policy_allows_fetch() {
374        let enforcer = PolicyEnforcer::default();
375        assert!(enforcer
376            .check_fetch("https://api.example.com/data")
377            .is_allowed());
378    }
379
380    #[test]
381    fn test_policy_disables_network() {
382        let config = PolicyConfig {
383            allow_network: false,
384            ..Default::default()
385        };
386        let enforcer = PolicyEnforcer::new(config);
387
388        assert!(enforcer.check_fetch("https://example.com").is_blocked());
389    }
390
391    #[test]
392    fn test_policy_blocks_hosts() {
393        let config = PolicyConfig {
394            blocked_hosts: vec!["evil.com".into(), "malware.io".into()],
395            ..Default::default()
396        };
397        let enforcer = PolicyEnforcer::new(config);
398
399        assert!(enforcer.check_fetch("https://evil.com/path").is_blocked());
400        assert!(enforcer
401            .check_fetch("https://sub.evil.com/path")
402            .is_blocked());
403        assert!(enforcer.check_fetch("https://malware.io/api").is_blocked());
404        assert!(enforcer.check_fetch("https://api.example.com").is_allowed());
405    }
406
407    #[test]
408    fn test_policy_allowed_hosts_whitelist() {
409        let config = PolicyConfig {
410            allowed_hosts: vec!["api.openai.com".into(), "anthropic.com".into()],
411            ..Default::default()
412        };
413        let enforcer = PolicyEnforcer::new(config);
414
415        assert!(enforcer
416            .check_fetch("https://api.openai.com/v1")
417            .is_allowed());
418        assert!(enforcer
419            .check_fetch("https://anthropic.com/api")
420            .is_allowed());
421        assert!(enforcer.check_fetch("https://other.com/api").is_blocked());
422    }
423
424    #[test]
425    fn test_token_budget_unlimited() {
426        let budget = TokenBudget::new(None);
427        assert!(budget.can_spend(1_000_000));
428        assert!(budget.remaining().is_none());
429    }
430
431    #[test]
432    fn test_token_budget_limited() {
433        let mut budget = TokenBudget::new(Some(10000));
434        assert!(budget.can_spend(5000));
435        budget.spend(5000);
436        assert_eq!(budget.used, 5000);
437        assert_eq!(budget.remaining(), Some(5000));
438
439        assert!(budget.can_spend(5000));
440        assert!(!budget.can_spend(5001));
441    }
442
443    #[test]
444    fn test_enforcer_token_budget() {
445        let config = PolicyConfig {
446            max_token_spend: Some(1000),
447            ..Default::default()
448        };
449        let mut enforcer = PolicyEnforcer::new(config);
450
451        assert!(enforcer.check_token_spend(500).is_allowed());
452        enforcer.record_token_spend(500);
453
454        assert!(enforcer.check_token_spend(500).is_allowed());
455        enforcer.record_token_spend(500);
456
457        // Now at limit
458        assert!(enforcer.check_token_spend(1).is_blocked());
459        assert_eq!(enforcer.remaining_budget(), Some(0));
460    }
461
462    #[test]
463    fn test_policy_decision_properties() {
464        let allow = PolicyDecision::Allow;
465        let block = PolicyDecision::Block("reason".into());
466
467        assert!(allow.is_allowed());
468        assert!(!allow.is_blocked());
469        assert!(block.is_blocked());
470        assert!(!block.is_allowed());
471    }
472
473    // =========================================================================
474    // Regression: Bug 18 — unparseable URLs must be blocked (fail-closed)
475    // =========================================================================
476
477    #[test]
478    fn test_policy_blocks_unparseable_url() {
479        let enforcer = PolicyEnforcer::default();
480        let decision = enforcer.check_fetch("not a url at all %%%");
481        assert!(
482            decision.is_blocked(),
483            "Unparseable URL should be blocked (fail-closed), got: {:?}",
484            decision
485        );
486    }
487
488    #[test]
489    fn test_policy_blocks_url_without_host() {
490        let enforcer = PolicyEnforcer::default();
491        // data: URIs have no host
492        let decision = enforcer.check_fetch("data:text/html,<script>alert(1)</script>");
493        assert!(
494            decision.is_blocked(),
495            "URL without host should be blocked (fail-closed), got: {:?}",
496            decision
497        );
498    }
499
500    #[test]
501    fn test_policy_still_allows_valid_urls() {
502        let enforcer = PolicyEnforcer::default();
503        assert!(enforcer.check_fetch("https://example.com/api").is_allowed());
504    }
505
506    // =========================================================================
507    // SSRF protection: cloud metadata + loopback always blocked
508    // =========================================================================
509
510    #[test]
511    fn test_ssrf_blocks_cloud_metadata() {
512        let enforcer = PolicyEnforcer::default();
513
514        // AWS/GCP metadata endpoint
515        assert!(enforcer
516            .check_fetch("http://169.254.169.254/latest/meta-data/")
517            .is_blocked());
518        // GCP internal DNS
519        assert!(enforcer
520            .check_fetch("http://metadata.google.internal/computeMetadata/v1/")
521            .is_blocked());
522        // Alibaba metadata
523        assert!(enforcer
524            .check_fetch("http://100.100.100.200/latest/meta-data/")
525            .is_blocked());
526    }
527
528    #[test]
529    fn test_ssrf_blocks_loopback() {
530        let enforcer = PolicyEnforcer::default();
531
532        assert!(enforcer.check_fetch("http://localhost:8080").is_blocked());
533        assert!(enforcer
534            .check_fetch("http://127.0.0.1:3000/api")
535            .is_blocked());
536        assert!(enforcer
537            .check_fetch("http://[::1]:9090/health")
538            .is_blocked());
539        assert!(enforcer.check_fetch("http://0.0.0.0/admin").is_blocked());
540    }
541
542    #[test]
543    fn test_ssrf_does_not_block_external_hosts() {
544        let enforcer = PolicyEnforcer::default();
545        assert!(enforcer
546            .check_fetch("https://api.openai.com/v1")
547            .is_allowed());
548        assert!(enforcer.check_fetch("https://example.com").is_allowed());
549    }
550
551    // =========================================================================
552    // H4: SSRF blocks private/reserved IP ranges
553    // =========================================================================
554
555    #[test]
556    fn test_ssrf_blocks_private_ranges() {
557        let enforcer = PolicyEnforcer::default();
558
559        // 10.0.0.0/8
560        assert!(enforcer.check_fetch("http://10.0.0.1/admin").is_blocked());
561        assert!(enforcer.check_fetch("http://10.255.255.255/x").is_blocked());
562
563        // 172.16.0.0/12
564        assert!(enforcer.check_fetch("http://172.16.0.1/api").is_blocked());
565        assert!(enforcer.check_fetch("http://172.31.255.255/x").is_blocked());
566        // 172.15.x.x is NOT private — should be allowed
567        assert!(enforcer.check_fetch("http://172.15.0.1/api").is_allowed());
568        // 172.32.x.x is NOT private — should be allowed
569        assert!(enforcer.check_fetch("http://172.32.0.1/api").is_allowed());
570
571        // 192.168.0.0/16
572        assert!(enforcer
573            .check_fetch("http://192.168.1.1/admin")
574            .is_blocked());
575        assert!(enforcer.check_fetch("http://192.168.0.100/x").is_blocked());
576
577        // 127.0.0.0/8 — full loopback range
578        assert!(enforcer.check_fetch("http://127.0.0.2:8080/x").is_blocked());
579        assert!(enforcer
580            .check_fetch("http://127.255.255.255/x")
581            .is_blocked());
582
583        // 169.254.0.0/16 — link-local
584        assert!(enforcer.check_fetch("http://169.254.0.1/x").is_blocked());
585        assert!(enforcer
586            .check_fetch("http://169.254.169.254/latest")
587            .is_blocked());
588
589        // 100.64.0.0/10 — CGN / shared (covers Alibaba 100.100.100.200)
590        assert!(enforcer.check_fetch("http://100.64.0.1/x").is_blocked());
591        assert!(enforcer
592            .check_fetch("http://100.100.100.200/meta")
593            .is_blocked());
594        assert!(enforcer
595            .check_fetch("http://100.127.255.255/x")
596            .is_blocked());
597        // 100.128.x.x is outside CGN — should be allowed
598        assert!(enforcer.check_fetch("http://100.128.0.1/api").is_allowed());
599    }
600
601    #[test]
602    fn test_ssrf_blocks_ipv6_mapped() {
603        let enforcer = PolicyEnforcer::default();
604
605        // ::ffff:127.0.0.1 (IPv4-mapped loopback)
606        assert!(enforcer
607            .check_fetch("http://[::ffff:127.0.0.1]:8080/x")
608            .is_blocked());
609        // ::ffff:10.0.0.1 (IPv4-mapped private)
610        assert!(enforcer
611            .check_fetch("http://[::ffff:10.0.0.1]/admin")
612            .is_blocked());
613        // ::ffff:192.168.1.1
614        assert!(enforcer
615            .check_fetch("http://[::ffff:192.168.1.1]/x")
616            .is_blocked());
617        // ::ffff:169.254.169.254
618        assert!(enforcer
619            .check_fetch("http://[::ffff:169.254.169.254]/meta")
620            .is_blocked());
621
622        // ::1 (pure IPv6 loopback)
623        assert!(enforcer
624            .check_fetch("http://[::1]:9090/health")
625            .is_blocked());
626    }
627
628    // =========================================================================
629    // H5: Proper domain-suffix matching (no substring bypass)
630    // =========================================================================
631
632    #[test]
633    fn test_host_matching_no_substring_bypass() {
634        // Blocked hosts: should NOT over-block unrelated domains
635        let config = PolicyConfig {
636            blocked_hosts: vec!["evil.com".into()],
637            allowed_hosts: vec![], // no whitelist
638            ..Default::default()
639        };
640        let enforcer = PolicyEnforcer::new(config);
641
642        // "evil.com" and subdomains blocked
643        assert!(enforcer.check_fetch("https://evil.com/x").is_blocked());
644        assert!(enforcer.check_fetch("https://sub.evil.com/x").is_blocked());
645        // "not-evil.com" must NOT be blocked (old substring match would block it)
646        assert!(enforcer.check_fetch("https://not-evil.com/x").is_allowed());
647        // "evil.com.attacker.com" must NOT be blocked
648        assert!(enforcer
649            .check_fetch("https://evil.com.attacker.com/x")
650            .is_allowed());
651
652        // Allowed hosts: should NOT allow spoofed domains
653        let config2 = PolicyConfig {
654            allowed_hosts: vec!["api.openai.com".into()],
655            ..Default::default()
656        };
657        let enforcer2 = PolicyEnforcer::new(config2);
658
659        // Exact match and subdomains allowed
660        assert!(enforcer2
661            .check_fetch("https://api.openai.com/v1")
662            .is_allowed());
663        assert!(enforcer2
664            .check_fetch("https://sub.api.openai.com/v1")
665            .is_allowed());
666        // Attacker domain with allowed host as prefix must be BLOCKED
667        assert!(enforcer2
668            .check_fetch("https://api.openai.com.evil.com/v1")
669            .is_blocked());
670        // Unrelated domain must be blocked
671        assert!(enforcer2.check_fetch("https://other.com/api").is_blocked());
672    }
673}