#[encrypted_library]
mod arcis_library {
pub fn logit(p: f64) -> f64 {
if p <= 0.0 || p >= 1.0 {
0.0
} else {
p.ln() - (1.0 - p).ln()
}
}
pub fn expit(x: f64) -> f64 {
ArcisMath::sigmoid(x)
}
pub struct LogisticRegression {
pub coef: Box<[f64]>,
pub intercept: f64,
}
impl LogisticRegression {
pub fn new(coef: &[f64], intercept: f64) -> Self {
let coef = box_from_slice(coef);
LogisticRegression { coef, intercept }
}
pub fn predict_log_proba(&self, x: &[f64]) -> f64 {
if x.len() != self.coef.len() {
arcis_static_panic!("Wrong length in `predict_log_proba`.");
}
let mut acc = self.intercept;
for (i, xi) in x.iter().enumerate() {
acc += self.coef[i] * xi;
}
acc
}
pub fn predict_proba(&self, x: &[f64]) -> f64 {
expit(Self::predict_log_proba(self, x))
}
pub fn predict(&self, x: &[f64], threshold: f64) -> bool {
Self::predict_log_proba(self, x) > logit(threshold)
}
}
pub struct LinearRegression {
pub coef: Box<[f64]>,
pub intercept: f64,
}
impl LinearRegression {
pub fn new(coef: &[f64], intercept: f64) -> Self {
let coef = box_from_slice(coef);
Self { coef, intercept }
}
pub fn predict(&self, x: &[f64]) -> f64 {
if x.len() != self.coef.len() {
arcis_static_panic!("Wrong length in `predict`.");
}
let mut acc = self.intercept;
for (i, xi) in x.iter().enumerate() {
acc += self.coef[i] * xi;
}
acc
}
}
}