rustkernel_graph/
messages.rs

1//! Ring message types for Graph Analytics kernels.
2//!
3//! This module defines request/response message types for GPU-native
4//! persistent actor communication.
5
6use crate::motif::MotifResult;
7use crate::types::{CentralityResult, CommunityResult, CsrGraph, SimilarityResult};
8use rustkernel_core::messages::CorrelationId;
9use serde::{Deserialize, Serialize};
10
11// ============================================================================
12// PageRank Messages
13// ============================================================================
14
15/// PageRank operation type.
16#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
17pub enum PageRankOp {
18    /// Query the current PageRank score for a node.
19    Query {
20        /// Node to query.
21        node_id: u64,
22    },
23    /// Perform one iteration of PageRank.
24    Iterate,
25    /// Initialize with a new graph.
26    Initialize,
27    /// Reset all scores to initial values.
28    Reset,
29    /// Run until convergence with threshold.
30    ConvergeUntil {
31        /// Convergence threshold.
32        threshold: f64,
33        /// Maximum iterations.
34        max_iterations: u32,
35    },
36}
37
38/// PageRank request message.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct PageRankRequest {
41    /// Correlation ID for request-response pairing.
42    pub correlation_id: CorrelationId,
43    /// The operation to perform.
44    pub operation: PageRankOp,
45    /// Graph data (for Initialize operation).
46    pub graph: Option<CsrGraph>,
47    /// Damping factor (default: 0.85).
48    pub damping: Option<f32>,
49}
50
51impl PageRankRequest {
52    /// Create a query request for a specific node.
53    pub fn query(node_id: u64) -> Self {
54        Self {
55            correlation_id: CorrelationId::new(),
56            operation: PageRankOp::Query { node_id },
57            graph: None,
58            damping: None,
59        }
60    }
61
62    /// Create an iterate request.
63    pub fn iterate() -> Self {
64        Self {
65            correlation_id: CorrelationId::new(),
66            operation: PageRankOp::Iterate,
67            graph: None,
68            damping: None,
69        }
70    }
71
72    /// Create an initialize request with graph data.
73    pub fn initialize(graph: CsrGraph, damping: f32) -> Self {
74        Self {
75            correlation_id: CorrelationId::new(),
76            operation: PageRankOp::Initialize,
77            graph: Some(graph),
78            damping: Some(damping),
79        }
80    }
81
82    /// Create a converge request.
83    pub fn converge(threshold: f64, max_iterations: u32) -> Self {
84        Self {
85            correlation_id: CorrelationId::new(),
86            operation: PageRankOp::ConvergeUntil {
87                threshold,
88                max_iterations,
89            },
90            graph: None,
91            damping: None,
92        }
93    }
94}
95
96/// PageRank response message.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct PageRankResponse {
99    /// Correlation ID matching the request.
100    pub correlation_id: CorrelationId,
101    /// Score for the queried node (Query operation).
102    pub score: Option<f64>,
103    /// Whether the algorithm has converged.
104    pub converged: bool,
105    /// Current iteration count.
106    pub iteration: u32,
107    /// Full centrality result (for converged operations).
108    pub result: Option<CentralityResult>,
109    /// Error message if operation failed.
110    pub error: Option<String>,
111}
112
113impl PageRankResponse {
114    /// Create a successful query response.
115    pub fn score(correlation_id: CorrelationId, score: f64, iteration: u32) -> Self {
116        Self {
117            correlation_id,
118            score: Some(score),
119            converged: false,
120            iteration,
121            result: None,
122            error: None,
123        }
124    }
125
126    /// Create a convergence response with full result.
127    pub fn converged(
128        correlation_id: CorrelationId,
129        result: CentralityResult,
130        iteration: u32,
131    ) -> Self {
132        Self {
133            correlation_id,
134            score: None,
135            converged: true,
136            iteration,
137            result: Some(result),
138            error: None,
139        }
140    }
141
142    /// Create an error response.
143    pub fn error(correlation_id: CorrelationId, error: impl Into<String>) -> Self {
144        Self {
145            correlation_id,
146            score: None,
147            converged: false,
148            iteration: 0,
149            result: None,
150            error: Some(error.into()),
151        }
152    }
153}
154
155// ============================================================================
156// Centrality Batch Input/Output Types
157// ============================================================================
158
159/// Input for batch centrality computation.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct CentralityInput {
162    /// The graph to analyze.
163    pub graph: CsrGraph,
164    /// Whether to normalize the results.
165    pub normalize: bool,
166    /// Maximum iterations (for iterative algorithms).
167    pub max_iterations: Option<u32>,
168    /// Convergence tolerance (for iterative algorithms).
169    pub tolerance: Option<f64>,
170    /// Algorithm-specific parameters.
171    pub params: CentralityParams,
172}
173
174/// Algorithm-specific parameters for centrality computation.
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub enum CentralityParams {
177    /// PageRank parameters.
178    PageRank {
179        /// Damping factor (typically 0.85).
180        damping: f32,
181    },
182    /// Degree centrality (no extra params).
183    Degree,
184    /// Betweenness centrality (no extra params).
185    Betweenness,
186    /// Closeness centrality parameters.
187    Closeness {
188        /// Use harmonic centrality variant.
189        harmonic: bool,
190    },
191    /// Eigenvector centrality (no extra params).
192    Eigenvector,
193    /// Katz centrality parameters.
194    Katz {
195        /// Attenuation factor.
196        alpha: f64,
197        /// Bias term.
198        beta: f64,
199    },
200}
201
202impl Default for CentralityInput {
203    fn default() -> Self {
204        Self {
205            graph: CsrGraph::empty(),
206            normalize: true,
207            max_iterations: Some(100),
208            tolerance: Some(1e-6),
209            params: CentralityParams::Degree,
210        }
211    }
212}
213
214/// Output from batch centrality computation.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct CentralityOutput {
217    /// The centrality result.
218    pub result: CentralityResult,
219    /// Computation time in microseconds.
220    pub compute_time_us: u64,
221}
222
223// ============================================================================
224// Community Detection Messages
225// ============================================================================
226
227/// Community detection algorithm.
228#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
229pub enum CommunityAlgorithm {
230    /// Louvain algorithm (modularity optimization).
231    Louvain,
232    /// Modularity score calculation.
233    ModularityScore,
234}
235
236/// Input for community detection.
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct CommunityInput {
239    /// The graph to analyze.
240    pub graph: CsrGraph,
241    /// Algorithm to use.
242    pub algorithm: CommunityAlgorithm,
243    /// Resolution parameter for Louvain (default: 1.0).
244    pub resolution: f64,
245}
246
247/// Output from community detection.
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct CommunityOutput {
250    /// The community detection result.
251    pub result: CommunityResult,
252    /// Computation time in microseconds.
253    pub compute_time_us: u64,
254}
255
256// ============================================================================
257// Similarity Messages
258// ============================================================================
259
260/// Similarity metric type.
261#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
262pub enum SimilarityMetric {
263    /// Jaccard similarity coefficient.
264    Jaccard,
265    /// Cosine similarity.
266    Cosine,
267    /// Adamic-Adar index.
268    AdamicAdar,
269}
270
271/// Input for similarity computation.
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct SimilarityInput {
274    /// The graph to analyze.
275    pub graph: CsrGraph,
276    /// Similarity metric to use.
277    pub metric: SimilarityMetric,
278    /// Node pairs to compute similarity for (if None, compute all pairs).
279    pub node_pairs: Option<Vec<(u64, u64)>>,
280}
281
282/// Output from similarity computation.
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct SimilarityOutput {
285    /// The similarity result.
286    pub result: SimilarityResult,
287    /// Computation time in microseconds.
288    pub compute_time_us: u64,
289}
290
291// ============================================================================
292// Motif Detection Messages
293// ============================================================================
294
295/// Input for motif/triangle detection.
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct MotifInput {
298    /// The graph to analyze.
299    pub graph: CsrGraph,
300    /// Motif size (3 for triangles).
301    pub motif_size: usize,
302    /// Whether to enumerate all instances.
303    pub enumerate: bool,
304}
305
306/// Output from motif detection.
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct MotifOutput {
309    /// The motif detection result.
310    pub result: MotifResult,
311    /// Computation time in microseconds.
312    pub compute_time_us: u64,
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_pagerank_request_query() {
321        let req = PageRankRequest::query(42);
322        assert!(matches!(req.operation, PageRankOp::Query { node_id: 42 }));
323    }
324
325    #[test]
326    fn test_pagerank_request_converge() {
327        let req = PageRankRequest::converge(1e-6, 100);
328        assert!(matches!(
329            req.operation,
330            PageRankOp::ConvergeUntil {
331                threshold: _,
332                max_iterations: 100
333            }
334        ));
335    }
336
337    #[test]
338    fn test_pagerank_response_score() {
339        let cid = CorrelationId::new();
340        let resp = PageRankResponse::score(cid, 0.5, 10);
341        assert_eq!(resp.score, Some(0.5));
342        assert!(!resp.converged);
343    }
344}