1pub mod error;
26pub mod handle;
27pub mod ptx_kernels;
28
29pub mod deepfm;
30pub mod dlrm;
31pub mod factorization;
32pub mod fibinet;
33pub mod graph_recsys;
34pub mod metrics;
35pub mod multitask;
36pub mod ncf;
37pub mod ranking;
38pub mod sampling;
39pub mod sequential;
40pub mod two_tower;
41
42pub use crate::deepfm::dcn::{CrossKind, Dcn, DcnConfig};
43pub use crate::dlrm::{Dlrm, DlrmConfig};
44pub use crate::factorization::ffm::{Ffm, FfmConfig, FfmEntry};
45pub use crate::factorization::fism::{Fism, FismConfig};
46pub use crate::factorization::ials::{Ials, IalsConfig};
47pub use crate::fibinet::{BilinearType, Fibinet, FibinetConfig};
48pub use crate::graph_recsys::graphrec::GraphRec;
49pub use crate::ranking::fairness_ranking::{FairnessRanker, FairnessRankerConfig};
50pub use crate::sequential::cl4srec::{Cl4sRec, Cl4sRecConfig};
51
52#[cfg(test)]
53mod e2e_tests {
54 use std::collections::{BTreeSet, HashSet};
55
56 use crate::{
57 deepfm::{deepfm::DeepFm, wide_deep::WideDeep},
58 factorization::{als::Als, bpr::Bpr, nmf::Nmf},
59 graph_recsys::lightgcn::LightGcn,
60 handle::LcgRng,
61 metrics::recsys_metrics::{ndcg_at_k, precision_at_k},
62 ncf::ncf::Ncf,
63 ptx_kernels,
64 sampling::uniform_neg::UniformNegSampler,
65 sequential::sasrec::SasRec,
66 two_tower::two_tower::TwoTower,
67 };
68
69 #[test]
70 fn als_score_finite() {
71 let mut rng = LcgRng::new(42);
72 let mut model = Als::new(5, 8, 4, 0.01, &mut rng).expect("new should succeed");
73 let interactions = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 1.0), (2, 3, 3.0)];
74 model.fit(&interactions, 2).expect("fit should succeed");
75 let s = model.score(0, 0).expect("score should succeed");
76 assert!(s.is_finite(), "ALS score must be finite, got {s}");
77 }
78
79 #[test]
80 fn bpr_step_loss_finite() {
81 let mut rng = LcgRng::new(7);
82 let mut model = Bpr::new(5, 10, 4, 0.01, 0.001, &mut rng).expect("new should succeed");
83 let triplets = vec![(0, 1, 3), (1, 2, 4), (2, 0, 5)];
84 let loss = model.train_step(&triplets);
85 assert!(loss.is_finite(), "BPR loss must be finite, got {loss}");
86 }
87
88 #[test]
89 fn nmf_fit_no_error() {
90 let mut rng = LcgRng::new(13);
91 let data = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 2, 3.0), (2, 0, 1.5)];
92 let model = Nmf::fit(&data, 3, 4, 3, 5, &mut rng).expect("fit should succeed");
93 let s = model.score(0, 0).expect("score should succeed");
94 assert!(s.is_finite(), "NMF score must be finite, got {s}");
95 }
96
97 #[test]
98 fn ncf_forward_in_0_1() {
99 let mut rng = LcgRng::new(99);
100 let model = Ncf::new(10, 20, 8, vec![16, 8], &mut rng).expect("new should succeed");
101 let pred = model.forward(0, 0).expect("forward should succeed");
102 assert!(
103 (0.0..=1.0).contains(&pred),
104 "NCF output {pred} not in [0,1]"
105 );
106 }
107
108 #[test]
109 fn two_tower_score_finite() {
110 let mut rng = LcgRng::new(11);
111 let model = TwoTower::new(8, 16, 8, 2, &mut rng).expect("new should succeed");
112 let user_x: Vec<f32> = (0..8).map(|_| rng.next_f32()).collect();
113 let item_x: Vec<f32> = (0..8).map(|_| rng.next_f32()).collect();
114 let score = model.score(&user_x, &item_x).expect("score should succeed");
115 assert!(
116 score.is_finite(),
117 "TwoTower score must be finite, got {score}"
118 );
119 }
120
121 #[test]
122 fn deepfm_forward_in_0_1() {
123 let mut rng = LcgRng::new(55);
124 let field_dims = vec![10, 20, 5];
125 let model = DeepFm::new(field_dims, 8, &[32, 16], &mut rng).expect("new should succeed");
126 let field_ids = vec![0usize, 3, 2];
127 let pred = model.forward(&field_ids).expect("forward should succeed");
128 assert!(
129 (0.0..=1.0).contains(&pred),
130 "DeepFM output {pred} not in [0,1]"
131 );
132 }
133
134 #[test]
135 fn wide_deep_forward_in_0_1() {
136 let mut rng = LcgRng::new(33);
137 let model = WideDeep::new(16, &[32, 16], &mut rng).expect("new should succeed");
138 let x: Vec<f32> = (0..16).map(|_| rng.next_f32()).collect();
139 let pred = model.forward(&x).expect("forward should succeed");
140 assert!(
141 (0.0..=1.0).contains(&pred),
142 "WideDeep output {pred} not in [0,1]"
143 );
144 }
145
146 #[test]
147 fn sasrec_forward_finite() {
148 let mut rng = LcgRng::new(77);
149 let model = SasRec::new(50, 16, 2, 2, 20, &mut rng).expect("new should succeed");
150 let item_ids = vec![0usize, 5, 12, 3];
151 let logits = model.forward(&item_ids).expect("forward should succeed");
152 assert_eq!(logits.len(), 50);
153 assert!(
154 logits.iter().all(|v| v.is_finite()),
155 "SASRec logits must all be finite"
156 );
157 }
158
159 #[test]
160 fn lightgcn_score_finite() {
161 let mut rng = LcgRng::new(21);
162 let mut model = LightGcn::new(5, 8, 8, 2, &mut rng).expect("new should succeed");
163 let edges = vec![(0, 0), (0, 1), (1, 2), (2, 3), (3, 4)];
164 model.propagate(&edges).expect("propagate should succeed");
165 let s = model.score(0, 1);
166 assert!(s.is_finite(), "LightGCN score must be finite, got {s}");
167 }
168
169 #[test]
170 fn ndcg_perfect_ranking() {
171 let recommended: Vec<usize> = (0..10).collect();
172 let relevant: HashSet<usize> = (0..5).collect();
173 let ndcg = ndcg_at_k(&recommended, &relevant, 10);
174 let prec = precision_at_k(&recommended, &relevant, 5);
175 assert!(
176 (ndcg - 1.0).abs() < 1e-5,
177 "NDCG for perfect ranking must be 1.0, got {ndcg}"
178 );
179 assert!(
180 (prec - 1.0).abs() < 1e-5,
181 "Precision@5 for perfect ranking must be 1.0, got {prec}"
182 );
183 }
184
185 #[test]
186 fn uniform_neg_not_in_positives() {
187 let mut rng = LcgRng::new(3);
188 let sampler = UniformNegSampler::new(100).expect("new should succeed");
189 let mut positives = BTreeSet::new();
190 positives.insert(0usize);
191 positives.insert(1);
192 positives.insert(2);
193 for _ in 0..50 {
194 let neg = sampler
195 .sample(0, &positives, &mut rng)
196 .expect("sample should succeed");
197 assert!(
198 !positives.contains(&neg),
199 "Sampled negative {neg} must not be in positives"
200 );
201 }
202 }
203
204 #[test]
205 fn ptx_kernels_non_empty_all_sm() {
206 let sm_versions = [75u32, 80, 86, 89, 90, 100];
207 type KernelFn = Box<dyn Fn(u32) -> String>;
208 let generators: Vec<(&str, KernelFn)> = vec![
209 ("als_step", Box::new(ptx_kernels::als_step_ptx)),
210 ("bpr_grad", Box::new(ptx_kernels::bpr_grad_ptx)),
211 (
212 "embedding_lookup",
213 Box::new(ptx_kernels::embedding_lookup_ptx),
214 ),
215 ("dot_score", Box::new(ptx_kernels::dot_score_ptx)),
216 ("softmax_topk", Box::new(ptx_kernels::softmax_topk_ptx)),
217 (
218 "negsample_uniform",
219 Box::new(ptx_kernels::negsample_uniform_ptx),
220 ),
221 (
222 "lightgcn_propagate",
223 Box::new(ptx_kernels::lightgcn_propagate_ptx),
224 ),
225 ];
226 for sm in sm_versions {
227 for (name, kernel_fn) in &generators {
228 let ptx = kernel_fn(sm);
229 assert!(
230 !ptx.is_empty(),
231 "PTX kernel {name} for SM {sm} must be non-empty"
232 );
233 assert!(
234 ptx.contains(&format!("sm_{sm}")),
235 "PTX kernel {name} for SM {sm} must contain .target sm_{sm}"
236 );
237 }
238 }
239 }
240}