1use std::ffi::c_void;
23use std::fmt;
24
25#[cfg(feature = "cuda")]
26pub mod cuda;
27
28#[cfg(feature = "omp")]
29pub mod omp;
30
31#[cfg(feature = "mpi")]
32pub mod mpi;
33
34#[allow(non_camel_case_types)]
37type c_int = i32;
38#[allow(non_camel_case_types)]
39type c_uint = u32;
40#[allow(non_camel_case_types)]
41type c_double = f64;
42
43#[allow(dead_code)]
44mod ffi {
45 use super::*;
46 extern "C" {
47 pub(crate) fn dress_init_graph(
48 n: c_int, e: c_int,
49 u: *mut c_int, v: *mut c_int,
50 w: *mut c_double, nw: *mut c_double,
51 variant: c_int, precompute_intercepts: c_int,
52 ) -> *mut c_void;
53
54 pub(crate) fn dress_fit(
55 g: *mut c_void, max_iterations: c_int, epsilon: c_double,
56 iterations: *mut c_int, delta: *mut c_double,
57 );
58
59 pub(crate) fn dress_free_graph(g: *mut c_void);
60
61 pub(crate) fn dress_get(
62 g: *mut c_void, u: c_int, v: c_int,
63 max_iterations: c_int, epsilon: c_double, edge_weight: c_double,
64 ) -> c_double;
65
66 pub(crate) fn dress_delta_fit_strided(
67 g: *mut c_void, k: c_int, iterations: c_int, epsilon: c_double,
68 n_samples: c_int, seed: c_uint,
69 hist_size: *mut c_int, keep_multisets: c_int,
70 multisets: *mut *mut c_double, num_subgraphs: *mut i64,
71 offset: c_int, stride: c_int,
72 ) -> *mut HistogramEntry;
73
74 pub(crate) fn dress_nabla_fit(
75 g: *mut c_void, k: c_int, iterations: c_int, epsilon: c_double,
76 n_samples: c_int, seed: c_uint,
77 hist_size: *mut c_int, keep_multisets: c_int,
78 multisets: *mut *mut c_double, num_tuples: *mut i64,
79 ) -> *mut HistogramEntry;
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87#[repr(i32)]
88pub enum Variant {
89 Undirected = 0,
90 Directed = 1,
91 Forward = 2,
92 Backward = 3,
93}
94
95#[derive(Debug, Clone)]
97pub struct DressResult {
98 pub sources: Vec<i32>,
99 pub targets: Vec<i32>,
100 pub edge_weight: Vec<f64>,
101 pub edge_dress: Vec<f64>,
102 pub vertex_dress: Vec<f64>,
103 pub vertex_weights: Option<Vec<f64>>,
104 pub iterations: i32,
105 pub delta: f64,
106}
107
108#[derive(Debug, Clone, Copy, PartialEq)]
110#[repr(C)]
111pub struct HistogramEntry {
112 pub value: f64,
113 pub count: i64,
114}
115
116impl fmt::Display for DressResult {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 write!(
119 f,
120 "DressResult(E={}, iterations={}, delta={:.6e})",
121 self.sources.len(),
122 self.iterations,
123 self.delta,
124 )
125 }
126}
127
128pub struct DRESS {
143 g: *mut c_void,
144 n: i32,
145 e: usize,
146 sources: Vec<i32>,
147 targets: Vec<i32>,
148}
149
150impl DRESS {
151 pub fn new(
153 n: i32,
154 sources: Vec<i32>,
155 targets: Vec<i32>,
156 weights: Option<Vec<f64>>,
157 vertex_weights: Option<Vec<f64>>,
158 variant: Variant,
159 precompute_intercepts: bool,
160 ) -> Result<DRESS, DressError> {
161 let e = sources.len();
162 if targets.len() != e {
163 return Err(DressError::LengthMismatch(
164 "sources and targets must have equal length".into(),
165 ));
166 }
167 unsafe {
168 let u_ptr = libc_malloc_copy_i32(&sources);
169 let v_ptr = libc_malloc_copy_i32(&targets);
170 let w_ptr = weights.as_ref().map_or(std::ptr::null_mut(), |w| libc_malloc_copy_f64(w));
171 let nw_ptr = vertex_weights.as_ref().map_or(std::ptr::null_mut(), |nw| libc_malloc_copy_f64(nw));
172 let g = ffi::dress_init_graph(n, e as c_int, u_ptr, v_ptr, w_ptr, nw_ptr,
173 variant as c_int, precompute_intercepts as c_int);
174 if g.is_null() {
175 return Err(DressError::InitFailed);
176 }
177 Ok(DRESS { g, n, e, sources, targets })
178 }
179 }
180
181 pub fn fit(&mut self, max_iterations: i32, epsilon: f64) -> (i32, f64) {
183 assert!(!self.g.is_null(), "DRESS already closed");
184 let mut iterations: c_int = 0;
185 let mut delta: c_double = 0.0;
186 unsafe {
187 ffi::dress_fit(self.g, max_iterations, epsilon, &mut iterations, &mut delta);
188 }
189 (iterations, delta)
190 }
191
192 pub fn get(&self, u: i32, v: i32, max_iterations: i32, epsilon: f64, edge_weight: f64) -> f64 {
194 assert!(!self.g.is_null(), "DRESS already closed");
195 unsafe { ffi::dress_get(self.g, u, v, max_iterations, epsilon, edge_weight) }
196 }
197
198 pub fn result(&self) -> DressResult {
200 assert!(!self.g.is_null(), "DRESS already closed");
201 let e = self.e;
202 let n = self.n as usize;
203 unsafe {
204 let base = self.g as *const u8;
205 let ew_ptr = *(base.add(72) as *const *const f64);
206 let ed_ptr = *(base.add(80) as *const *const f64);
207 let nd_ptr = *(base.add(96) as *const *const f64);
208 let nw_ptr = *(base.add(104) as *const *const f64);
209
210 let vertex_weights = if !nw_ptr.is_null() {
211 Some(std::slice::from_raw_parts(nw_ptr, n).to_vec())
212 } else {
213 None
214 };
215
216 DressResult {
217 sources: self.sources.clone(),
218 targets: self.targets.clone(),
219 edge_weight: std::slice::from_raw_parts(ew_ptr, e).to_vec(),
220 edge_dress: std::slice::from_raw_parts(ed_ptr, e).to_vec(),
221 vertex_dress: std::slice::from_raw_parts(nd_ptr, n).to_vec(),
222 vertex_weights,
223 iterations: 0,
224 delta: 0.0,
225 }
226 }
227 }
228
229 pub fn delta_fit(
233 &self,
234 k: i32,
235 max_iterations: i32,
236 epsilon: f64,
237 n_samples: i32,
238 seed: u32,
239 keep_multisets: bool,
240 compute_histogram: bool,
241 ) -> Result<DeltaDressResult, DressError> {
242 assert!(!self.g.is_null(), "DRESS already closed");
243 let e = self.e;
244
245 unsafe {
246 let mut hsize: c_int = 0;
247 let mut ms_ptr: *mut c_double = std::ptr::null_mut();
248 let mut num_sub: i64 = 0;
249 let h = ffi::dress_delta_fit_strided(
250 self.g,
251 k,
252 max_iterations,
253 epsilon,
254 n_samples,
255 seed,
256 if compute_histogram { &mut hsize } else { std::ptr::null_mut() },
257 if keep_multisets { 1 } else { 0 },
258 if keep_multisets { &mut ms_ptr } else { std::ptr::null_mut() },
259 &mut num_sub,
260 0,
261 1,
262 );
263
264 let histogram = histogram_from_raw(h, hsize);
265
266 extern "C" { fn free(ptr: *mut std::ffi::c_void); }
267
268 let multisets = if keep_multisets && !ms_ptr.is_null() && num_sub > 0 {
269 let len = (num_sub as usize) * e;
270 let ms = std::slice::from_raw_parts(ms_ptr, len).to_vec();
271 free(ms_ptr as *mut std::ffi::c_void);
272 Some(ms)
273 } else {
274 if keep_multisets && !ms_ptr.is_null() {
275 free(ms_ptr as *mut std::ffi::c_void);
276 }
277 None
278 };
279
280 if !h.is_null() {
281 free(h as *mut std::ffi::c_void);
282 }
283
284 Ok(DeltaDressResult {
285 histogram,
286 multisets,
287 num_subgraphs: num_sub,
288 })
289 }
290 }
291
292 pub fn nabla_fit(
294 &self,
295 k: i32,
296 max_iterations: i32,
297 epsilon: f64,
298 n_samples: i32,
299 seed: u32,
300 keep_multisets: bool,
301 compute_histogram: bool,
302 ) -> Result<NablaDressResult, DressError> {
303 assert!(!self.g.is_null(), "DRESS already closed");
304 let e = self.e;
305
306 unsafe {
307 let mut hsize: c_int = 0;
308 let mut ms_ptr: *mut c_double = std::ptr::null_mut();
309 let mut num_tup: i64 = 0;
310 let h = ffi::dress_nabla_fit(
311 self.g,
312 k,
313 max_iterations,
314 epsilon,
315 n_samples,
316 seed,
317 if compute_histogram { &mut hsize } else { std::ptr::null_mut() },
318 if keep_multisets { 1 } else { 0 },
319 if keep_multisets { &mut ms_ptr } else { std::ptr::null_mut() },
320 &mut num_tup,
321 );
322
323 let histogram = histogram_from_raw(h, hsize);
324
325 extern "C" { fn free(ptr: *mut std::ffi::c_void); }
326
327 let multisets = if keep_multisets && !ms_ptr.is_null() && num_tup > 0 {
328 let len = (num_tup as usize) * e;
329 let ms = std::slice::from_raw_parts(ms_ptr, len).to_vec();
330 free(ms_ptr as *mut std::ffi::c_void);
331 Some(ms)
332 } else {
333 if keep_multisets && !ms_ptr.is_null() {
334 free(ms_ptr as *mut std::ffi::c_void);
335 }
336 None
337 };
338
339 if !h.is_null() {
340 free(h as *mut std::ffi::c_void);
341 }
342
343 Ok(NablaDressResult {
344 histogram,
345 multisets,
346 num_tuples: num_tup,
347 })
348 }
349 }
350
351 pub fn close(&mut self) {
353 if !self.g.is_null() {
354 unsafe { ffi::dress_free_graph(self.g); }
355 self.g = std::ptr::null_mut();
356 }
357 }
358}
359
360impl Drop for DRESS {
361 fn drop(&mut self) {
362 self.close();
363 }
364}
365
366#[derive(Debug, Clone)]
368pub struct DeltaDressResult {
369 pub histogram: Vec<HistogramEntry>,
371 pub multisets: Option<Vec<f64>>,
375 pub num_subgraphs: i64,
377}
378
379#[derive(Debug, Clone)]
381pub struct NablaDressResult {
382 pub histogram: Vec<HistogramEntry>,
384 pub multisets: Option<Vec<f64>>,
387 pub num_tuples: i64,
389}
390
391impl fmt::Display for DeltaDressResult {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 let total: i64 = self.histogram.iter().map(|entry| entry.count).sum();
394 write!(
395 f,
396 "DeltaDressResult(histogram_entries={}, total_values={})",
397 self.histogram.len(), total,
398 )
399 }
400}
401
402impl fmt::Display for NablaDressResult {
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 let total: i64 = self.histogram.iter().map(|entry| entry.count).sum();
405 write!(
406 f,
407 "NablaDressResult(histogram_entries={}, total_values={})",
408 self.histogram.len(), total,
409 )
410 }
411}
412
413#[derive(Debug)]
415pub enum DressError {
416 LengthMismatch(String),
418 InitFailed,
420}
421
422impl fmt::Display for DressError {
423 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424 match self {
425 Self::LengthMismatch(msg) => write!(f, "length mismatch: {msg}"),
426 Self::InitFailed => write!(f, "dress_init_graph returned NULL"),
427 }
428 }
429}
430
431impl std::error::Error for DressError {}
432
433pub fn fit(
437 n: i32, sources: Vec<i32>, targets: Vec<i32>,
438 weights: Option<Vec<f64>>, vertex_weights: Option<Vec<f64>>,
439 variant: Variant, precompute: bool,
440 max_iterations: i32, epsilon: f64,
441) -> Result<DressResult, DressError> {
442 let mut g = DRESS::new(n, sources, targets, weights, vertex_weights, variant, precompute)?;
443 let (iterations, delta) = g.fit(max_iterations, epsilon);
444 let mut r = g.result();
445 r.iterations = iterations;
446 r.delta = delta;
447 Ok(r)
448}
449
450pub fn delta_fit(
452 n: i32, sources: Vec<i32>, targets: Vec<i32>,
453 weights: Option<Vec<f64>>, vertex_weights: Option<Vec<f64>>,
454 variant: Variant, precompute: bool,
455 k: i32, max_iterations: i32, epsilon: f64,
456 n_samples: i32, seed: u32,
457 keep_multisets: bool, compute_histogram: bool,
458) -> Result<DeltaDressResult, DressError> {
459 let g = DRESS::new(n, sources, targets, weights, vertex_weights, variant, precompute)?;
460 g.delta_fit(k, max_iterations, epsilon, n_samples, seed,
461 keep_multisets, compute_histogram)
462}
463
464pub fn nabla_fit(
466 n: i32, sources: Vec<i32>, targets: Vec<i32>,
467 weights: Option<Vec<f64>>, vertex_weights: Option<Vec<f64>>,
468 variant: Variant, precompute: bool,
469 k: i32, max_iterations: i32, epsilon: f64,
470 n_samples: i32, seed: u32,
471 keep_multisets: bool, compute_histogram: bool,
472) -> Result<NablaDressResult, DressError> {
473 let g = DRESS::new(n, sources, targets, weights, vertex_weights, variant, precompute)?;
474 g.nabla_fit(k, max_iterations, epsilon, n_samples, seed,
475 keep_multisets, compute_histogram)
476}
477
478pub(crate) unsafe fn libc_malloc_copy_i32(data: &[i32]) -> *mut c_int {
482 let bytes = data.len() * std::mem::size_of::<c_int>();
483 let ptr = libc::malloc(bytes) as *mut c_int;
484 assert!(!ptr.is_null(), "malloc failed");
485 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
486 ptr
487}
488
489pub(crate) unsafe fn libc_malloc_copy_f64(data: &[f64]) -> *mut c_double {
491 let bytes = data.len() * std::mem::size_of::<c_double>();
492 let ptr = libc::malloc(bytes) as *mut c_double;
493 assert!(!ptr.is_null(), "malloc failed");
494 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
495 ptr
496}
497
498pub(crate) unsafe fn histogram_from_raw(
499 data: *mut HistogramEntry,
500 hist_size: c_int,
501) -> Vec<HistogramEntry> {
502 if !data.is_null() && hist_size > 0 {
503 std::slice::from_raw_parts(data, hist_size as usize).to_vec()
504 } else {
505 vec![]
506 }
507}
508
509pub(crate) mod libc {
511 extern "C" {
512 pub fn malloc(size: usize) -> *mut u8;
513 }
514}