Skip to main content

oxicuda_recsys/
lib.rs

1//! `oxicuda-recsys` — Recommender system primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of collaborative filtering, neural recommendation models,
4//! and graph-based recommenders, suitable for CPU simulation and PTX kernel generation
5//! for GPU execution.
6//!
7//! # Architecture
8//!
9//! ```text
10//! oxicuda-recsys
11//! ├── factorization/  — ALS, BPR, NMF matrix factorization
12//! ├── ncf/            — Neural Collaborative Filtering
13//! ├── two_tower/      — Two-Tower retrieval model
14//! ├── deepfm/         — DeepFM and Wide & Deep models
15//! ├── sequential/     — SASRec self-attention sequential model
16//! ├── graph_recsys/   — LightGCN graph-based recommendation
17//! ├── multitask/      — MMoE / PLE multi-task learning
18//! ├── sampling/       — Uniform negative sampling
19//! ├── metrics/        — NDCG@k, Precision@k ranking metrics
20//! ├── handle          — LcgRng (deterministic PRNG)
21//! ├── error           — RecSysError / RecSysResult
22//! └── ptx_kernels     — GPU PTX kernel strings (7 kernels × 6 SM versions)
23//! ```
24
25pub mod error;
26pub mod handle;
27pub mod ptx_kernels;
28
29pub mod deepfm;
30pub mod dlrm;
31pub mod factorization;
32pub mod fibinet;
33pub mod graph_recsys;
34pub mod metrics;
35pub mod multitask;
36pub mod ncf;
37pub mod ranking;
38pub mod sampling;
39pub mod sequential;
40pub mod two_tower;
41
42pub use crate::deepfm::dcn::{CrossKind, Dcn, DcnConfig};
43pub use crate::dlrm::{Dlrm, DlrmConfig};
44pub use crate::factorization::ffm::{Ffm, FfmConfig, FfmEntry};
45pub use crate::factorization::fism::{Fism, FismConfig};
46pub use crate::factorization::ials::{Ials, IalsConfig};
47pub use crate::fibinet::{BilinearType, Fibinet, FibinetConfig};
48pub use crate::graph_recsys::graphrec::GraphRec;
49pub use crate::ranking::fairness_ranking::{FairnessRanker, FairnessRankerConfig};
50pub use crate::sequential::cl4srec::{Cl4sRec, Cl4sRecConfig};
51
52#[cfg(test)]
53mod e2e_tests {
54    use std::collections::{BTreeSet, HashSet};
55
56    use crate::{
57        deepfm::{deepfm::DeepFm, wide_deep::WideDeep},
58        factorization::{als::Als, bpr::Bpr, nmf::Nmf},
59        graph_recsys::lightgcn::LightGcn,
60        handle::LcgRng,
61        metrics::recsys_metrics::{ndcg_at_k, precision_at_k},
62        ncf::ncf::Ncf,
63        ptx_kernels,
64        sampling::uniform_neg::UniformNegSampler,
65        sequential::sasrec::SasRec,
66        two_tower::two_tower::TwoTower,
67    };
68
69    #[test]
70    fn als_score_finite() {
71        let mut rng = LcgRng::new(42);
72        let mut model = Als::new(5, 8, 4, 0.01, &mut rng).expect("new should succeed");
73        let interactions = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 1.0), (2, 3, 3.0)];
74        model.fit(&interactions, 2).expect("fit should succeed");
75        let s = model.score(0, 0).expect("score should succeed");
76        assert!(s.is_finite(), "ALS score must be finite, got {s}");
77    }
78
79    #[test]
80    fn bpr_step_loss_finite() {
81        let mut rng = LcgRng::new(7);
82        let mut model = Bpr::new(5, 10, 4, 0.01, 0.001, &mut rng).expect("new should succeed");
83        let triplets = vec![(0, 1, 3), (1, 2, 4), (2, 0, 5)];
84        let loss = model.train_step(&triplets);
85        assert!(loss.is_finite(), "BPR loss must be finite, got {loss}");
86    }
87
88    #[test]
89    fn nmf_fit_no_error() {
90        let mut rng = LcgRng::new(13);
91        let data = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 2, 3.0), (2, 0, 1.5)];
92        let model = Nmf::fit(&data, 3, 4, 3, 5, &mut rng).expect("fit should succeed");
93        let s = model.score(0, 0).expect("score should succeed");
94        assert!(s.is_finite(), "NMF score must be finite, got {s}");
95    }
96
97    #[test]
98    fn ncf_forward_in_0_1() {
99        let mut rng = LcgRng::new(99);
100        let model = Ncf::new(10, 20, 8, vec![16, 8], &mut rng).expect("new should succeed");
101        let pred = model.forward(0, 0).expect("forward should succeed");
102        assert!(
103            (0.0..=1.0).contains(&pred),
104            "NCF output {pred} not in [0,1]"
105        );
106    }
107
108    #[test]
109    fn two_tower_score_finite() {
110        let mut rng = LcgRng::new(11);
111        let model = TwoTower::new(8, 16, 8, 2, &mut rng).expect("new should succeed");
112        let user_x: Vec<f32> = (0..8).map(|_| rng.next_f32()).collect();
113        let item_x: Vec<f32> = (0..8).map(|_| rng.next_f32()).collect();
114        let score = model.score(&user_x, &item_x).expect("score should succeed");
115        assert!(
116            score.is_finite(),
117            "TwoTower score must be finite, got {score}"
118        );
119    }
120
121    #[test]
122    fn deepfm_forward_in_0_1() {
123        let mut rng = LcgRng::new(55);
124        let field_dims = vec![10, 20, 5];
125        let model = DeepFm::new(field_dims, 8, &[32, 16], &mut rng).expect("new should succeed");
126        let field_ids = vec![0usize, 3, 2];
127        let pred = model.forward(&field_ids).expect("forward should succeed");
128        assert!(
129            (0.0..=1.0).contains(&pred),
130            "DeepFM output {pred} not in [0,1]"
131        );
132    }
133
134    #[test]
135    fn wide_deep_forward_in_0_1() {
136        let mut rng = LcgRng::new(33);
137        let model = WideDeep::new(16, &[32, 16], &mut rng).expect("new should succeed");
138        let x: Vec<f32> = (0..16).map(|_| rng.next_f32()).collect();
139        let pred = model.forward(&x).expect("forward should succeed");
140        assert!(
141            (0.0..=1.0).contains(&pred),
142            "WideDeep output {pred} not in [0,1]"
143        );
144    }
145
146    #[test]
147    fn sasrec_forward_finite() {
148        let mut rng = LcgRng::new(77);
149        let model = SasRec::new(50, 16, 2, 2, 20, &mut rng).expect("new should succeed");
150        let item_ids = vec![0usize, 5, 12, 3];
151        let logits = model.forward(&item_ids).expect("forward should succeed");
152        assert_eq!(logits.len(), 50);
153        assert!(
154            logits.iter().all(|v| v.is_finite()),
155            "SASRec logits must all be finite"
156        );
157    }
158
159    #[test]
160    fn lightgcn_score_finite() {
161        let mut rng = LcgRng::new(21);
162        let mut model = LightGcn::new(5, 8, 8, 2, &mut rng).expect("new should succeed");
163        let edges = vec![(0, 0), (0, 1), (1, 2), (2, 3), (3, 4)];
164        model.propagate(&edges).expect("propagate should succeed");
165        let s = model.score(0, 1);
166        assert!(s.is_finite(), "LightGCN score must be finite, got {s}");
167    }
168
169    #[test]
170    fn ndcg_perfect_ranking() {
171        let recommended: Vec<usize> = (0..10).collect();
172        let relevant: HashSet<usize> = (0..5).collect();
173        let ndcg = ndcg_at_k(&recommended, &relevant, 10);
174        let prec = precision_at_k(&recommended, &relevant, 5);
175        assert!(
176            (ndcg - 1.0).abs() < 1e-5,
177            "NDCG for perfect ranking must be 1.0, got {ndcg}"
178        );
179        assert!(
180            (prec - 1.0).abs() < 1e-5,
181            "Precision@5 for perfect ranking must be 1.0, got {prec}"
182        );
183    }
184
185    #[test]
186    fn uniform_neg_not_in_positives() {
187        let mut rng = LcgRng::new(3);
188        let sampler = UniformNegSampler::new(100).expect("new should succeed");
189        let mut positives = BTreeSet::new();
190        positives.insert(0usize);
191        positives.insert(1);
192        positives.insert(2);
193        for _ in 0..50 {
194            let neg = sampler
195                .sample(0, &positives, &mut rng)
196                .expect("sample should succeed");
197            assert!(
198                !positives.contains(&neg),
199                "Sampled negative {neg} must not be in positives"
200            );
201        }
202    }
203
204    #[test]
205    fn ptx_kernels_non_empty_all_sm() {
206        let sm_versions = [75u32, 80, 86, 89, 90, 100];
207        type KernelFn = Box<dyn Fn(u32) -> String>;
208        let generators: Vec<(&str, KernelFn)> = vec![
209            ("als_step", Box::new(ptx_kernels::als_step_ptx)),
210            ("bpr_grad", Box::new(ptx_kernels::bpr_grad_ptx)),
211            (
212                "embedding_lookup",
213                Box::new(ptx_kernels::embedding_lookup_ptx),
214            ),
215            ("dot_score", Box::new(ptx_kernels::dot_score_ptx)),
216            ("softmax_topk", Box::new(ptx_kernels::softmax_topk_ptx)),
217            (
218                "negsample_uniform",
219                Box::new(ptx_kernels::negsample_uniform_ptx),
220            ),
221            (
222                "lightgcn_propagate",
223                Box::new(ptx_kernels::lightgcn_propagate_ptx),
224            ),
225        ];
226        for sm in sm_versions {
227            for (name, kernel_fn) in &generators {
228                let ptx = kernel_fn(sm);
229                assert!(
230                    !ptx.is_empty(),
231                    "PTX kernel {name} for SM {sm} must be non-empty"
232                );
233                assert!(
234                    ptx.contains(&format!("sm_{sm}")),
235                    "PTX kernel {name} for SM {sm} must contain .target sm_{sm}"
236                );
237            }
238        }
239    }
240}