Skip to main content

webgraph_algo/llp/
preds.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Tommaso Fontana
3 * SPDX-FileCopyrightText: 2024 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8//! Predicates implementing stopping conditions.
9//!
10//! The implementation of [layered label propagation](super) requires a
11//! [predicate](Predicate) to stop the algorithm. This module provides a few
12//! such predicates: they evaluate to true if the updates should be stopped.
13//!
14//! You can combine the predicates using the `and` and `or` methods provided by
15//! the [`Predicate`] trait.
16//!
17//! # Examples
18//! ```
19//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! use predicates::prelude::*;
21//! use webgraph_algo::llp::preds::{MinGain, MaxUpdates};
22//!
23//! let mut predicate = MinGain::try_from(0.001)?.boxed();
24//! predicate = predicate.or(MaxUpdates::from(100)).boxed();
25//! #     Ok(())
26//! # }
27//! ```
28
29use anyhow::ensure;
30use predicates::{Predicate, reflection::PredicateReflection};
31use std::fmt::Display;
32
33#[doc(hidden)]
34/// This structure is passed to stopping predicates to provide the information
35/// that is needed to evaluate them.
36#[derive(Debug)]
37pub struct PredParams {
38    pub num_nodes: usize,
39    pub num_arcs: u64,
40    pub gain: f64,
41    pub avg_gain_impr: f64,
42    pub modified: usize,
43    pub update: usize,
44}
45
46/// Stop after at most the provided number of updates for a given ɣ.
47#[derive(Debug, Clone)]
48pub struct MaxUpdates {
49    max_updates: usize,
50}
51
52impl MaxUpdates {
53    pub const DEFAULT_MAX_UPDATES: usize = usize::MAX;
54}
55
56impl From<Option<usize>> for MaxUpdates {
57    fn from(max_updates: Option<usize>) -> Self {
58        match max_updates {
59            Some(max_updates) => MaxUpdates { max_updates },
60            None => Self::default(),
61        }
62    }
63}
64
65impl From<usize> for MaxUpdates {
66    fn from(max_updates: usize) -> Self {
67        Some(max_updates).into()
68    }
69}
70
71impl Default for MaxUpdates {
72    fn default() -> Self {
73        Self::from(Self::DEFAULT_MAX_UPDATES)
74    }
75}
76
77impl Display for MaxUpdates {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.write_fmt(format_args!("(max updates: {})", self.max_updates))
80    }
81}
82
83impl PredicateReflection for MaxUpdates {}
84impl Predicate<PredParams> for MaxUpdates {
85    fn eval(&self, pred_params: &PredParams) -> bool {
86        pred_params.update + 1 >= self.max_updates
87    }
88}
89
90#[derive(Debug, Clone)]
91/// Stop if the gain of the objective function is below the given threshold.
92///
93/// The [default threshold](Self::DEFAULT_THRESHOLD) is the same as that
94/// of the Java implementation.
95pub struct MinGain {
96    threshold: f64,
97}
98
99impl MinGain {
100    pub const DEFAULT_THRESHOLD: f64 = 0.001;
101}
102
103impl TryFrom<Option<f64>> for MinGain {
104    type Error = anyhow::Error;
105    fn try_from(threshold: Option<f64>) -> anyhow::Result<Self> {
106        Ok(match threshold {
107            Some(threshold) => {
108                ensure!(!threshold.is_nan());
109                ensure!(threshold >= 0.0, "The threshold must be nonnegative");
110                MinGain { threshold }
111            }
112            None => Self::default(),
113        })
114    }
115}
116
117impl TryFrom<f64> for MinGain {
118    type Error = anyhow::Error;
119    fn try_from(threshold: f64) -> anyhow::Result<Self> {
120        Some(threshold).try_into()
121    }
122}
123
124impl Default for MinGain {
125    fn default() -> Self {
126        Self::try_from(Self::DEFAULT_THRESHOLD).unwrap()
127    }
128}
129
130impl Display for MinGain {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.write_fmt(format_args!("(min gain: {})", self.threshold))
133    }
134}
135
136impl PredicateReflection for MinGain {}
137impl Predicate<PredParams> for MinGain {
138    fn eval(&self, pred_params: &PredParams) -> bool {
139        pred_params.gain <= self.threshold
140    }
141}
142
143#[derive(Debug, Clone)]
144/// Stop if the average improvement of the gain of the objective function on
145/// a window of ten updates is below the given threshold.
146///
147/// This criterion is a second-order version of [`MinGain`]. It is very useful
148/// to avoid a large number of iteration which do not improve the objective
149/// function significantly.
150pub struct MinAvgImprov {
151    threshold: f64,
152}
153
154impl MinAvgImprov {
155    pub const DEFAULT_THRESHOLD: f64 = 0.1;
156}
157
158impl TryFrom<Option<f64>> for MinAvgImprov {
159    type Error = anyhow::Error;
160    fn try_from(threshold: Option<f64>) -> anyhow::Result<Self> {
161        Ok(match threshold {
162            Some(threshold) => {
163                ensure!(!threshold.is_nan());
164                MinAvgImprov { threshold }
165            }
166            None => Self::default(),
167        })
168    }
169}
170
171impl TryFrom<f64> for MinAvgImprov {
172    type Error = anyhow::Error;
173    fn try_from(threshold: f64) -> anyhow::Result<Self> {
174        Some(threshold).try_into()
175    }
176}
177
178impl Default for MinAvgImprov {
179    fn default() -> Self {
180        Self::try_from(Self::DEFAULT_THRESHOLD).unwrap()
181    }
182}
183
184impl Display for MinAvgImprov {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.write_fmt(format_args!(
187            "(min avg gain improvement: {})",
188            self.threshold
189        ))
190    }
191}
192
193impl PredicateReflection for MinAvgImprov {}
194impl Predicate<PredParams> for MinAvgImprov {
195    fn eval(&self, pred_params: &PredParams) -> bool {
196        pred_params.avg_gain_impr <= self.threshold
197    }
198}
199
200#[derive(Debug, Clone, Default)]
201/// Stop after the number of modified nodes falls below the square root of the
202/// number of nodes.
203pub struct MinModified {}
204
205impl Display for MinModified {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        f.write_str("(min modified: √n)")
208    }
209}
210
211impl PredicateReflection for MinModified {}
212impl Predicate<PredParams> for MinModified {
213    fn eval(&self, pred_params: &PredParams) -> bool {
214        (pred_params.modified as f64) <= (pred_params.num_nodes as f64).sqrt()
215    }
216}
217
218#[derive(Debug, Clone, Default)]
219/// Stop after the number of modified nodes falls below
220/// a specified percentage of the number of nodes.
221pub struct PercModified {
222    threshold: f64,
223}
224
225impl Display for PercModified {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.write_fmt(format_args!("(min modified: {}%)", self.threshold * 100.0))
228    }
229}
230
231impl TryFrom<f64> for PercModified {
232    type Error = anyhow::Error;
233    fn try_from(threshold: f64) -> anyhow::Result<Self> {
234        ensure!(
235            threshold >= 0.0,
236            "The percent threshold must be nonnegative"
237        );
238        ensure!(
239            threshold <= 100.0,
240            "The percent threshold must be at most 100"
241        );
242        Ok(PercModified {
243            threshold: threshold / 100.0,
244        })
245    }
246}
247
248impl PredicateReflection for PercModified {}
249impl Predicate<PredParams> for PercModified {
250    fn eval(&self, pred_params: &PredParams) -> bool {
251        (pred_params.modified as f64) <= (pred_params.num_nodes as f64) * self.threshold
252    }
253}