llm_security/
types.rs

1//! Type definitions for LLM security
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for the LLM security layer
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct LLMSecurityConfig {
8    /// Enable prompt injection detection
9    pub enable_injection_detection: bool,
10
11    /// Enable output validation
12    pub enable_output_validation: bool,
13
14    /// Maximum code size to analyze (prevent DoS)
15    pub max_code_size_bytes: usize,
16
17    /// Block suspicious patterns even if detection is uncertain
18    pub strict_mode: bool,
19
20    /// Log all detected attacks
21    pub log_attacks: bool,
22
23    /// Rate limit for LLM calls per IP
24    pub max_llm_calls_per_hour: u32,
25}
26
27impl Default for LLMSecurityConfig {
28    fn default() -> Self {
29        Self {
30            enable_injection_detection: true,
31            enable_output_validation: true,
32            max_code_size_bytes: crate::constants::DEFAULT_MAX_CODE_SIZE_BYTES,
33            strict_mode: true,
34            log_attacks: true,
35            max_llm_calls_per_hour: crate::constants::DEFAULT_MAX_LLM_CALLS_PER_HOUR,
36        }
37    }
38}
39
40impl LLMSecurityConfig {
41    /// Create a new configuration with custom values
42    pub fn new(
43        enable_injection_detection: bool,
44        enable_output_validation: bool,
45        max_code_size_bytes: usize,
46        strict_mode: bool,
47    ) -> Self {
48        Self {
49            enable_injection_detection,
50            enable_output_validation,
51            max_code_size_bytes,
52            strict_mode,
53            log_attacks: true,
54            max_llm_calls_per_hour: crate::constants::DEFAULT_MAX_LLM_CALLS_PER_HOUR,
55        }
56    }
57
58    /// Create a permissive configuration
59    pub fn permissive() -> Self {
60        Self {
61            enable_injection_detection: false,
62            enable_output_validation: false,
63            max_code_size_bytes: crate::constants::DEFAULT_MAX_CODE_SIZE_BYTES,
64            strict_mode: false,
65            log_attacks: false,
66            max_llm_calls_per_hour: crate::constants::DEFAULT_MAX_LLM_CALLS_PER_HOUR,
67        }
68    }
69
70    /// Create a strict configuration
71    pub fn strict() -> Self {
72        Self {
73            enable_injection_detection: true,
74            enable_output_validation: true,
75            max_code_size_bytes: 100_000, // Smaller limit for strict mode
76            strict_mode: true,
77            log_attacks: true,
78            max_llm_calls_per_hour: 50, // Lower rate limit for strict mode
79        }
80    }
81
82    /// Validate the configuration
83    pub fn validate(&self) -> Result<(), String> {
84        if self.max_code_size_bytes == 0 {
85            return Err("Maximum code size cannot be zero".to_string());
86        }
87
88        if self.max_llm_calls_per_hour == 0 {
89            return Err("Maximum LLM calls per hour cannot be zero".to_string());
90        }
91
92        Ok(())
93    }
94
95    /// Check if this is a development configuration
96    pub fn is_development(&self) -> bool {
97        !self.strict_mode && !self.enable_injection_detection
98    }
99
100    /// Check if this is a production configuration
101    pub fn is_production(&self) -> bool {
102        self.strict_mode && self.enable_injection_detection
103    }
104
105    /// Get a human-readable description of the configuration
106    pub fn describe(&self) -> String {
107        format!(
108            "LLMSecurityConfig: injection_detection={}, output_validation={}, max_size={}B, strict={}, log_attacks={}, rate_limit={}/hour",
109            self.enable_injection_detection,
110            self.enable_output_validation,
111            self.max_code_size_bytes,
112            self.strict_mode,
113            self.log_attacks,
114            self.max_llm_calls_per_hour
115        )
116    }
117}
118
119/// Result of injection detection analysis
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct InjectionDetectionResult {
122    /// Whether malicious patterns were detected
123    pub is_malicious: bool,
124    /// Confidence score (0.0 - 1.0)
125    pub confidence: f32,
126    /// List of detected malicious patterns
127    pub detected_patterns: Vec<String>,
128    /// Overall risk score
129    pub risk_score: u32,
130}
131
132impl InjectionDetectionResult {
133    /// Create a new detection result
134    pub fn new(is_malicious: bool, confidence: f32, detected_patterns: Vec<String>, risk_score: u32) -> Self {
135        Self {
136            is_malicious,
137            confidence,
138            detected_patterns,
139            risk_score,
140        }
141    }
142
143    /// Create a safe result (no malicious patterns detected)
144    pub fn safe() -> Self {
145        Self {
146            is_malicious: false,
147            confidence: 0.0,
148            detected_patterns: Vec::new(),
149            risk_score: 0,
150        }
151    }
152
153    /// Create a malicious result
154    pub fn malicious(confidence: f32, detected_patterns: Vec<String>, risk_score: u32) -> Self {
155        Self {
156            is_malicious: true,
157            confidence,
158            detected_patterns,
159            risk_score,
160        }
161    }
162
163    /// Check if this result indicates high risk
164    pub fn is_high_risk(&self) -> bool {
165        self.risk_score >= crate::constants::DEFAULT_HIGH_RISK_THRESHOLD
166    }
167
168    /// Check if this result indicates critical risk
169    pub fn is_critical_risk(&self) -> bool {
170        self.risk_score >= crate::constants::REGEX_DOS_RISK_SCORE
171    }
172
173    /// Get risk level as a string
174    pub fn risk_level(&self) -> &'static str {
175        if self.is_critical_risk() {
176            "CRITICAL"
177        } else if self.is_high_risk() {
178            "HIGH"
179        } else if self.risk_score >= crate::constants::DEFAULT_MALICIOUS_THRESHOLD {
180            "MEDIUM"
181        } else if self.risk_score > 0 {
182            "LOW"
183        } else {
184            "NONE"
185        }
186    }
187
188    /// Get a summary of the detection result
189    pub fn summary(&self) -> String {
190        if self.is_malicious {
191            format!(
192                "MALICIOUS ({}): {} patterns detected, risk score: {}, confidence: {:.2}",
193                self.risk_level(),
194                self.detected_patterns.len(),
195                self.risk_score,
196                self.confidence
197            )
198        } else {
199            "SAFE: No malicious patterns detected".to_string()
200        }
201    }
202}
203
204/// Main LLM Security struct
205pub struct LLMSecurity {
206    config: LLMSecurityConfig,
207}
208
209impl LLMSecurity {
210    /// Create a new LLM Security instance
211    pub fn new(config: LLMSecurityConfig) -> Self {
212        Self { config }
213    }
214
215    /// Create a new instance with default configuration
216    pub fn default() -> Self {
217        Self::new(LLMSecurityConfig::default())
218    }
219
220    /// Get the current configuration
221    pub fn config(&self) -> &LLMSecurityConfig {
222        &self.config
223    }
224
225    /// Update the configuration
226    pub fn update_config(&mut self, config: LLMSecurityConfig) {
227        self.config = config;
228    }
229}