use pyo3::prelude::*;
use crate::embedded::{self, RiskLevel};
#[pyclass(name = "RiskLevel", eq, eq_int)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum PyRiskLevel {
Safe,
Low,
Medium,
High,
Critical,
}
#[pymethods]
impl PyRiskLevel {
fn __repr__(&self) -> &'static str {
match self {
PyRiskLevel::Safe => "RiskLevel.Safe",
PyRiskLevel::Low => "RiskLevel.Low",
PyRiskLevel::Medium => "RiskLevel.Medium",
PyRiskLevel::High => "RiskLevel.High",
PyRiskLevel::Critical => "RiskLevel.Critical",
}
}
fn __str__(&self) -> &'static str {
match self {
PyRiskLevel::Safe => "Safe",
PyRiskLevel::Low => "Low",
PyRiskLevel::Medium => "Medium",
PyRiskLevel::High => "High",
PyRiskLevel::Critical => "Critical",
}
}
}
impl From<RiskLevel> for PyRiskLevel {
fn from(r: RiskLevel) -> Self {
match r {
RiskLevel::Safe => PyRiskLevel::Safe,
RiskLevel::Low => PyRiskLevel::Low,
RiskLevel::Medium => PyRiskLevel::Medium,
RiskLevel::High => PyRiskLevel::High,
RiskLevel::Critical => PyRiskLevel::Critical,
}
}
}
#[pyclass(name = "DetectionResult", get_all)]
#[derive(Clone)]
pub struct PyDetectionResult {
pub is_injection: bool,
pub score: f32,
pub confidence: f32,
pub risk: PyRiskLevel,
}
#[pymethods]
impl PyDetectionResult {
fn __repr__(&self) -> String {
format!(
"DetectionResult(is_injection={}, score={:.4}, confidence={:.4}, risk={})",
if self.is_injection { "True" } else { "False" },
self.score,
self.confidence,
self.risk.__repr__(),
)
}
fn __bool__(&self) -> bool {
self.is_injection
}
}
impl From<embedded::DetectionOutput> for PyDetectionResult {
fn from(r: embedded::DetectionOutput) -> Self {
Self {
is_injection: r.is_injection,
score: r.score,
confidence: r.confidence,
risk: r.risk.into(),
}
}
}
#[pyfunction]
fn detect(text: &str) -> PyDetectionResult {
embedded::detect(text).into()
}
#[pyfunction]
fn is_injection(text: &str) -> bool {
embedded::is_injection(text)
}
#[pyfunction]
fn score(text: &str) -> f32 {
embedded::score(text)
}
#[pyfunction]
fn detect_batch(texts: Vec<String>) -> Vec<PyDetectionResult> {
let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
embedded::detect_batch(&refs)
.into_iter()
.map(PyDetectionResult::from)
.collect()
}
#[pyfunction]
fn download_model() -> PyResult<String> {
crate::model_manager::download_model()
.map(|p| p.display().to_string())
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
fn model_cache_dir() -> String {
if let Ok(dir) = std::env::var("JAILGUARD_MODEL_DIR") {
return dir;
}
if let Ok(home) = std::env::var("HOME") {
return format!("{home}/.cache/jailguard");
}
"~/.cache/jailguard".to_string()
}
#[pymodule]
fn _jailguard(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_class::<PyRiskLevel>()?;
m.add_class::<PyDetectionResult>()?;
m.add_function(wrap_pyfunction!(detect, m)?)?;
m.add_function(wrap_pyfunction!(is_injection, m)?)?;
m.add_function(wrap_pyfunction!(score, m)?)?;
m.add_function(wrap_pyfunction!(detect_batch, m)?)?;
m.add_function(wrap_pyfunction!(download_model, m)?)?;
m.add_function(wrap_pyfunction!(model_cache_dir, m)?)?;
Ok(())
}