1use crate::{Error, Result, Pattern};
3use hyperscan_tokio_sys::{DatabasePtr, Mode, Platform, VectorScan, Flags, ExpressionExt, compile_extended};
4use arc_swap::ArcSwap;
5use parking_lot::RwLock;
6use std::sync::Arc;
7use std::path::Path;
8use tokio::sync::Notify;
9
10#[derive(Debug, Clone)]
12pub struct DatabaseInfo {
13 pub version: String,
14 pub mode: Mode,
15 pub expression_count: usize,
16 pub database_size: usize,
17 pub stream_size: Option<usize>,
18}
19
20#[derive(Debug, Clone)]
22pub struct ExpressionInfo {
23 pub id: u32,
24 pub flags: hyperscan_tokio_sys::Flags,
25 pub min_offset: Option<u64>,
26 pub max_offset: Option<u64>,
27 pub min_length: Option<u64>,
28 pub edit_distance: Option<u32>,
29 pub hamming_distance: Option<u32>,
30}
31
32impl Default for DatabaseInfo {
33 fn default() -> Self {
34 Self {
35 version: String::new(),
36 mode: Mode::BLOCK,
37 expression_count: 0,
38 database_size: 0,
39 stream_size: None,
40 }
41 }
42}
43
44pub struct Database {
46 inner: Arc<DatabaseInner>,
47}
48
49pub(crate) struct DatabaseInner {
50 ptr: DatabasePtr,
51 mode: Mode,
52 pattern_count: usize,
53 patterns: Vec<Pattern>,
54}
55
56impl Drop for DatabaseInner {
57 fn drop(&mut self) {
58 VectorScan::free_database(DatabasePtr(self.ptr.0));
59 }
60}
61
62impl Database {
63 pub(crate) fn from_raw(ptr: DatabasePtr, mode: Mode, patterns: Vec<Pattern>) -> Self {
65 Self {
66 inner: Arc::new(DatabaseInner {
67 ptr,
68 mode,
69 pattern_count: patterns.len(),
70 patterns,
71 }),
72 }
73 }
74
75 pub(crate) fn compile(
76 patterns: Vec<Pattern>,
77 mode: Mode,
78 platform: Option<Platform>,
79 _som_horizon: Option<u64>,
80 ) -> Result<Self> {
81 if patterns.is_empty() {
82 return Err(Error::Compile {
83 message: "No patterns provided".to_string(),
84 pattern_id: None,
85 position: None,
86 });
87 }
88
89 let needs_extended = patterns.iter().any(|p| p.has_extended_params());
91
92 let db_ptr = if needs_extended {
93 let expressions: Vec<&str> = patterns.iter().map(|p| p.expression.as_str()).collect();
95 let flags: Vec<Flags> = patterns.iter().map(|p| p.flags).collect();
96 let ids: Vec<u32> = patterns.iter().map(|p| p.id).collect();
97
98 let ext: Vec<ExpressionExt> = patterns.iter().map(|p| {
99 let mut ext = ExpressionExt::default();
100 let mut ext_flags = 0u64;
101
102 if let Some(v) = p.min_offset {
103 ext.min_offset = v;
104 ext_flags |= ExpressionExt::FLAG_MIN_OFFSET;
105 }
106 if let Some(v) = p.max_offset {
107 ext.max_offset = v;
108 ext_flags |= ExpressionExt::FLAG_MAX_OFFSET;
109 }
110 if let Some(v) = p.min_length {
111 ext.min_length = v;
112 ext_flags |= ExpressionExt::FLAG_MIN_LENGTH;
113 }
114 if let Some(v) = p.edit_distance {
115 ext.edit_distance = v;
116 ext_flags |= ExpressionExt::FLAG_EDIT_DISTANCE;
117 }
118 if let Some(v) = p.hamming_distance {
119 ext.hamming_distance = v;
120 ext_flags |= ExpressionExt::FLAG_HAMMING_DISTANCE;
121 }
122
123 ext.flags = ext_flags;
124 ext
125 }).collect();
126
127 compile_extended(&expressions, &flags, &ids, &ext, mode, platform.as_ref())
128 .map_err(|e| Error::Compile {
129 message: e.message,
130 pattern_id: Some(e.expression as u32),
131 position: e.position,
132 })?
133 } else {
134 let expressions: Vec<&str> = patterns.iter().map(|p| p.expression.as_str()).collect();
136 let flags: Vec<Flags> = patterns.iter().map(|p| p.flags).collect();
137 let ids: Vec<u32> = patterns.iter().map(|p| p.id).collect();
138
139 VectorScan::compile_multi(&expressions, &flags, &ids, mode, platform.as_ref())
140 .map_err(|e| Error::Compile {
141 message: e.message,
142 pattern_id: Some(e.expression as u32),
143 position: e.position,
144 })?
145 };
146
147 Ok(Self {
148 inner: Arc::new(DatabaseInner {
149 ptr: db_ptr,
150 mode,
151 pattern_count: patterns.len(),
152 patterns,
153 }),
154 })
155 }
156
157 pub fn mode(&self) -> Mode {
158 self.inner.mode
159 }
160
161 pub fn pattern_count(&self) -> usize {
162 self.inner.pattern_count
163 }
164
165 pub fn serialize(&self) -> Result<Vec<u8>> {
166 VectorScan::serialize_database(&self.inner.ptr)
167 .map_err(|e| Error::Serialization(e))
168 }
169
170 pub fn deserialize(data: &[u8]) -> Result<Self> {
171 let db_ptr = VectorScan::deserialize_database(data)
172 .map_err(|e| Error::Serialization(e))?;
173
174 let info = VectorScan::database_info(&db_ptr)
176 .map_err(|e| Error::Runtime(e))?;
177
178 let pattern_count = info
180 .split_whitespace()
181 .skip_while(|&s| s != "Expressions:")
182 .nth(1)
183 .and_then(|s| s.parse::<usize>().ok())
184 .unwrap_or(0);
185
186 Ok(Self {
187 inner: Arc::new(DatabaseInner {
188 ptr: db_ptr,
189 mode: Mode::BLOCK, pattern_count,
191 patterns: Vec::new(), }),
193 })
194 }
195
196 pub(crate) fn as_ptr(&self) -> &DatabasePtr {
197 &self.inner.ptr
198 }
199
200 pub async fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
202 let data = self.serialize()?;
203 tokio::fs::write(path, data).await
204 .map_err(|e| Error::Io(e))
205 }
206
207 pub async fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
209 let data = tokio::fs::read(path).await
210 .map_err(|e| Error::Io(e))?;
211 Self::deserialize(&data)
212 }
213
214 pub fn fingerprint(&self) -> u64 {
216 use std::collections::hash_map::DefaultHasher;
217 use std::hash::{Hash, Hasher};
218
219 let mut hasher = DefaultHasher::new();
220 for pattern in &self.inner.patterns {
221 pattern.expression.hash(&mut hasher);
222 pattern.id.hash(&mut hasher);
223 pattern.flags.bits().hash(&mut hasher);
224 }
225 hasher.finish()
226 }
227
228 pub fn info(&self) -> Result<DatabaseInfo> {
230 let info_str = VectorScan::database_info(&self.inner.ptr)
231 .map_err(|e| Error::Runtime(e))?;
232
233 let mut info = DatabaseInfo::default();
236
237 for part in info_str.split_whitespace() {
238 match part {
239 "Version:" => continue,
240 "Features:" => continue,
241 "Mode:" => continue,
242 "Expressions:" => continue,
243 s if s.contains('.') => info.version = s.to_string(),
244 s if s.chars().all(|c| c.is_numeric()) => {
245 if info.expression_count == 0 {
246 info.expression_count = s.parse().unwrap_or(0);
247 }
248 }
249 _ => {}
250 }
251 }
252
253 info.mode = self.inner.mode;
254 info.expression_count = self.inner.pattern_count;
255
256 info.database_size = VectorScan::database_size(&self.inner.ptr)
258 .map_err(|e| Error::Runtime(e))?;
259
260 if self.inner.mode.contains(Mode::STREAM) {
262 info.stream_size = Some(
263 VectorScan::stream_size(&self.inner.ptr)
264 .map_err(|e| Error::Runtime(e))?
265 );
266 }
267
268 Ok(info)
269 }
270
271 pub fn size(&self) -> Result<usize> {
273 VectorScan::database_size(&self.inner.ptr)
274 .map_err(|e| Error::Runtime(e))
275 }
276
277 pub fn expression_info(&self, id: u32) -> Result<ExpressionInfo> {
279 let pattern = self.inner.patterns
280 .iter()
281 .find(|p| p.id == id)
282 .ok_or_else(|| Error::Pattern {
283 id,
284 message: format!("Pattern with ID {} not found", id),
285 })?;
286
287 Ok(ExpressionInfo {
288 id: pattern.id,
289 flags: pattern.flags,
290 min_offset: pattern.min_offset,
291 max_offset: pattern.max_offset,
292 min_length: pattern.min_length,
293 edit_distance: pattern.edit_distance,
294 hamming_distance: pattern.hamming_distance,
295 })
296 }
297
298 pub fn expression_contexts(&self) -> Vec<crate::expression::ExpressionContext> {
300 self.inner.patterns
301 .iter()
302 .map(|p| p.into())
303 .collect()
304 }
305
306 pub fn expression_context(&self, id: u32) -> Result<crate::expression::ExpressionContext> {
308 let pattern = self.inner.patterns
309 .iter()
310 .find(|p| p.id == id)
311 .ok_or_else(|| Error::Pattern {
312 id,
313 message: format!("Pattern with ID {} not found", id),
314 })?;
315
316 Ok(pattern.into())
317 }
318
319 pub fn validate(&self) -> Result<()> {
321 let info = self.info()?;
323
324 if info.expression_count != self.inner.pattern_count {
325 return Err(Error::Runtime(
326 format!("Database pattern count mismatch: {} vs {}",
327 info.expression_count, self.inner.pattern_count)
328 ));
329 }
330
331 let scratch = VectorScan::alloc_scratch(&self.inner.ptr)
333 .map_err(|e| Error::Runtime(e))?;
334 VectorScan::free_scratch(scratch);
335
336 Ok(())
337 }
338}
339
340pub struct ReloadableDatabase {
342 current: ArcSwap<Database>,
343 reload_notify: Arc<Notify>,
344 version: Arc<RwLock<u64>>,
345}
346
347impl ReloadableDatabase {
348 pub fn new(database: Database) -> Self {
349 Self {
350 current: ArcSwap::from_pointee(database),
351 reload_notify: Arc::new(Notify::new()),
352 version: Arc::new(RwLock::new(0)),
353 }
354 }
355
356 pub async fn reload(&self, new_database: Database) -> Result<()> {
357 self.current.store(Arc::new(new_database));
358 *self.version.write() += 1;
359 self.reload_notify.notify_waiters();
360 Ok(())
361 }
362
363 pub fn current(&self) -> Arc<Database> {
364 self.current.load_full()
365 }
366
367 pub fn version(&self) -> u64 {
368 *self.version.read()
369 }
370
371 pub async fn wait_for_reload(&self) {
372 self.reload_notify.notified().await
373 }
374}