Skip to main content

rustify_stdlib/
lib.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use pyo3::prelude::Bound;
3use pyo3::exceptions::PyValueError;
4use pyo3::prelude::*;
5
6#[pyfunction]
7pub fn euclidean(p1: Vec<f64>, p2: Vec<f64>) -> PyResult<f64> {
8    if p1.len() != p2.len() {
9        return Err(PyValueError::new_err("length mismatch"));
10    }
11    let sum_sq: f64 = p1
12        .iter()
13        .zip(p2.iter())
14        .map(|(a, b)| {
15            let d = a - b;
16            d * d
17        })
18        .sum();
19    Ok(sum_sq.sqrt())
20}
21
22#[pyfunction]
23pub fn dot_product(a: Vec<f64>, b: Vec<f64>) -> PyResult<f64> {
24    if a.len() != b.len() {
25        return Err(PyValueError::new_err("length mismatch"));
26    }
27    let mut total = 0.0;
28    for (x, y) in a.iter().zip(b.iter()) {
29        total += x * y;
30    }
31    Ok(total)
32}
33
34#[pyfunction]
35pub fn moving_average(signal: Vec<f64>, window: usize) -> PyResult<Vec<f64>> {
36    if window == 0 {
37        return Err(PyValueError::new_err("invalid window"));
38    }
39    let n = signal.len();
40    if n < window {
41        return Err(PyValueError::new_err("signal shorter than window"));
42    }
43    let mut result = vec![0.0f64; n - window + 1];
44    for i in 0..=n - window {
45        let mut total = 0.0;
46        for j in 0..window {
47            total += signal[i + j];
48        }
49        result[i] = total / window as f64;
50    }
51    Ok(result)
52}
53
54#[pyfunction]
55pub fn convolve1d(signal: Vec<f64>, kernel: Vec<f64>) -> PyResult<Vec<f64>> {
56    let n = signal.len();
57    let k = kernel.len();
58    if k == 0 {
59        return Err(PyValueError::new_err("invalid kernel length"));
60    }
61    if n < k {
62        return Err(PyValueError::new_err("signal shorter than kernel"));
63    }
64    let mut result = vec![0.0f64; n - k + 1];
65    for i in 0..=n - k {
66        let mut total = 0.0;
67        for j in 0..k {
68            total += signal[i + j] * kernel[j];
69        }
70        result[i] = total;
71    }
72    Ok(result)
73}
74
75#[pyfunction]
76pub fn bpe_encode(text: String, _merges: Vec<(String, String)>) -> PyResult<Vec<i64>> {
77    Ok(text.bytes().map(|b| b as i64).collect())
78}
79
80#[pymodule]
81fn rustify_stdlib(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
82    m.add_function(wrap_pyfunction!(euclidean, m)?)?;
83    m.add_function(wrap_pyfunction!(dot_product, m)?)?;
84    m.add_function(wrap_pyfunction!(moving_average, m)?)?;
85    m.add_function(wrap_pyfunction!(convolve1d, m)?)?;
86    m.add_function(wrap_pyfunction!(bpe_encode, m)?)?;
87    Ok(())
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_euclidean() {
96        let v = euclidean(vec![0.0, 3.0, 4.0], vec![0.0, 0.0, 0.0]).unwrap();
97        assert!((v - 5.0).abs() < 1e-9);
98    }
99
100    #[test]
101    fn test_dot_product() {
102        let v = dot_product(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]).unwrap();
103        assert!((v - 32.0).abs() < 1e-9);
104    }
105
106    #[test]
107    fn test_moving_average_valid() {
108        let v = moving_average(vec![1.0, 2.0, 3.0, 4.0, 5.0], 3).unwrap();
109        assert_eq!(v, vec![2.0, 3.0, 4.0]);
110    }
111
112    #[test]
113    fn test_convolve1d() {
114        let v = convolve1d(vec![1.0, 2.0, 3.0, 4.0], vec![1.0, 0.0, -1.0]).unwrap();
115        assert_eq!(v, vec![-2.0, -2.0]);
116    }
117
118    #[test]
119    fn test_bpe_encode() {
120        let v = bpe_encode("ab".to_string(), vec![]).unwrap();
121        assert_eq!(v, vec![97, 98]);
122    }
123}