Skip to main content

dress_graph/
lib.rs

1//! # dress-graph
2//!
3//! Safe Rust bindings for the **DRESS** C library — A Continuous Framework for Structural Graph Refinement.  See the [DRESS repository](https://github.com/velicat/dress-graph) for more information.
4//!
5//! ```no_run
6//! use dress_graph::{DRESS, Variant};
7//!
8//! let sources = vec![0, 1, 2, 0];
9//! let targets = vec![1, 2, 3, 3];
10//!
11//! let mut g = DRESS::new(4, sources, targets,
12//!                        None, None, Variant::Undirected, false).unwrap();
13//! let (iters, delta) = g.fit(100, 1e-6);
14//! let result = g.result();
15//!
16//! println!("iterations: {}", iters);
17//! for (i, d) in result.edge_dress.iter().enumerate() {
18//!     println!("  edge {}: dress = {:.6}", i, d);
19//! }
20//! ```
21
22use std::ffi::c_void;
23use std::fmt;
24
25#[cfg(feature = "cuda")]
26pub mod cuda;
27
28#[cfg(feature = "omp")]
29pub mod omp;
30
31#[cfg(feature = "mpi")]
32pub mod mpi;
33
34// ── FFI declarations ────────────────────────────────────────────────
35
36#[allow(non_camel_case_types)]
37type c_int = i32;
38#[allow(non_camel_case_types)]
39type c_uint = u32;
40#[allow(non_camel_case_types)]
41type c_double = f64;
42
43#[allow(dead_code)]
44mod ffi {
45    use super::*;
46    extern "C" {
47        pub(crate) fn dress_init_graph(
48            n: c_int, e: c_int,
49            u: *mut c_int, v: *mut c_int,
50            w: *mut c_double, nw: *mut c_double,
51            variant: c_int, precompute_intercepts: c_int,
52        ) -> *mut c_void;
53
54        pub(crate) fn dress_fit(
55            g: *mut c_void, max_iterations: c_int, epsilon: c_double,
56            iterations: *mut c_int, delta: *mut c_double,
57        );
58
59        pub(crate) fn dress_free_graph(g: *mut c_void);
60
61        pub(crate) fn dress_get(
62            g: *mut c_void, u: c_int, v: c_int,
63            max_iterations: c_int, epsilon: c_double, edge_weight: c_double,
64        ) -> c_double;
65
66        pub(crate) fn dress_delta_fit_strided(
67            g: *mut c_void, k: c_int, iterations: c_int, epsilon: c_double,
68            n_samples: c_int, seed: c_uint,
69            hist_size: *mut c_int, keep_multisets: c_int,
70            multisets: *mut *mut c_double, num_subgraphs: *mut i64,
71            offset: c_int, stride: c_int,
72        ) -> *mut HistogramEntry;
73
74        pub(crate) fn dress_nabla_fit(
75            g: *mut c_void, k: c_int, iterations: c_int, epsilon: c_double,
76            n_samples: c_int, seed: c_uint,
77            hist_size: *mut c_int, keep_multisets: c_int,
78            multisets: *mut *mut c_double, num_tuples: *mut i64,
79        ) -> *mut HistogramEntry;
80    }
81}
82
83// ── Public types ────────────────────────────────────────────────────
84
85/// Graph variant — determines how neighbourhoods are constructed.
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87#[repr(i32)]
88pub enum Variant {
89    Undirected = 0,
90    Directed   = 1,
91    Forward    = 2,
92    Backward   = 3,
93}
94
95/// Result of the DRESS fitting procedure.
96#[derive(Debug, Clone)]
97pub struct DressResult {
98    pub sources:     Vec<i32>,
99    pub targets:     Vec<i32>,
100    pub edge_weight: Vec<f64>,
101    pub edge_dress:  Vec<f64>,
102    pub vertex_dress:  Vec<f64>,
103    pub vertex_weights: Option<Vec<f64>>,
104    pub iterations:  i32,
105    pub delta:       f64,
106}
107
108/// Exact sparse histogram entry produced by Δ^k-DRESS.
109#[derive(Debug, Clone, Copy, PartialEq)]
110#[repr(C)]
111pub struct HistogramEntry {
112    pub value: f64,
113    pub count: i64,
114}
115
116impl fmt::Display for DressResult {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        write!(
119            f,
120            "DressResult(E={}, iterations={}, delta={:.6e})",
121            self.sources.len(),
122            self.iterations,
123            self.delta,
124        )
125    }
126}
127
128// ── Persistent graph object ─────────────────────────────────────────
129
130/// A persistent DRESS graph that supports repeated `fit` and `get` calls.
131///
132/// The underlying C graph is freed automatically when dropped.
133///
134/// ```no_run
135/// use dress_graph::{DRESS, Variant};
136///
137/// let mut g = DRESS::new(4, vec![0,1,2,0], vec![1,2,3,3],
138///                        None, None, Variant::Undirected, false)?;
139/// g.fit(100, 1e-6);
140/// let d = g.get(0, 2, 100, 1e-6, 1.0);
141/// ```
142pub struct DRESS {
143    g: *mut c_void,
144    n: i32,
145    e: usize,
146    sources: Vec<i32>,
147    targets: Vec<i32>,
148}
149
150impl DRESS {
151    /// Construct a persistent DRESS graph.
152    pub fn new(
153        n: i32,
154        sources: Vec<i32>,
155        targets: Vec<i32>,
156        weights: Option<Vec<f64>>,
157        vertex_weights: Option<Vec<f64>>,
158        variant: Variant,
159        precompute_intercepts: bool,
160    ) -> Result<DRESS, DressError> {
161        let e = sources.len();
162        if targets.len() != e {
163            return Err(DressError::LengthMismatch(
164                "sources and targets must have equal length".into(),
165            ));
166        }
167        unsafe {
168            let u_ptr = libc_malloc_copy_i32(&sources);
169            let v_ptr = libc_malloc_copy_i32(&targets);
170            let w_ptr = weights.as_ref().map_or(std::ptr::null_mut(), |w| libc_malloc_copy_f64(w));
171            let nw_ptr = vertex_weights.as_ref().map_or(std::ptr::null_mut(), |nw| libc_malloc_copy_f64(nw));
172            let g = ffi::dress_init_graph(n, e as c_int, u_ptr, v_ptr, w_ptr, nw_ptr,
173                                     variant as c_int, precompute_intercepts as c_int);
174            if g.is_null() {
175                return Err(DressError::InitFailed);
176            }
177            Ok(DRESS { g, n, e, sources, targets })
178        }
179    }
180
181    /// Fit the DRESS model.  Returns `(iterations, delta)`.
182    pub fn fit(&mut self, max_iterations: i32, epsilon: f64) -> (i32, f64) {
183        assert!(!self.g.is_null(), "DRESS already closed");
184        let mut iterations: c_int = 0;
185        let mut delta: c_double = 0.0;
186        unsafe {
187            ffi::dress_fit(self.g, max_iterations, epsilon, &mut iterations, &mut delta);
188        }
189        (iterations, delta)
190    }
191
192    /// Query the DRESS value for an edge (existing or virtual).
193    pub fn get(&self, u: i32, v: i32, max_iterations: i32, epsilon: f64, edge_weight: f64) -> f64 {
194        assert!(!self.g.is_null(), "DRESS already closed");
195        unsafe { ffi::dress_get(self.g, u, v, max_iterations, epsilon, edge_weight) }
196    }
197
198    /// Extract a snapshot of the current results without freeing.
199    pub fn result(&self) -> DressResult {
200        assert!(!self.g.is_null(), "DRESS already closed");
201        let e = self.e;
202        let n = self.n as usize;
203        unsafe {
204            let base = self.g as *const u8;
205            let ew_ptr = *(base.add(72) as *const *const f64);
206            let ed_ptr = *(base.add(80) as *const *const f64);
207            let nd_ptr = *(base.add(96) as *const *const f64);
208            let nw_ptr = *(base.add(104) as *const *const f64);
209
210            let vertex_weights = if !nw_ptr.is_null() {
211                Some(std::slice::from_raw_parts(nw_ptr, n).to_vec())
212            } else {
213                None
214            };
215
216            DressResult {
217                sources:     self.sources.clone(),
218                targets:     self.targets.clone(),
219                edge_weight: std::slice::from_raw_parts(ew_ptr, e).to_vec(),
220                edge_dress:  std::slice::from_raw_parts(ed_ptr, e).to_vec(),
221                vertex_dress:  std::slice::from_raw_parts(nd_ptr, n).to_vec(),
222                vertex_weights,
223                iterations:  0,
224                delta:       0.0,
225            }
226        }
227    }
228
229    /// Run Δ^k-DRESS on the persistent graph: enumerate all C(N,k)
230    /// vertex-deletion subsets, fit DRESS on each subgraph, and return the
231    /// pooled histogram.
232    pub fn delta_fit(
233        &self,
234        k: i32,
235        max_iterations: i32,
236        epsilon: f64,
237        n_samples: i32,
238        seed: u32,
239        keep_multisets: bool,
240        compute_histogram: bool,
241    ) -> Result<DeltaDressResult, DressError> {
242        assert!(!self.g.is_null(), "DRESS already closed");
243        let e = self.e;
244
245        unsafe {
246            let mut hsize: c_int = 0;
247            let mut ms_ptr: *mut c_double = std::ptr::null_mut();
248            let mut num_sub: i64 = 0;
249            let h = ffi::dress_delta_fit_strided(
250                self.g,
251                k,
252                max_iterations,
253                epsilon,
254                n_samples,
255                seed,
256                if compute_histogram { &mut hsize } else { std::ptr::null_mut() },
257                if keep_multisets { 1 } else { 0 },
258                if keep_multisets { &mut ms_ptr } else { std::ptr::null_mut() },
259                &mut num_sub,
260                0,
261                1,
262            );
263
264            let histogram = histogram_from_raw(h, hsize);
265
266            extern "C" { fn free(ptr: *mut std::ffi::c_void); }
267
268            let multisets = if keep_multisets && !ms_ptr.is_null() && num_sub > 0 {
269                let len = (num_sub as usize) * e;
270                let ms = std::slice::from_raw_parts(ms_ptr, len).to_vec();
271                free(ms_ptr as *mut std::ffi::c_void);
272                Some(ms)
273            } else {
274                if keep_multisets && !ms_ptr.is_null() {
275                    free(ms_ptr as *mut std::ffi::c_void);
276                }
277                None
278            };
279
280            if !h.is_null() {
281                free(h as *mut std::ffi::c_void);
282            }
283
284            Ok(DeltaDressResult {
285                histogram,
286                multisets,
287                num_subgraphs: num_sub,
288            })
289        }
290    }
291
292    /// Run ∇^k-DRESS on the persistent graph.
293    pub fn nabla_fit(
294        &self,
295        k: i32,
296        max_iterations: i32,
297        epsilon: f64,
298        n_samples: i32,
299        seed: u32,
300        keep_multisets: bool,
301        compute_histogram: bool,
302    ) -> Result<NablaDressResult, DressError> {
303        assert!(!self.g.is_null(), "DRESS already closed");
304        let e = self.e;
305
306        unsafe {
307            let mut hsize: c_int = 0;
308            let mut ms_ptr: *mut c_double = std::ptr::null_mut();
309            let mut num_tup: i64 = 0;
310            let h = ffi::dress_nabla_fit(
311                self.g,
312                k,
313                max_iterations,
314                epsilon,
315                n_samples,
316                seed,
317                if compute_histogram { &mut hsize } else { std::ptr::null_mut() },
318                if keep_multisets { 1 } else { 0 },
319                if keep_multisets { &mut ms_ptr } else { std::ptr::null_mut() },
320                &mut num_tup,
321            );
322
323            let histogram = histogram_from_raw(h, hsize);
324
325            extern "C" { fn free(ptr: *mut std::ffi::c_void); }
326
327            let multisets = if keep_multisets && !ms_ptr.is_null() && num_tup > 0 {
328                let len = (num_tup as usize) * e;
329                let ms = std::slice::from_raw_parts(ms_ptr, len).to_vec();
330                free(ms_ptr as *mut std::ffi::c_void);
331                Some(ms)
332            } else {
333                if keep_multisets && !ms_ptr.is_null() {
334                    free(ms_ptr as *mut std::ffi::c_void);
335                }
336                None
337            };
338
339            if !h.is_null() {
340                free(h as *mut std::ffi::c_void);
341            }
342
343            Ok(NablaDressResult {
344                histogram,
345                multisets,
346                num_tuples: num_tup,
347            })
348        }
349    }
350
351    /// Explicitly free the underlying C graph.
352    pub fn close(&mut self) {
353        if !self.g.is_null() {
354            unsafe { ffi::dress_free_graph(self.g); }
355            self.g = std::ptr::null_mut();
356        }
357    }
358}
359
360impl Drop for DRESS {
361    fn drop(&mut self) {
362        self.close();
363    }
364}
365
366/// Result of the Δ^k-DRESS fitting procedure.
367#[derive(Debug, Clone)]
368pub struct DeltaDressResult {
369    /// Sorted exact histogram entries as `(value, count)` pairs.
370    pub histogram: Vec<HistogramEntry>,
371    /// Per-subgraph edge values, row-major C(N,k) × E.
372    /// `NaN` marks edges removed in a given subgraph.
373    /// `None` when `keep_multisets` is `false`.
374    pub multisets: Option<Vec<f64>>,
375    /// Number of subgraphs: C(N,k).
376    pub num_subgraphs: i64,
377}
378
379/// Result of the ∇^k-DRESS fitting procedure.
380#[derive(Debug, Clone)]
381pub struct NablaDressResult {
382    /// Sorted exact histogram entries as `(value, count)` pairs.
383    pub histogram: Vec<HistogramEntry>,
384    /// Per-tuple edge values, row-major.
385    /// `None` when `keep_multisets` is `false`.
386    pub multisets: Option<Vec<f64>>,
387    /// Number of tuples.
388    pub num_tuples: i64,
389}
390
391impl fmt::Display for DeltaDressResult {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        let total: i64 = self.histogram.iter().map(|entry| entry.count).sum();
394        write!(
395            f,
396            "DeltaDressResult(histogram_entries={}, total_values={})",
397            self.histogram.len(), total,
398        )
399    }
400}
401
402impl fmt::Display for NablaDressResult {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        let total: i64 = self.histogram.iter().map(|entry| entry.count).sum();
405        write!(
406            f,
407            "NablaDressResult(histogram_entries={}, total_values={})",
408            self.histogram.len(), total,
409        )
410    }
411}
412
413/// Errors that can occur when building or fitting a DRESS graph.
414#[derive(Debug)]
415pub enum DressError {
416    /// Mismatched array lengths.
417    LengthMismatch(String),
418    /// The C library returned a null pointer.
419    InitFailed,
420}
421
422impl fmt::Display for DressError {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        match self {
425            Self::LengthMismatch(msg) => write!(f, "length mismatch: {msg}"),
426            Self::InitFailed => write!(f, "dress_init_graph returned NULL"),
427        }
428    }
429}
430
431impl std::error::Error for DressError {}
432
433// ── One-shot free functions ─────────────────────────────────────────
434
435/// One-shot DRESS fit: build graph, fit, return results, free graph.
436pub fn fit(
437    n: i32, sources: Vec<i32>, targets: Vec<i32>,
438    weights: Option<Vec<f64>>, vertex_weights: Option<Vec<f64>>,
439    variant: Variant, precompute: bool,
440    max_iterations: i32, epsilon: f64,
441) -> Result<DressResult, DressError> {
442    let mut g = DRESS::new(n, sources, targets, weights, vertex_weights, variant, precompute)?;
443    let (iterations, delta) = g.fit(max_iterations, epsilon);
444    let mut r = g.result();
445    r.iterations = iterations;
446    r.delta = delta;
447    Ok(r)
448}
449
450/// One-shot Δ^k-DRESS: build graph, run delta fit, return results, free graph.
451pub fn delta_fit(
452    n: i32, sources: Vec<i32>, targets: Vec<i32>,
453    weights: Option<Vec<f64>>, vertex_weights: Option<Vec<f64>>,
454    variant: Variant, precompute: bool,
455    k: i32, max_iterations: i32, epsilon: f64,
456    n_samples: i32, seed: u32,
457    keep_multisets: bool, compute_histogram: bool,
458) -> Result<DeltaDressResult, DressError> {
459    let g = DRESS::new(n, sources, targets, weights, vertex_weights, variant, precompute)?;
460    g.delta_fit(k, max_iterations, epsilon, n_samples, seed,
461                keep_multisets, compute_histogram)
462}
463
464/// One-shot ∇^k-DRESS: build graph, run nabla fit, return results, free graph.
465pub fn nabla_fit(
466    n: i32, sources: Vec<i32>, targets: Vec<i32>,
467    weights: Option<Vec<f64>>, vertex_weights: Option<Vec<f64>>,
468    variant: Variant, precompute: bool,
469    k: i32, max_iterations: i32, epsilon: f64,
470    n_samples: i32, seed: u32,
471    keep_multisets: bool, compute_histogram: bool,
472) -> Result<NablaDressResult, DressError> {
473    let g = DRESS::new(n, sources, targets, weights, vertex_weights, variant, precompute)?;
474    g.nabla_fit(k, max_iterations, epsilon, n_samples, seed,
475                keep_multisets, compute_histogram)
476}
477
478// ── Internal helpers ────────────────────────────────────────────────
479
480/// Allocate a C-compatible (malloc'd) copy of an i32 slice.
481pub(crate) unsafe fn libc_malloc_copy_i32(data: &[i32]) -> *mut c_int {
482    let bytes = data.len() * std::mem::size_of::<c_int>();
483    let ptr = libc::malloc(bytes) as *mut c_int;
484    assert!(!ptr.is_null(), "malloc failed");
485    std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
486    ptr
487}
488
489/// Allocate a C-compatible (malloc'd) copy of an f64 slice.
490pub(crate) unsafe fn libc_malloc_copy_f64(data: &[f64]) -> *mut c_double {
491    let bytes = data.len() * std::mem::size_of::<c_double>();
492    let ptr = libc::malloc(bytes) as *mut c_double;
493    assert!(!ptr.is_null(), "malloc failed");
494    std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
495    ptr
496}
497
498pub(crate) unsafe fn histogram_from_raw(
499    data: *mut HistogramEntry,
500    hist_size: c_int,
501) -> Vec<HistogramEntry> {
502    if !data.is_null() && hist_size > 0 {
503        std::slice::from_raw_parts(data, hist_size as usize).to_vec()
504    } else {
505        vec![]
506    }
507}
508
509// We use libc::malloc — pull in the libc crate minimally via extern.
510pub(crate) mod libc {
511    extern "C" {
512        pub fn malloc(size: usize) -> *mut u8;
513    }
514}