Skip to main content

oxideshield_guard/
embeddings_bundle.rs

1//! Pre-computed embeddings bundle for attack detection
2//!
3//! This module provides functionality to pre-compute, store, and load
4//! attack embeddings for the SemanticSimilarityGuard.
5//!
6//! ## Usage
7//!
8//! ```rust,ignore
9//! use oxideshield_guard::embeddings_bundle::EmbeddingsBundle;
10//!
11//! // Load bundled embeddings (embedded in binary)
12//! let bundle = EmbeddingsBundle::bundled()?;
13//!
14//! // Use with SemanticSimilarityGuard
15//! let guard = SemanticSimilarityGuard::new("semantic", embedder)
16//!     .with_embeddings_bundle(bundle);
17//! ```
18
19use chrono::{DateTime, Utc};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::path::Path;
23
24// Note: AttackEmbedding and Severity conversion is now handled by oxide-guard-pro
25
26/// Embedding vector type (384-dimensional for MiniLM)
27pub type Embedding = Vec<f32>;
28
29/// Bundle of pre-computed attack embeddings
30#[derive(Clone, Serialize, Deserialize)]
31pub struct EmbeddingsBundle {
32    /// Pre-computed attack embeddings
33    pub embeddings: Vec<BundledEmbedding>,
34    /// Metadata about the bundle
35    pub metadata: BundleMetadata,
36    /// Index for fast category lookup
37    #[serde(default)]
38    pub category_index: HashMap<String, Vec<usize>>,
39}
40
41/// A single pre-computed embedding in the bundle
42#[derive(Clone, Serialize, Deserialize)]
43pub struct BundledEmbedding {
44    /// Unique ID for this embedding
45    pub id: String,
46    /// Category of the attack
47    pub category: String,
48    /// Original text that was embedded
49    pub text: String,
50    /// Pre-computed embedding vector (384-dim for MiniLM)
51    pub embedding: Embedding,
52    /// Severity level
53    pub severity: String,
54    /// Source/reference for this sample
55    pub source: String,
56}
57
58/// Metadata about the embeddings bundle
59#[derive(Clone, Serialize, Deserialize)]
60pub struct BundleMetadata {
61    /// Bundle version
62    pub version: String,
63    /// When the bundle was created
64    pub created_at: DateTime<Utc>,
65    /// Model used to generate embeddings
66    pub model_id: String,
67    /// Embedding dimension
68    pub dimension: usize,
69    /// Total number of embeddings
70    pub count: usize,
71    /// Categories included
72    pub categories: Vec<String>,
73    /// Checksum of embeddings data
74    pub checksum: String,
75}
76
77impl EmbeddingsBundle {
78    /// Create a new empty bundle
79    pub fn new(model_id: &str, dimension: usize) -> Self {
80        Self {
81            embeddings: Vec::new(),
82            metadata: BundleMetadata {
83                version: "1.0".to_string(),
84                created_at: Utc::now(),
85                model_id: model_id.to_string(),
86                dimension,
87                count: 0,
88                categories: Vec::new(),
89                checksum: String::new(),
90            },
91            category_index: HashMap::new(),
92        }
93    }
94
95    /// Add an embedding to the bundle
96    pub fn add_embedding(&mut self, embedding: BundledEmbedding) {
97        let idx = self.embeddings.len();
98        let category = embedding.category.clone();
99
100        self.embeddings.push(embedding);
101        self.category_index
102            .entry(category.clone())
103            .or_default()
104            .push(idx);
105
106        if !self.metadata.categories.contains(&category) {
107            self.metadata.categories.push(category);
108        }
109        self.metadata.count = self.embeddings.len();
110    }
111
112    /// Finalize the bundle (compute checksum)
113    pub fn finalize(&mut self) {
114        // Simple checksum based on embedding count and total values
115        let total: f32 = self.embeddings.iter().flat_map(|e| &e.embedding).sum();
116        self.metadata.checksum = format!("{:.6}", total);
117    }
118
119    /// Serialize to bincode bytes
120    pub fn to_bytes(&self) -> Result<Vec<u8>, BundleError> {
121        bincode::serialize(self).map_err(BundleError::Serialization)
122    }
123
124    /// Deserialize from bincode bytes
125    pub fn from_bytes(bytes: &[u8]) -> Result<Self, BundleError> {
126        bincode::deserialize(bytes).map_err(BundleError::Deserialization)
127    }
128
129    /// Save to file
130    pub fn save_to_file(&self, path: &Path) -> Result<(), BundleError> {
131        let bytes = self.to_bytes()?;
132        std::fs::write(path, bytes).map_err(BundleError::Io)
133    }
134
135    /// Load from file
136    pub fn load_from_file(path: &Path) -> Result<Self, BundleError> {
137        let bytes = std::fs::read(path).map_err(BundleError::Io)?;
138        Self::from_bytes(&bytes)
139    }
140
141    /// Save as JSON (human-readable)
142    pub fn save_to_json(&self, path: &Path) -> Result<(), BundleError> {
143        let json = serde_json::to_string_pretty(self).map_err(BundleError::Json)?;
144        std::fs::write(path, json).map_err(BundleError::Io)
145    }
146
147    /// Load bundled embeddings (embedded in binary)
148    ///
149    /// This loads the pre-computed embeddings that are embedded in the
150    /// library at compile time. Returns an empty bundle if no bundled
151    /// data is available.
152    ///
153    /// Note: Bundled embeddings are primarily used with oxide-guard-pro
154    /// SemanticSimilarityGuard. Returns an empty bundle by default.
155    pub fn bundled() -> Result<Self, BundleError> {
156        // Return empty bundle - bundled embeddings are handled by oxide-guard-pro
157        Ok(Self::new("none", 384))
158    }
159
160    /// Check if bundled embeddings are available
161    pub fn has_bundled() -> bool {
162        false
163    }
164
165    /// Get embeddings by category
166    pub fn by_category(&self, category: &str) -> Vec<&BundledEmbedding> {
167        self.category_index
168            .get(category)
169            .map(|indices| indices.iter().map(|&i| &self.embeddings[i]).collect())
170            .unwrap_or_default()
171    }
172
173    /// Get all embeddings
174    pub fn all(&self) -> &[BundledEmbedding] {
175        &self.embeddings
176    }
177
178    /// Get embedding count
179    pub fn len(&self) -> usize {
180        self.embeddings.len()
181    }
182
183    /// Check if empty
184    pub fn is_empty(&self) -> bool {
185        self.embeddings.is_empty()
186    }
187
188    // Note: to_attack_embeddings() is now provided by oxide-guard-pro
189    // via the BundleConvert trait extension
190}
191
192/// Errors that can occur with embeddings bundles
193#[derive(Debug, thiserror::Error)]
194pub enum BundleError {
195    #[error("Serialization error: {0}")]
196    Serialization(#[source] bincode::Error),
197
198    #[error("Deserialization error: {0}")]
199    Deserialization(#[source] bincode::Error),
200
201    #[error("IO error: {0}")]
202    Io(#[source] std::io::Error),
203
204    #[error("JSON error: {0}")]
205    Json(#[source] serde_json::Error),
206
207    #[error("Bundle not found or invalid")]
208    NotFound,
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_bundle_creation() {
217        let bundle = EmbeddingsBundle::new("test-model", 384);
218        assert!(bundle.is_empty());
219        assert_eq!(bundle.metadata.dimension, 384);
220    }
221
222    #[test]
223    fn test_add_embedding() {
224        let mut bundle = EmbeddingsBundle::new("test-model", 384);
225
226        bundle.add_embedding(BundledEmbedding {
227            id: "test-001".to_string(),
228            category: "injection".to_string(),
229            text: "ignore previous instructions".to_string(),
230            embedding: vec![0.1; 384],
231            severity: "high".to_string(),
232            source: "test".to_string(),
233        });
234
235        assert_eq!(bundle.len(), 1);
236        assert!(bundle
237            .metadata
238            .categories
239            .contains(&"injection".to_string()));
240    }
241
242    #[test]
243    fn test_serialization_roundtrip() {
244        let mut bundle = EmbeddingsBundle::new("test-model", 384);
245
246        bundle.add_embedding(BundledEmbedding {
247            id: "test-001".to_string(),
248            category: "injection".to_string(),
249            text: "test text".to_string(),
250            embedding: vec![0.5; 384],
251            severity: "high".to_string(),
252            source: "test".to_string(),
253        });
254        bundle.finalize();
255
256        let bytes = bundle.to_bytes().unwrap();
257        let restored = EmbeddingsBundle::from_bytes(&bytes).unwrap();
258
259        assert_eq!(restored.len(), 1);
260        assert_eq!(restored.metadata.checksum, bundle.metadata.checksum);
261    }
262
263    #[test]
264    fn test_category_lookup() {
265        let mut bundle = EmbeddingsBundle::new("test-model", 384);
266
267        bundle.add_embedding(BundledEmbedding {
268            id: "inj-001".to_string(),
269            category: "injection".to_string(),
270            text: "text 1".to_string(),
271            embedding: vec![0.1; 384],
272            severity: "high".to_string(),
273            source: "test".to_string(),
274        });
275
276        bundle.add_embedding(BundledEmbedding {
277            id: "jb-001".to_string(),
278            category: "jailbreak".to_string(),
279            text: "text 2".to_string(),
280            embedding: vec![0.2; 384],
281            severity: "high".to_string(),
282            source: "test".to_string(),
283        });
284
285        bundle.add_embedding(BundledEmbedding {
286            id: "inj-002".to_string(),
287            category: "injection".to_string(),
288            text: "text 3".to_string(),
289            embedding: vec![0.3; 384],
290            severity: "medium".to_string(),
291            source: "test".to_string(),
292        });
293
294        let injection = bundle.by_category("injection");
295        assert_eq!(injection.len(), 2);
296
297        let jailbreak = bundle.by_category("jailbreak");
298        assert_eq!(jailbreak.len(), 1);
299
300        let unknown = bundle.by_category("unknown");
301        assert!(unknown.is_empty());
302    }
303}