kindly_guard_server/security/
hardening.rs

1// Copyright 2025 Kindly Software Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Security hardening measures for production deployment
15
16use anyhow::{bail, Result};
17use parking_lot::RwLock;
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tracing::{error, info, warn};
21
22/// Rate limiter for command execution
23pub struct CommandRateLimiter {
24    limits: Arc<RwLock<RateLimitState>>,
25}
26
27#[derive(Debug)]
28struct RateLimitState {
29    /// Command counts per window
30    command_counts: std::collections::HashMap<String, WindowedCounter>,
31    /// Global rate limit
32    global_counter: WindowedCounter,
33}
34
35#[derive(Debug)]
36struct WindowedCounter {
37    count: u64,
38    window_start: Instant,
39    window_duration: Duration,
40    max_count: u64,
41}
42
43impl WindowedCounter {
44    fn new(max_count: u64, window_duration: Duration) -> Self {
45        Self {
46            count: 0,
47            window_start: Instant::now(),
48            window_duration,
49            max_count,
50        }
51    }
52
53    fn check_and_increment(&mut self) -> Result<()> {
54        // Reset window if expired
55        if self.window_start.elapsed() > self.window_duration {
56            self.count = 0;
57            self.window_start = Instant::now();
58        }
59
60        if self.count >= self.max_count {
61            bail!(
62                "Rate limit exceeded: {} requests per {:?}",
63                self.max_count,
64                self.window_duration
65            );
66        }
67
68        self.count += 1;
69        Ok(())
70    }
71}
72
73impl Default for CommandRateLimiter {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl CommandRateLimiter {
80    pub fn new() -> Self {
81        Self {
82            limits: Arc::new(RwLock::new(RateLimitState {
83                command_counts: std::collections::HashMap::new(),
84                global_counter: WindowedCounter::new(100, Duration::from_secs(60)),
85            })),
86        }
87    }
88
89    /// Check if command is allowed
90    pub fn check_command(&self, command: &str) -> Result<()> {
91        let mut state = self.limits.write();
92
93        // Check global limit
94        state.global_counter.check_and_increment()?;
95
96        // Check per-command limit
97        let limit = match command {
98            "scan" => (10, Duration::from_secs(60)), // 10 scans per minute
99            "dashboard" => (5, Duration::from_secs(300)), // 5 dashboard starts per 5 minutes
100            "status" => (60, Duration::from_secs(60)), // 60 status checks per minute
101            _ => (30, Duration::from_secs(60)),      // Default: 30 per minute
102        };
103
104        let counter = state
105            .command_counts
106            .entry(command.to_string())
107            .or_insert_with(|| WindowedCounter::new(limit.0, limit.1));
108
109        counter.check_and_increment()
110    }
111}
112
113/// Resource usage monitor
114pub struct ResourceMonitor {
115    max_memory_mb: usize,
116    #[allow(dead_code)] // CPU monitoring planned for future release
117    max_cpu_percent: f32,
118}
119
120impl Default for ResourceMonitor {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl ResourceMonitor {
127    pub const fn new() -> Self {
128        Self {
129            max_memory_mb: 512,    // 512MB max
130            max_cpu_percent: 80.0, // 80% CPU max
131        }
132    }
133
134    /// Check current resource usage
135    pub fn check_resources(&self) -> Result<()> {
136        // Get current memory usage
137        let memory_usage = self.get_memory_usage_mb();
138        if memory_usage > self.max_memory_mb {
139            bail!(
140                "Memory usage too high: {}MB (max: {}MB)",
141                memory_usage,
142                self.max_memory_mb
143            );
144        }
145
146        // Note: CPU usage checking would require platform-specific code
147        // For now, we'll just log memory usage
148        if memory_usage > self.max_memory_mb * 80 / 100 {
149            warn!("Memory usage approaching limit: {}MB", memory_usage);
150        }
151
152        Ok(())
153    }
154
155    const fn get_memory_usage_mb(&self) -> usize {
156        // Simple approximation using jemalloc stats if available
157        // In production, use proper system monitoring
158
159        // This is a placeholder - in production use proper metrics
160        50 // Return dummy value for now
161    }
162}
163
164/// Security context for command execution
165#[derive(Clone)]
166pub struct SecurityContext {
167    /// User identifier (for future authentication)
168    pub user_id: Option<String>,
169    /// Source of command (cli, api, etc)
170    pub source: CommandSource,
171    /// Timestamp
172    pub timestamp: chrono::DateTime<chrono::Utc>,
173    /// Request ID for correlation
174    pub request_id: String,
175    /// Neutralization tracking
176    pub neutralization: NeutralizationContext,
177}
178
179impl SecurityContext {
180    /// Create a new security context
181    pub fn new(source: CommandSource) -> Self {
182        Self {
183            user_id: None,
184            source,
185            timestamp: chrono::Utc::now(),
186            request_id: uuid::Uuid::new_v4().to_string(),
187            neutralization: NeutralizationContext::default(),
188        }
189    }
190
191    /// Create with user ID
192    pub fn with_user(mut self, user_id: String) -> Self {
193        self.user_id = Some(user_id);
194        self
195    }
196
197    /// Set neutralization mode
198    pub const fn with_neutralization_mode(mut self, mode: NeutralizationMode) -> Self {
199        self.neutralization.mode = mode;
200        self
201    }
202
203    /// Enable enhanced mode
204    pub const fn with_enhanced_mode(mut self, enhanced: bool) -> Self {
205        self.neutralization.enhanced_mode = enhanced;
206        self
207    }
208
209    /// Record neutralization result
210    pub fn record_neutralization(&mut self, success: bool) {
211        if success {
212            self.neutralization.record_success();
213        } else {
214            self.neutralization.record_failure();
215        }
216    }
217
218    /// Check if neutralization should be attempted based on context
219    pub const fn should_neutralize(&self) -> bool {
220        match self.neutralization.mode {
221            NeutralizationMode::Automatic => true,
222            NeutralizationMode::Interactive => self.neutralization.auto_neutralize,
223            NeutralizationMode::ReportOnly => false,
224        }
225    }
226}
227
228#[derive(Debug, Clone)]
229pub enum CommandSource {
230    Cli,
231    WebDashboard,
232    Api,
233    Unknown,
234}
235
236/// Neutralization context for tracking threat mitigation
237#[derive(Debug, Clone)]
238pub struct NeutralizationContext {
239    /// Total threats neutralized in this context
240    pub threats_neutralized: u32,
241    /// Failed neutralization attempts
242    pub neutralization_failures: u32,
243    /// Whether automatic neutralization is enabled
244    pub auto_neutralize: bool,
245    /// Neutralization mode
246    pub mode: NeutralizationMode,
247    /// Performance mode (standard vs enhanced)
248    pub enhanced_mode: bool,
249    /// Last neutralization timestamp
250    pub last_neutralization: Option<chrono::DateTime<chrono::Utc>>,
251}
252
253impl Default for NeutralizationContext {
254    fn default() -> Self {
255        Self {
256            threats_neutralized: 0,
257            neutralization_failures: 0,
258            auto_neutralize: false,
259            mode: NeutralizationMode::ReportOnly,
260            enhanced_mode: false,
261            last_neutralization: None,
262        }
263    }
264}
265
266impl NeutralizationContext {
267    /// Record successful neutralization
268    pub fn record_success(&mut self) {
269        self.threats_neutralized += 1;
270        self.last_neutralization = Some(chrono::Utc::now());
271    }
272
273    /// Record failed neutralization
274    pub const fn record_failure(&mut self) {
275        self.neutralization_failures += 1;
276    }
277
278    /// Get neutralization success rate
279    pub fn success_rate(&self) -> f64 {
280        let total = self.threats_neutralized + self.neutralization_failures;
281        if total == 0 {
282            1.0
283        } else {
284            f64::from(self.threats_neutralized) / f64::from(total)
285        }
286    }
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq)]
290pub enum NeutralizationMode {
291    ReportOnly,
292    Interactive,
293    Automatic,
294}
295
296/// Audit logger for security events
297pub struct SecurityAuditLogger {
298    log_path: Option<std::path::PathBuf>,
299}
300
301impl SecurityAuditLogger {
302    pub const fn new(log_path: Option<std::path::PathBuf>) -> Self {
303        Self { log_path }
304    }
305
306    /// Log command execution
307    pub fn log_command(
308        &self,
309        context: &SecurityContext,
310        command: &str,
311        args: &serde_json::Value,
312        result: &Result<()>,
313    ) {
314        let event = serde_json::json!({
315            "timestamp": context.timestamp,
316            "request_id": context.request_id,
317            "user_id": context.user_id,
318            "source": format!("{:?}", context.source),
319            "command": command,
320            "args": args,
321            "success": result.is_ok(),
322            "error": result.as_ref().err().map(std::string::ToString::to_string),
323        });
324
325        // Log to tracing
326        if let Ok(()) = result {
327            info!(event = %event, "Command executed")
328        } else {
329            warn!(event = %event, "Command failed")
330        }
331
332        // Write to audit file if configured
333        if let Some(ref path) = self.log_path {
334            if let Err(e) = self.write_to_file(path, &event) {
335                error!("Failed to write audit log: {}", e);
336            }
337        }
338    }
339
340    fn write_to_file(&self, path: &std::path::Path, event: &serde_json::Value) -> Result<()> {
341        use std::fs::OpenOptions;
342        use std::io::Write;
343
344        let mut file = OpenOptions::new().create(true).append(true).open(path)?;
345
346        writeln!(file, "{event}")?;
347        Ok(())
348    }
349}
350
351/// Sandbox for file operations
352pub struct FileSandbox {
353    allowed_paths: Vec<std::path::PathBuf>,
354}
355
356impl FileSandbox {
357    pub const fn new(allowed_paths: Vec<std::path::PathBuf>) -> Self {
358        Self { allowed_paths }
359    }
360
361    /// Check if path access is allowed
362    pub fn check_path(&self, path: &std::path::Path) -> Result<()> {
363        let canonical = path
364            .canonicalize()
365            .map_err(|e| anyhow::anyhow!("Invalid path: {}", e))?;
366
367        // Check if path is under any allowed directory
368        for allowed in &self.allowed_paths {
369            if canonical.starts_with(allowed) {
370                return Ok(());
371            }
372        }
373
374        bail!(
375            "Access denied: path '{}' is outside allowed directories",
376            path.display()
377        );
378    }
379}
380
381/// Command injection prevention
382pub mod injection {
383    use super::{bail, Result};
384
385    use regex::Regex;
386
387    // Patterns that might indicate command injection
388    static DANGEROUS_PATTERNS: std::sync::LazyLock<Vec<Regex>> = std::sync::LazyLock::new(|| {
389        vec![
390            Regex::new(r"[;&|]").unwrap(),    // Command separators
391            Regex::new(r"\$\(.*\)").unwrap(), // Command substitution
392            Regex::new(r"`.*`").unwrap(),     // Backticks
393            Regex::new(r"<<.*>>").unwrap(),   // Heredoc
394            Regex::new(r"[<>]").unwrap(),     // Redirections
395        ]
396    });
397
398    /// Check for potential command injection
399    pub fn check_command_injection(input: &str) -> Result<()> {
400        for pattern in DANGEROUS_PATTERNS.iter() {
401            if pattern.is_match(input) {
402                bail!("Potential command injection detected");
403            }
404        }
405        Ok(())
406    }
407}
408
409/// Information disclosure prevention  
410pub mod info_disclosure {
411
412    /// Sanitize error messages for production
413    pub fn sanitize_error(error: anyhow::Error) -> String {
414        // In production, hide internal details
415        if cfg!(debug_assertions) {
416            format!("{error:#}")
417        } else {
418            // Generic messages for production
419            match error.to_string().to_lowercase() {
420                s if s.contains("permission") => "Permission denied".to_string(),
421                s if s.contains("not found") => "Resource not found".to_string(),
422                s if s.contains("timeout") => "Operation timed out".to_string(),
423                s if s.contains("rate limit") => "Rate limit exceeded".to_string(),
424                _ => "An error occurred. Please try again.".to_string(),
425            }
426        }
427    }
428
429    /// Mask sensitive configuration values
430    pub fn mask_sensitive(key: &str, value: &str) -> String {
431        let sensitive_keys = ["password", "token", "secret", "key", "auth"];
432
433        if sensitive_keys
434            .iter()
435            .any(|&k| key.to_lowercase().contains(k))
436        {
437            "***MASKED***".to_string()
438        } else {
439            value.to_string()
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_rate_limiter() {
450        let limiter = CommandRateLimiter::new();
451
452        // Should allow initial requests
453        for _ in 0..5 {
454            assert!(limiter.check_command("scan").is_ok());
455        }
456
457        // Should eventually hit limit
458        let mut hit_limit = false;
459        for _ in 0..20 {
460            if limiter.check_command("scan").is_err() {
461                hit_limit = true;
462                break;
463            }
464        }
465        assert!(hit_limit);
466    }
467
468    #[test]
469    fn test_command_injection_detection() {
470        use injection::check_command_injection;
471
472        // Safe inputs
473        assert!(check_command_injection("normal text").is_ok());
474        assert!(check_command_injection("/path/to/file.txt").is_ok());
475
476        // Dangerous inputs
477        assert!(check_command_injection("test; rm -rf /").is_err());
478        assert!(check_command_injection("$(cat /etc/passwd)").is_err());
479        assert!(check_command_injection("`whoami`").is_err());
480        assert!(check_command_injection("test > /dev/null").is_err());
481    }
482
483    #[test]
484    fn test_file_sandbox() {
485        use tempfile::tempdir;
486
487        // Create temp directory
488        let temp_dir = tempdir().unwrap();
489        let allowed_path = temp_dir.path().to_path_buf();
490
491        // Create a test file
492        let test_file = allowed_path.join("test.txt");
493        std::fs::write(&test_file, "test").unwrap();
494
495        let sandbox = FileSandbox::new(vec![allowed_path.clone()]);
496
497        // Allowed paths
498        assert!(sandbox.check_path(&test_file).is_ok());
499
500        // Disallowed paths (outside sandbox)
501        let outside_path = std::env::temp_dir().join("outside.txt");
502        std::fs::write(&outside_path, "test").unwrap();
503
504        // Only test if the outside path is actually outside our sandbox
505        if !outside_path.starts_with(&allowed_path) {
506            assert!(sandbox.check_path(&outside_path).is_err());
507        }
508
509        // Cleanup
510        let _ = std::fs::remove_file(outside_path);
511    }
512
513    #[test]
514    fn test_info_disclosure_prevention() {
515        use info_disclosure::{mask_sensitive, sanitize_error};
516
517        // Error sanitization - in debug mode it shows full error, in release it's generic
518        let error = anyhow::anyhow!("Connection to database at 192.168.1.1:5432 failed");
519        let sanitized = sanitize_error(error);
520
521        if cfg!(debug_assertions) {
522            // In debug mode, we get the full error
523            assert!(sanitized.contains("database"));
524        } else {
525            // In release mode, we get generic message
526            assert_eq!(sanitized, "An error occurred. Please try again.");
527        }
528
529        // Test known error patterns
530        let perm_error = anyhow::anyhow!("Permission denied for user");
531        assert!(sanitize_error(perm_error).contains("Permission"));
532
533        // Sensitive value masking
534        assert_eq!(mask_sensitive("password", "secret123"), "***MASKED***");
535        assert_eq!(mask_sensitive("api_token", "xyz"), "***MASKED***");
536        assert_eq!(mask_sensitive("username", "john"), "john");
537    }
538}