use pyo3::prelude::*;
use std::sync::Arc;
use crate::error::PhosttError;
use crate::inference::Engine;
#[pyclass(name = "Engine")]
pub struct PyEngine {
engine: Arc<Engine>,
}
#[pymethods]
impl PyEngine {
#[new]
fn new(model_dir: &str) -> PyResult<Self> {
let engine = Engine::load(model_dir).map_err(|e| match e {
PhosttError::ModelLoad(msg) => PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(msg),
PhosttError::Inference(msg) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(msg),
PhosttError::InvalidAudio(msg) => PyErr::new::<pyo3::exceptions::PyValueError, _>(msg),
PhosttError::Io(err) => PyErr::new::<pyo3::exceptions::PyOSError, _>(format!("{err}")),
})?;
Ok(Self {
engine: Arc::new(engine),
})
}
fn transcribe_file(&self, path: &str) -> PyResult<String> {
let engine = self.engine.clone();
let path = path.to_string();
let text = std::thread::spawn(move || {
let mut guard = engine.pool.checkout_blocking()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Pool checkout failed: {e}")))?;
let result = engine.transcribe_file(&path, &mut *guard)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Transcription failed: {e}")))?;
Ok::<String, PyErr>(result.text)
})
.join()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Thread panicked: {e:?}")))??;
Ok(text)
}
fn transcribe_bytes<'py>(&self, data: &[u8]) -> PyResult<String> {
let engine = self.engine.clone();
let data = data.to_vec();
let text = std::thread::spawn(move || {
let mut guard = engine.pool.checkout_blocking()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Pool checkout failed: {e}")))?;
let result = engine.transcribe_bytes(&data, &mut *guard)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Transcription failed: {e}")))?;
Ok::<String, PyErr>(result.text)
})
.join()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Thread panicked: {e:?}")))??;
Ok(text)
}
}
#[pymodule]
fn phostt(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyEngine>()?;
Ok(())
}