use crate::Result;
#[cfg(feature = "candle")]
pub mod candle_local;
pub mod external;
#[cfg(feature = "candle")]
pub use candle_local::{
CandleLocalConfig, CandleLocalPredictionLoss, DEFAULT_CANDLE_MODEL_ID,
DEFAULT_LOSS_SCALE as CANDLE_DEFAULT_LOSS_SCALE,
};
pub use external::{
ExternalPredictionLossBackend, ExternalPredictionLossConfig, DEFAULT_LOSS_SCALE,
};
pub trait PredictionLossBackend: Send + Sync {
fn predict_loss(&self, content: &str) -> Result<f32>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PredictionLossBackendKind {
#[default]
None,
OpenAiCompat,
CandleLocal,
}
impl PredictionLossBackendKind {
pub fn parse(s: &str) -> std::result::Result<Self, String> {
match s.trim().to_ascii_lowercase().as_str() {
"none" | "off" | "disabled" => Ok(Self::None),
"openai-compat" | "openai" | "vllm" | "llamacpp" | "llama-cpp" => {
Ok(Self::OpenAiCompat)
}
"candle-local" | "candle" | "local" => Ok(Self::CandleLocal),
other => Err(format!(
"unknown prediction-loss backend: {other:?} \
(expected: none, openai-compat, candle-local)"
)),
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::OpenAiCompat => "openai-compat",
Self::CandleLocal => "candle-local",
}
}
}
pub struct MockPredictionLoss;
impl PredictionLossBackend for MockPredictionLoss {
fn predict_loss(&self, content: &str) -> Result<f32> {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(content.as_bytes());
let seed = h.finalize();
Ok(seed[0] as f32 / 255.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backend_kind_parse_canonical() {
assert_eq!(
PredictionLossBackendKind::parse("none").unwrap(),
PredictionLossBackendKind::None
);
assert_eq!(
PredictionLossBackendKind::parse("openai-compat").unwrap(),
PredictionLossBackendKind::OpenAiCompat
);
}
#[test]
fn backend_kind_parse_aliases() {
assert_eq!(
PredictionLossBackendKind::parse("vllm").unwrap(),
PredictionLossBackendKind::OpenAiCompat
);
assert_eq!(
PredictionLossBackendKind::parse("OFF").unwrap(),
PredictionLossBackendKind::None
);
assert_eq!(
PredictionLossBackendKind::parse("candle-local").unwrap(),
PredictionLossBackendKind::CandleLocal
);
assert_eq!(
PredictionLossBackendKind::parse("candle").unwrap(),
PredictionLossBackendKind::CandleLocal
);
assert_eq!(
PredictionLossBackendKind::parse("LOCAL").unwrap(),
PredictionLossBackendKind::CandleLocal
);
}
#[test]
fn backend_kind_parse_rejects_unknown() {
assert!(PredictionLossBackendKind::parse("gpt-7").is_err());
}
#[test]
fn backend_kind_default_is_none() {
assert_eq!(
PredictionLossBackendKind::default(),
PredictionLossBackendKind::None
);
}
#[test]
fn mock_returns_deterministic_in_range() {
let m = MockPredictionLoss;
let a = m.predict_loss("alpha").unwrap();
let b = m.predict_loss("alpha").unwrap();
assert_eq!(a, b);
assert!((0.0..=1.0).contains(&a));
let c = m.predict_loss("bravo").unwrap();
assert!(a != c, "different inputs must produce different scores");
}
}