hyperscan_tokio/
database.rs

1// src/database.rs
2use crate::{Error, Result, Pattern};
3use hyperscan_tokio_sys::{DatabasePtr, Mode, Platform, VectorScan, Flags, ExpressionExt, compile_extended};
4use arc_swap::ArcSwap;
5use parking_lot::RwLock;
6use std::sync::Arc;
7use std::path::Path;
8use tokio::sync::Notify;
9
10/// Information about a compiled database
11#[derive(Debug, Clone)]
12pub struct DatabaseInfo {
13    pub version: String,
14    pub mode: Mode,
15    pub expression_count: usize,
16    pub database_size: usize,
17    pub stream_size: Option<usize>,
18}
19
20/// Information about a specific expression
21#[derive(Debug, Clone)]
22pub struct ExpressionInfo {
23    pub id: u32,
24    pub flags: hyperscan_tokio_sys::Flags,
25    pub min_offset: Option<u64>,
26    pub max_offset: Option<u64>,
27    pub min_length: Option<u64>,
28    pub edit_distance: Option<u32>,
29    pub hamming_distance: Option<u32>,
30}
31
32impl Default for DatabaseInfo {
33    fn default() -> Self {
34        Self {
35            version: String::new(),
36            mode: Mode::BLOCK,
37            expression_count: 0,
38            database_size: 0,
39            stream_size: None,
40        }
41    }
42}
43
44/// A compiled pattern database
45pub struct Database {
46    inner: Arc<DatabaseInner>,
47}
48
49pub(crate) struct DatabaseInner {
50    ptr: DatabasePtr,
51    mode: Mode,
52    pattern_count: usize,
53    patterns: Vec<Pattern>,
54}
55
56impl Drop for DatabaseInner {
57    fn drop(&mut self) {
58        VectorScan::free_database(DatabasePtr(self.ptr.0));
59    }
60}
61
62impl Database {
63    /// Create database from raw parts (internal use)
64    pub(crate) fn from_raw(ptr: DatabasePtr, mode: Mode, patterns: Vec<Pattern>) -> Self {
65        Self {
66            inner: Arc::new(DatabaseInner {
67                ptr,
68                mode,
69                pattern_count: patterns.len(),
70                patterns,
71            }),
72        }
73    }
74    
75    pub(crate) fn compile(
76        patterns: Vec<Pattern>,
77        mode: Mode,
78        platform: Option<Platform>,
79        _som_horizon: Option<u64>,
80    ) -> Result<Self> {
81        if patterns.is_empty() {
82            return Err(Error::Compile {
83                message: "No patterns provided".to_string(),
84                pattern_id: None,
85                position: None,
86            });
87        }
88        
89        // Check if we need extended compilation
90        let needs_extended = patterns.iter().any(|p| p.has_extended_params());
91        
92        let db_ptr = if needs_extended {
93            // Prepare extended parameters
94            let expressions: Vec<&str> = patterns.iter().map(|p| p.expression.as_str()).collect();
95            let flags: Vec<Flags> = patterns.iter().map(|p| p.flags).collect();
96            let ids: Vec<u32> = patterns.iter().map(|p| p.id).collect();
97            
98            let ext: Vec<ExpressionExt> = patterns.iter().map(|p| {
99                let mut ext = ExpressionExt::default();
100                let mut ext_flags = 0u64;
101                
102                if let Some(v) = p.min_offset {
103                    ext.min_offset = v;
104                    ext_flags |= ExpressionExt::FLAG_MIN_OFFSET;
105                }
106                if let Some(v) = p.max_offset {
107                    ext.max_offset = v;
108                    ext_flags |= ExpressionExt::FLAG_MAX_OFFSET;
109                }
110                if let Some(v) = p.min_length {
111                    ext.min_length = v;
112                    ext_flags |= ExpressionExt::FLAG_MIN_LENGTH;
113                }
114                if let Some(v) = p.edit_distance {
115                    ext.edit_distance = v;
116                    ext_flags |= ExpressionExt::FLAG_EDIT_DISTANCE;
117                }
118                if let Some(v) = p.hamming_distance {
119                    ext.hamming_distance = v;
120                    ext_flags |= ExpressionExt::FLAG_HAMMING_DISTANCE;
121                }
122                
123                ext.flags = ext_flags;
124                ext
125            }).collect();
126            
127            compile_extended(&expressions, &flags, &ids, &ext, mode, platform.as_ref())
128                .map_err(|e| Error::Compile {
129                    message: e.message,
130                    pattern_id: Some(e.expression as u32),
131                    position: e.position,
132                })?
133        } else {
134            // Simple compilation
135            let expressions: Vec<&str> = patterns.iter().map(|p| p.expression.as_str()).collect();
136            let flags: Vec<Flags> = patterns.iter().map(|p| p.flags).collect();
137            let ids: Vec<u32> = patterns.iter().map(|p| p.id).collect();
138            
139            VectorScan::compile_multi(&expressions, &flags, &ids, mode, platform.as_ref())
140                .map_err(|e| Error::Compile {
141                    message: e.message,
142                    pattern_id: Some(e.expression as u32),
143                    position: e.position,
144                })?
145        };
146        
147        Ok(Self {
148            inner: Arc::new(DatabaseInner {
149                ptr: db_ptr,
150                mode,
151                pattern_count: patterns.len(),
152                patterns,
153            }),
154        })
155    }
156    
157    pub fn mode(&self) -> Mode {
158        self.inner.mode
159    }
160    
161    pub fn pattern_count(&self) -> usize {
162        self.inner.pattern_count
163    }
164    
165    pub fn serialize(&self) -> Result<Vec<u8>> {
166        VectorScan::serialize_database(&self.inner.ptr)
167            .map_err(|e| Error::Serialization(e))
168    }
169    
170    pub fn deserialize(data: &[u8]) -> Result<Self> {
171        let db_ptr = VectorScan::deserialize_database(data)
172            .map_err(|e| Error::Serialization(e))?;
173            
174        // Get database info to determine pattern count
175        let info = VectorScan::database_info(&db_ptr)
176            .map_err(|e| Error::Runtime(e))?;
177            
178        // Parse pattern count from info (format: "Version: X.X.X Features: XX Expressions: N")
179        let pattern_count = info
180            .split_whitespace()
181            .skip_while(|&s| s != "Expressions:")
182            .nth(1)
183            .and_then(|s| s.parse::<usize>().ok())
184            .unwrap_or(0);
185        
186        Ok(Self {
187            inner: Arc::new(DatabaseInner {
188                ptr: db_ptr,
189                mode: Mode::BLOCK, // TODO: Detect from database
190                pattern_count,
191                patterns: Vec::new(), // Patterns not preserved in serialization
192            }),
193        })
194    }
195    
196    pub(crate) fn as_ptr(&self) -> &DatabasePtr {
197        &self.inner.ptr
198    }
199    
200    /// Save database to file
201    pub async fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
202        let data = self.serialize()?;
203        tokio::fs::write(path, data).await
204            .map_err(|e| Error::Io(e))
205    }
206    
207    /// Load database from file
208    pub async fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
209        let data = tokio::fs::read(path).await
210            .map_err(|e| Error::Io(e))?;
211        Self::deserialize(&data)
212    }
213    
214    /// Get a hash of the patterns for caching
215    pub fn fingerprint(&self) -> u64 {
216        use std::collections::hash_map::DefaultHasher;
217        use std::hash::{Hash, Hasher};
218        
219        let mut hasher = DefaultHasher::new();
220        for pattern in &self.inner.patterns {
221            pattern.expression.hash(&mut hasher);
222            pattern.id.hash(&mut hasher);
223            pattern.flags.bits().hash(&mut hasher);
224        }
225        hasher.finish()
226    }
227    
228    /// Get database information
229    pub fn info(&self) -> Result<DatabaseInfo> {
230        let info_str = VectorScan::database_info(&self.inner.ptr)
231            .map_err(|e| Error::Runtime(e))?;
232            
233        // Parse the info string
234        // Format: "Version: X.X.X Features: XX Mode: XXXX Expressions: N"
235        let mut info = DatabaseInfo::default();
236        
237        for part in info_str.split_whitespace() {
238            match part {
239                "Version:" => continue,
240                "Features:" => continue,
241                "Mode:" => continue,
242                "Expressions:" => continue,
243                s if s.contains('.') => info.version = s.to_string(),
244                s if s.chars().all(|c| c.is_numeric()) => {
245                    if info.expression_count == 0 {
246                        info.expression_count = s.parse().unwrap_or(0);
247                    }
248                }
249                _ => {}
250            }
251        }
252        
253        info.mode = self.inner.mode;
254        info.expression_count = self.inner.pattern_count;
255        
256        // Get database size
257        info.database_size = VectorScan::database_size(&self.inner.ptr)
258            .map_err(|e| Error::Runtime(e))?;
259        
260        // Get stream size if applicable
261        if self.inner.mode.contains(Mode::STREAM) {
262            info.stream_size = Some(
263                VectorScan::stream_size(&self.inner.ptr)
264                    .map_err(|e| Error::Runtime(e))?
265            );
266        }
267        
268        Ok(info)
269    }
270    
271    /// Get database size in bytes
272    pub fn size(&self) -> Result<usize> {
273        VectorScan::database_size(&self.inner.ptr)
274            .map_err(|e| Error::Runtime(e))
275    }
276    
277    /// Get information about a specific expression
278    pub fn expression_info(&self, id: u32) -> Result<ExpressionInfo> {
279        let pattern = self.inner.patterns
280            .iter()
281            .find(|p| p.id == id)
282            .ok_or_else(|| Error::Pattern {
283                id,
284                message: format!("Pattern with ID {} not found", id),
285            })?;
286        
287        Ok(ExpressionInfo {
288            id: pattern.id,
289            flags: pattern.flags,
290            min_offset: pattern.min_offset,
291            max_offset: pattern.max_offset,
292            min_length: pattern.min_length,
293            edit_distance: pattern.edit_distance,
294            hamming_distance: pattern.hamming_distance,
295        })
296    }
297    
298    /// Get expression contexts for all patterns
299    pub fn expression_contexts(&self) -> Vec<crate::expression::ExpressionContext> {
300        self.inner.patterns
301            .iter()
302            .map(|p| p.into())
303            .collect()
304    }
305    
306    /// Get expression context for a specific pattern
307    pub fn expression_context(&self, id: u32) -> Result<crate::expression::ExpressionContext> {
308        let pattern = self.inner.patterns
309            .iter()
310            .find(|p| p.id == id)
311            .ok_or_else(|| Error::Pattern {
312                id,
313                message: format!("Pattern with ID {} not found", id),
314            })?;
315        
316        Ok(pattern.into())
317    }
318    
319    /// Validate the database
320    pub fn validate(&self) -> Result<()> {
321        // Check database integrity
322        let info = self.info()?;
323        
324        if info.expression_count != self.inner.pattern_count {
325            return Err(Error::Runtime(
326                format!("Database pattern count mismatch: {} vs {}", 
327                    info.expression_count, self.inner.pattern_count)
328            ));
329        }
330        
331        // Verify we can create scratch space
332        let scratch = VectorScan::alloc_scratch(&self.inner.ptr)
333            .map_err(|e| Error::Runtime(e))?;
334        VectorScan::free_scratch(scratch);
335        
336        Ok(())
337    }
338}
339
340/// A database that can be reloaded at runtime
341pub struct ReloadableDatabase {
342    current: ArcSwap<Database>,
343    reload_notify: Arc<Notify>,
344    version: Arc<RwLock<u64>>,
345}
346
347impl ReloadableDatabase {
348    pub fn new(database: Database) -> Self {
349        Self {
350            current: ArcSwap::from_pointee(database),
351            reload_notify: Arc::new(Notify::new()),
352            version: Arc::new(RwLock::new(0)),
353        }
354    }
355    
356    pub async fn reload(&self, new_database: Database) -> Result<()> {
357        self.current.store(Arc::new(new_database));
358        *self.version.write() += 1;
359        self.reload_notify.notify_waiters();
360        Ok(())
361    }
362    
363    pub fn current(&self) -> Arc<Database> {
364        self.current.load_full()
365    }
366    
367    pub fn version(&self) -> u64 {
368        *self.version.read()
369    }
370    
371    pub async fn wait_for_reload(&self) {
372        self.reload_notify.notified().await
373    }
374}