Skip to main content

is_it_slop/
lib.rs

1//! # is-it-slop
2//!
3//! A fast and accurate AI text detector built with Rust.
4//!
5//! Supports multiple platforms including macOS (Apple Silicon), Linux, and Windows.
6//!
7//! This crate provides tools to classify whether text was written by AI or a human,
8//! using a machine learning model based on TF-IDF features and ONNX runtime inference.
9//!
10//! ## Quick Start
11//!
12//! ```rust
13//! use is_it_slop::Predictor;
14//!
15//! let predictor = Predictor::new();
16//!
17//! // Get raw probabilities
18//! let result = predictor.predict("Some text to analyze")?;
19//! println!(
20//!     "AI probability: {:.2}%",
21//!     result.prediction.ai_probability() * 100.0
22//! );
23//!
24//! // Or get a classification directly
25//! let class = predictor.classify("Some text to analyze")?;
26//! println!("Classification: {:?}", class);
27//! # Ok::<(), anyhow::Error>(())
28//! ```
29//!
30//! ## Custom Threshold
31//!
32//! ```rust
33//! use is_it_slop::Predictor;
34//!
35//! // Use a custom threshold (default is [`CLASSIFICATION_THRESHOLD`])
36//! let predictor = Predictor::new().with_threshold(0.7);
37//! let class = predictor.classify("Some text")?;
38//! # Ok::<(), anyhow::Error>(())
39//! ```
40//!
41//! ## Batch Processing
42//!
43//! ```rust
44//! use is_it_slop::Predictor;
45//!
46//! let predictor = Predictor::new();
47//! let texts = vec!["First text", "Second text", "Third text"];
48//! let predictions = predictor.predict_batch(&texts)?;
49//! # Ok::<(), anyhow::Error>(())
50//! ```
51
52#[cfg(feature = "python")]
53mod python;
54
55#[cfg(feature = "cli")]
56pub mod cli;
57
58pub mod model;
59pub mod pipeline;
60
61use anyhow::Context;
62pub use model::MODEL_VERSION;
63pub use pipeline::{AggregationMethod, Classification, Prediction, UnifiedPrediction};
64
65use crate::model::{CHUNK_CLASSIFICATION_THRESHOLD, CLASSIFICATION_THRESHOLD, MODEL};
66
67/// Builder struct for configuring and running predictions.
68///
69/// Use `Predictor::new()` to create with default threshold, or chain
70/// `.with_threshold()` to customize.
71///
72/// # Examples
73///
74/// ```rust
75/// use is_it_slop::Predictor;
76///
77/// // Use default threshold
78/// let predictor = Predictor::new();
79/// let prediction = predictor.predict("some text")?;
80///
81/// // Custom threshold
82/// let predictor = Predictor::new().with_threshold(0.7);
83/// let class = predictor.classify("some text")?;
84/// # Ok::<(), anyhow::Error>(())
85/// ```
86pub struct Predictor {
87    threshold: f32,
88    agg_method: AggregationMethod,
89}
90
91impl Predictor {
92    /// Create a new predictor with the default classification threshold.
93    #[must_use]
94    pub fn new() -> Self {
95        Self {
96            threshold: CLASSIFICATION_THRESHOLD,
97            agg_method: AggregationMethod::WeightedMean(CHUNK_CLASSIFICATION_THRESHOLD),
98        }
99    }
100
101    /// Set a custom classification threshold.
102    ///
103    /// The threshold determines the AI probability cutoff for classification:
104    /// - If P(AI) >= threshold: classified as AI
105    /// - If P(AI) < threshold: classified as Human
106    #[must_use]
107    pub fn with_threshold(mut self, threshold: f32) -> Self {
108        self.threshold = threshold;
109        self
110    }
111
112    /// Set a custom aggregation method for chunk predictions.
113    ///
114    /// The aggregation method determines how chunk-level predictions
115    /// are combined into a final document-level prediction.
116    #[must_use]
117    pub fn with_aggregation_method(mut self, method: AggregationMethod) -> Self {
118        self.agg_method = method;
119        self
120    }
121
122    /// Get the current threshold value.
123    #[must_use]
124    pub fn threshold(&self) -> f32 {
125        self.threshold
126    }
127
128    /// Predict probabilities for a single text.
129    ///
130    /// Returns a `Prediction` containing P(Human) and P(AI).
131    pub fn predict<T: AsRef<str>>(&self, text: T) -> anyhow::Result<UnifiedPrediction> {
132        pipeline::predict(&MODEL, text.as_ref(), self.agg_method)
133            .with_context(|| "Failed to predict probabilities for the given text")
134    }
135
136    /// Predict probabilities for multiple texts.
137    ///
138    /// Returns a vector of `Prediction` values, one for each input text.
139    pub fn predict_batch<T: AsRef<str> + Sync>(
140        &self,
141        texts: &[T],
142    ) -> anyhow::Result<Vec<UnifiedPrediction>> {
143        pipeline::predict_batch(&MODEL, texts, self.agg_method)
144            .with_context(|| "Failed to predict probabilities for the given texts")
145    }
146
147    /// Classify a single text using the configured threshold.
148    ///
149    /// Returns `Classification::Human` or `Classification::AI`.
150    pub fn classify<T: AsRef<str>>(&self, text: T) -> anyhow::Result<Classification> {
151        self.predict(text)
152            .map(|pred| pred.classification(self.threshold))
153    }
154
155    /// Classify multiple texts using the configured threshold.
156    ///
157    /// Returns a vector of classifications, one for each input text.
158    pub fn classify_batch<T: AsRef<str> + Sync>(
159        &self,
160        texts: &[T],
161    ) -> anyhow::Result<Vec<Classification>> {
162        self.predict_batch(texts).map(|preds| {
163            preds
164                .into_iter()
165                .map(|pred| pred.classification(self.threshold))
166                .collect()
167        })
168    }
169}
170
171impl Default for Predictor {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_predict_probabilities() {
183        let predictor = Predictor::new();
184        let prediction = predictor
185            .predict("This is a test text")
186            .expect("Prediction should succeed");
187        assert!(prediction.chunk_predictions.iter().all(|p| {
188            p.human_probability() >= 0.0
189                && p.human_probability() <= 1.0
190                && p.ai_probability() >= 0.0
191                && p.ai_probability() <= 1.0
192                && (p.human_probability() + p.ai_probability() - 1.0).abs() < 0.001
193        }));
194
195        assert!(prediction.prediction.human_probability() >= 0.0);
196        assert!(prediction.prediction.human_probability() <= 1.0);
197        assert!(prediction.prediction.ai_probability() >= 0.0);
198        assert!(prediction.prediction.ai_probability() <= 1.0);
199        assert!(
200            (prediction.prediction.human_probability() + prediction.prediction.ai_probability()
201                - 1.0)
202                .abs()
203                < 0.001
204        );
205    }
206
207    #[test]
208    fn test_predict_class() {
209        let predictor = Predictor::new();
210        let class = predictor
211            .classify("This is a test text")
212            .expect("Classification should succeed");
213
214        assert!(matches!(class, Classification::Human | Classification::AI));
215    }
216
217    #[test]
218    fn test_predict_class_with_threshold() {
219        let predictor = Predictor::new().with_threshold(0.99);
220        let class = predictor
221            .classify("This is a test text")
222            .expect("Classification should succeed");
223
224        assert!(matches!(class, Classification::Human | Classification::AI));
225    }
226
227    #[test]
228    fn test_batch_predictions() {
229        let predictor = Predictor::new();
230        let texts = vec!["Text 1", "Text 2", "Text 3"];
231
232        let predictions = predictor
233            .predict_batch(&texts)
234            .expect("Batch prediction should succeed");
235
236        assert_eq!(predictions.len(), 3);
237        for pred in predictions {
238            assert!(pred.prediction.human_probability() >= 0.0);
239            assert!(pred.prediction.human_probability() <= 1.0);
240            assert!(pred.prediction.ai_probability() >= 0.0);
241            assert!(pred.prediction.ai_probability() <= 1.0);
242        }
243    }
244
245    #[test]
246    fn test_threshold_accessor() {
247        let predictor = Predictor::new().with_threshold(0.75);
248        assert!((predictor.threshold() - 0.75).abs() < f32::EPSILON);
249    }
250
251    #[test]
252    fn test_default_threshold() {
253        let predictor = Predictor::new();
254        assert!((predictor.threshold() - CLASSIFICATION_THRESHOLD).abs() < f32::EPSILON);
255    }
256}