1#[cfg(feature = "python")]
53mod python;
54
55#[cfg(feature = "cli")]
56pub mod cli;
57
58pub mod model;
59pub mod pipeline;
60
61use anyhow::Context;
62pub use model::MODEL_VERSION;
63pub use pipeline::{AggregationMethod, Classification, Prediction, UnifiedPrediction};
64
65use crate::model::{CHUNK_CLASSIFICATION_THRESHOLD, CLASSIFICATION_THRESHOLD, MODEL};
66
67pub struct Predictor {
87 threshold: f32,
88 agg_method: AggregationMethod,
89}
90
91impl Predictor {
92 #[must_use]
94 pub fn new() -> Self {
95 Self {
96 threshold: CLASSIFICATION_THRESHOLD,
97 agg_method: AggregationMethod::WeightedMean(CHUNK_CLASSIFICATION_THRESHOLD),
98 }
99 }
100
101 #[must_use]
107 pub fn with_threshold(mut self, threshold: f32) -> Self {
108 self.threshold = threshold;
109 self
110 }
111
112 #[must_use]
117 pub fn with_aggregation_method(mut self, method: AggregationMethod) -> Self {
118 self.agg_method = method;
119 self
120 }
121
122 #[must_use]
124 pub fn threshold(&self) -> f32 {
125 self.threshold
126 }
127
128 pub fn predict<T: AsRef<str>>(&self, text: T) -> anyhow::Result<UnifiedPrediction> {
132 pipeline::predict(&MODEL, text.as_ref(), self.agg_method)
133 .with_context(|| "Failed to predict probabilities for the given text")
134 }
135
136 pub fn predict_batch<T: AsRef<str> + Sync>(
140 &self,
141 texts: &[T],
142 ) -> anyhow::Result<Vec<UnifiedPrediction>> {
143 pipeline::predict_batch(&MODEL, texts, self.agg_method)
144 .with_context(|| "Failed to predict probabilities for the given texts")
145 }
146
147 pub fn classify<T: AsRef<str>>(&self, text: T) -> anyhow::Result<Classification> {
151 self.predict(text)
152 .map(|pred| pred.classification(self.threshold))
153 }
154
155 pub fn classify_batch<T: AsRef<str> + Sync>(
159 &self,
160 texts: &[T],
161 ) -> anyhow::Result<Vec<Classification>> {
162 self.predict_batch(texts).map(|preds| {
163 preds
164 .into_iter()
165 .map(|pred| pred.classification(self.threshold))
166 .collect()
167 })
168 }
169}
170
171impl Default for Predictor {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_predict_probabilities() {
183 let predictor = Predictor::new();
184 let prediction = predictor
185 .predict("This is a test text")
186 .expect("Prediction should succeed");
187 assert!(prediction.chunk_predictions.iter().all(|p| {
188 p.human_probability() >= 0.0
189 && p.human_probability() <= 1.0
190 && p.ai_probability() >= 0.0
191 && p.ai_probability() <= 1.0
192 && (p.human_probability() + p.ai_probability() - 1.0).abs() < 0.001
193 }));
194
195 assert!(prediction.prediction.human_probability() >= 0.0);
196 assert!(prediction.prediction.human_probability() <= 1.0);
197 assert!(prediction.prediction.ai_probability() >= 0.0);
198 assert!(prediction.prediction.ai_probability() <= 1.0);
199 assert!(
200 (prediction.prediction.human_probability() + prediction.prediction.ai_probability()
201 - 1.0)
202 .abs()
203 < 0.001
204 );
205 }
206
207 #[test]
208 fn test_predict_class() {
209 let predictor = Predictor::new();
210 let class = predictor
211 .classify("This is a test text")
212 .expect("Classification should succeed");
213
214 assert!(matches!(class, Classification::Human | Classification::AI));
215 }
216
217 #[test]
218 fn test_predict_class_with_threshold() {
219 let predictor = Predictor::new().with_threshold(0.99);
220 let class = predictor
221 .classify("This is a test text")
222 .expect("Classification should succeed");
223
224 assert!(matches!(class, Classification::Human | Classification::AI));
225 }
226
227 #[test]
228 fn test_batch_predictions() {
229 let predictor = Predictor::new();
230 let texts = vec!["Text 1", "Text 2", "Text 3"];
231
232 let predictions = predictor
233 .predict_batch(&texts)
234 .expect("Batch prediction should succeed");
235
236 assert_eq!(predictions.len(), 3);
237 for pred in predictions {
238 assert!(pred.prediction.human_probability() >= 0.0);
239 assert!(pred.prediction.human_probability() <= 1.0);
240 assert!(pred.prediction.ai_probability() >= 0.0);
241 assert!(pred.prediction.ai_probability() <= 1.0);
242 }
243 }
244
245 #[test]
246 fn test_threshold_accessor() {
247 let predictor = Predictor::new().with_threshold(0.75);
248 assert!((predictor.threshold() - 0.75).abs() < f32::EPSILON);
249 }
250
251 #[test]
252 fn test_default_threshold() {
253 let predictor = Predictor::new();
254 assert!((predictor.threshold() - CLASSIFICATION_THRESHOLD).abs() < f32::EPSILON);
255 }
256}