1#![allow(unsafe_op_in_unsafe_fn)]
2use pyo3::prelude::Bound;
3use pyo3::exceptions::PyValueError;
4use pyo3::prelude::*;
5
6#[pyfunction]
7pub fn euclidean(p1: Vec<f64>, p2: Vec<f64>) -> PyResult<f64> {
8 if p1.len() != p2.len() {
9 return Err(PyValueError::new_err("length mismatch"));
10 }
11 let sum_sq: f64 = p1
12 .iter()
13 .zip(p2.iter())
14 .map(|(a, b)| {
15 let d = a - b;
16 d * d
17 })
18 .sum();
19 Ok(sum_sq.sqrt())
20}
21
22#[pyfunction]
23pub fn dot_product(a: Vec<f64>, b: Vec<f64>) -> PyResult<f64> {
24 if a.len() != b.len() {
25 return Err(PyValueError::new_err("length mismatch"));
26 }
27 let mut total = 0.0;
28 for (x, y) in a.iter().zip(b.iter()) {
29 total += x * y;
30 }
31 Ok(total)
32}
33
34#[pyfunction]
35pub fn moving_average(signal: Vec<f64>, window: usize) -> PyResult<Vec<f64>> {
36 if window == 0 {
37 return Err(PyValueError::new_err("invalid window"));
38 }
39 let n = signal.len();
40 if n < window {
41 return Err(PyValueError::new_err("signal shorter than window"));
42 }
43 let mut result = vec![0.0f64; n - window + 1];
44 for i in 0..=n - window {
45 let mut total = 0.0;
46 for j in 0..window {
47 total += signal[i + j];
48 }
49 result[i] = total / window as f64;
50 }
51 Ok(result)
52}
53
54#[pyfunction]
55pub fn convolve1d(signal: Vec<f64>, kernel: Vec<f64>) -> PyResult<Vec<f64>> {
56 let n = signal.len();
57 let k = kernel.len();
58 if k == 0 {
59 return Err(PyValueError::new_err("invalid kernel length"));
60 }
61 if n < k {
62 return Err(PyValueError::new_err("signal shorter than kernel"));
63 }
64 let mut result = vec![0.0f64; n - k + 1];
65 for i in 0..=n - k {
66 let mut total = 0.0;
67 for j in 0..k {
68 total += signal[i + j] * kernel[j];
69 }
70 result[i] = total;
71 }
72 Ok(result)
73}
74
75#[pyfunction]
76pub fn bpe_encode(text: String, _merges: Vec<(String, String)>) -> PyResult<Vec<i64>> {
77 Ok(text.bytes().map(|b| b as i64).collect())
78}
79
80#[pymodule]
81fn rustify_stdlib(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
82 m.add_function(wrap_pyfunction!(euclidean, m)?)?;
83 m.add_function(wrap_pyfunction!(dot_product, m)?)?;
84 m.add_function(wrap_pyfunction!(moving_average, m)?)?;
85 m.add_function(wrap_pyfunction!(convolve1d, m)?)?;
86 m.add_function(wrap_pyfunction!(bpe_encode, m)?)?;
87 Ok(())
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_euclidean() {
96 let v = euclidean(vec![0.0, 3.0, 4.0], vec![0.0, 0.0, 0.0]).unwrap();
97 assert!((v - 5.0).abs() < 1e-9);
98 }
99
100 #[test]
101 fn test_dot_product() {
102 let v = dot_product(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]).unwrap();
103 assert!((v - 32.0).abs() < 1e-9);
104 }
105
106 #[test]
107 fn test_moving_average_valid() {
108 let v = moving_average(vec![1.0, 2.0, 3.0, 4.0, 5.0], 3).unwrap();
109 assert_eq!(v, vec![2.0, 3.0, 4.0]);
110 }
111
112 #[test]
113 fn test_convolve1d() {
114 let v = convolve1d(vec![1.0, 2.0, 3.0, 4.0], vec![1.0, 0.0, -1.0]).unwrap();
115 assert_eq!(v, vec![-2.0, -2.0]);
116 }
117
118 #[test]
119 fn test_bpe_encode() {
120 let v = bpe_encode("ab".to_string(), vec![]).unwrap();
121 assert_eq!(v, vec![97, 98]);
122 }
123}