use std::path::Path;
use std::sync::{Arc, Mutex};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyTuple, PyType};
use oxillama_runtime::{SpeculativeConfig as RustSpeculativeConfig, SpeculativeEngine};
use crate::callback::{
make_progress_bridge, ProgressBridge, DEFAULT_THROTTLE_MS, DEFAULT_THROTTLE_TOKENS,
};
use crate::engine::PyEngineConfig;
use crate::error::runtime_to_py;
#[pyclass(name = "SpeculativeConfig", from_py_object)]
#[derive(Clone)]
pub struct PySpeculativeConfig {
#[pyo3(get, set)]
pub target: PyEngineConfig,
#[pyo3(get, set)]
pub draft: PyEngineConfig,
#[pyo3(get, set)]
pub num_speculative: usize,
#[pyo3(get, set)]
pub seed: Option<u64>,
}
#[pymethods]
impl PySpeculativeConfig {
#[new]
#[pyo3(signature = (target, draft, *, num_speculative = 4, seed = None))]
pub fn new(
target: PyEngineConfig,
draft: PyEngineConfig,
num_speculative: usize,
seed: Option<u64>,
) -> Self {
Self {
target,
draft,
num_speculative,
seed,
}
}
fn __repr__(&self) -> String {
format!(
"SpeculativeConfig(target={:?}, draft={:?}, num_speculative={}, seed={:?})",
self.target.model_path, self.draft.model_path, self.num_speculative, self.seed,
)
}
}
impl PySpeculativeConfig {
pub fn to_rust(&self) -> RustSpeculativeConfig {
RustSpeculativeConfig {
target: self.target.to_rust(),
draft: self.draft.to_rust(),
num_speculative: self.num_speculative,
seed: self.seed,
}
}
}
#[pyclass(name = "SpeculativeEngine")]
pub struct PySpeculativeEngine {
inner: SpeculativeEngine,
}
#[pymethods]
#[allow(clippy::useless_conversion)]
impl PySpeculativeEngine {
#[new]
pub fn new(py: Python<'_>, config: &PySpeculativeConfig) -> PyResult<Self> {
let rust_cfg = config.to_rust();
let inner = py
.detach(|| SpeculativeEngine::new(rust_cfg))
.map_err(runtime_to_py)?;
Ok(Self { inner })
}
#[pyo3(signature = (
prompt,
max_tokens = 128,
*,
progress = None,
progress_throttle_ms = None,
progress_throttle_tokens = None,
progress_capture_text = false,
strict_progress = false,
))]
#[allow(clippy::too_many_arguments)]
pub fn generate(
&mut self,
py: Python<'_>,
prompt: &str,
max_tokens: usize,
progress: Option<Py<PyAny>>,
progress_throttle_ms: Option<u64>,
progress_throttle_tokens: Option<usize>,
progress_capture_text: bool,
strict_progress: bool,
) -> PyResult<String> {
let inner = &mut self.inner;
let bridge = make_progress_bridge(
py,
progress.as_ref(),
max_tokens,
progress_throttle_ms.unwrap_or(DEFAULT_THROTTLE_MS),
progress_throttle_tokens.unwrap_or(DEFAULT_THROTTLE_TOKENS),
progress_capture_text,
)?;
let bridge_arc: Option<Arc<Mutex<ProgressBridge>>> =
bridge.map(|b| Arc::new(Mutex::new(b)));
let result = py
.detach(|| {
let bridge_inner = bridge_arc.clone();
inner.generate(prompt, max_tokens, move |tok| {
if let Some(ref bridge) = bridge_inner {
Python::attach(|py| {
if let Ok(mut b) = bridge.lock() {
let _ = b.note_token(py, tok, false, false);
}
});
}
})
})
.map_err(runtime_to_py);
if let Some(bridge) = bridge_arc.as_ref() {
if let Ok(mut b) = bridge.lock() {
if result.is_ok() {
b.fire_final(py);
}
b.finalise(py, result.as_ref().err());
if strict_progress {
if let Some(err) = b.take_stashed_error() {
return Err(err);
}
}
}
}
result
}
fn snapshot(&self, py: Python<'_>, path: &str) -> PyResult<()> {
let inner = &self.inner;
let path = Path::new(path).to_path_buf();
py.detach(|| inner.snapshot_to_file(&path))
.map_err(runtime_to_py)
}
fn snapshot_bytes(&self, py: Python<'_>) -> PyResult<Vec<u8>> {
let inner = &self.inner;
py.detach(|| inner.snapshot()).map_err(runtime_to_py)
}
#[classmethod]
fn restore(
_cls: &Bound<'_, PyType>,
py: Python<'_>,
path: &str,
target_model: &str,
draft_model: &str,
) -> PyResult<Self> {
let path = Path::new(path).to_path_buf();
let target_path = Path::new(target_model).to_path_buf();
let draft_path = Path::new(draft_model).to_path_buf();
let inner = py
.detach(|| SpeculativeEngine::resume_from_file(&path, &target_path, &draft_path))
.map_err(runtime_to_py)?;
Ok(Self { inner })
}
fn __reduce__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
use pyo3::types::PyString;
let tmp_dir = std::env::temp_dir();
let snap_path = tmp_dir.join(format!(
"oxillama_spec_snap_{}.bin",
std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
self.inner
.snapshot_to_file(&snap_path)
.map_err(runtime_to_py)?;
let path_str = snap_path.to_str().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"snapshot temp path contains non-UTF-8 characters",
)
})?;
let bytes = std::fs::read(&snap_path)
.map_err(|e| pyo3::exceptions::PyOSError::new_err(e.to_string()))?;
let spec_snap = oxillama_runtime::snapshot::SpeculativeEngineSnapshot::decode(&bytes)
.map_err(runtime_to_py)?;
let target_model_path = spec_snap.target_snapshot.model_path.clone();
let draft_model_path = spec_snap.draft_snapshot.model_path.clone();
let cls = py.get_type::<Self>();
let args = PyTuple::new(
py,
&[
PyString::new(py, path_str).into_any(),
PyString::new(py, &target_model_path).into_any(),
PyString::new(py, &draft_model_path).into_any(),
],
)?;
let restore_method = cls.getattr("restore")?;
let result = PyTuple::new(py, &[restore_method.into_any(), args.into_any()])?;
Ok(result.into())
}
fn __reduce_ex__(&self, py: Python<'_>, _protocol: i32) -> PyResult<Py<PyAny>> {
self.__reduce__(py)
}
#[pyo3(signature = (
prompt,
max_tokens = 128,
callback = None,
*,
progress = None,
progress_throttle_ms = None,
progress_throttle_tokens = None,
progress_capture_text = false,
strict_progress = false,
))]
#[allow(clippy::too_many_arguments)]
pub fn generate_streaming(
&mut self,
py: Python<'_>,
prompt: &str,
max_tokens: usize,
callback: Option<Py<PyAny>>,
progress: Option<Py<PyAny>>,
progress_throttle_ms: Option<u64>,
progress_throttle_tokens: Option<usize>,
progress_capture_text: bool,
strict_progress: bool,
) -> PyResult<String> {
let inner = &mut self.inner;
let bridge = make_progress_bridge(
py,
progress.as_ref(),
max_tokens,
progress_throttle_ms.unwrap_or(DEFAULT_THROTTLE_MS),
progress_throttle_tokens.unwrap_or(DEFAULT_THROTTLE_TOKENS),
progress_capture_text,
)?;
let bridge_arc: Option<Arc<Mutex<ProgressBridge>>> =
bridge.map(|b| Arc::new(Mutex::new(b)));
let result = py
.detach(|| {
let bridge_inner = bridge_arc.clone();
inner.generate(prompt, max_tokens, |tok| {
if let Some(ref cb) = callback {
Python::attach(|py| {
let _ = cb.call1(py, (tok,));
});
}
if let Some(ref bridge) = bridge_inner {
Python::attach(|py| {
if let Ok(mut b) = bridge.lock() {
let _ = b.note_token(py, tok, false, false);
}
});
}
})
})
.map_err(runtime_to_py);
if let Some(bridge) = bridge_arc.as_ref() {
if let Ok(mut b) = bridge.lock() {
if result.is_ok() {
b.fire_final(py);
}
b.finalise(py, result.as_ref().err());
if strict_progress {
if let Some(err) = b.take_stashed_error() {
return Err(err);
}
}
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sampler::PySamplerConfig;
fn make_engine_config(path: &str) -> PyEngineConfig {
PyEngineConfig::new(path.to_string(), None, 4, None, None)
}
#[test]
fn test_speculative_config_default_k() {
let cfg = PySpeculativeConfig::new(
make_engine_config("target.gguf"),
make_engine_config("draft.gguf"),
4,
None,
);
assert_eq!(cfg.num_speculative, 4);
}
#[test]
fn test_speculative_config_override_k() {
let cfg = PySpeculativeConfig::new(
make_engine_config("target.gguf"),
make_engine_config("draft.gguf"),
8,
None,
);
assert_eq!(cfg.num_speculative, 8);
}
#[test]
fn test_speculative_config_to_rust() {
let cfg = PySpeculativeConfig::new(
make_engine_config("target.gguf"),
make_engine_config("draft.gguf"),
4,
Some(42),
);
let rust = cfg.to_rust();
assert_eq!(rust.target.model_path, "target.gguf");
assert_eq!(rust.draft.model_path, "draft.gguf");
assert_eq!(rust.num_speculative, 4);
assert_eq!(rust.seed, Some(42));
}
#[test]
fn test_engine_config_default_sampler() {
let cfg = make_engine_config("x.gguf");
let rust = cfg.to_rust();
let default_sampler = PySamplerConfig::default_config().to_rust();
assert!(
(rust.sampler.temperature - default_sampler.temperature).abs() < 1e-6,
"sampler temperature should match default"
);
}
#[test]
fn test_speculative_config_repr() {
let cfg = PySpeculativeConfig::new(
make_engine_config("big.gguf"),
make_engine_config("tiny.gguf"),
6,
Some(99),
);
let repr = cfg.__repr__();
assert!(
repr.contains("big.gguf"),
"repr missing target path: {repr}"
);
assert!(
repr.contains("tiny.gguf"),
"repr missing draft path: {repr}"
);
assert!(repr.contains('6'), "repr missing num_speculative: {repr}");
assert!(repr.contains("99"), "repr missing seed: {repr}");
}
#[test]
fn test_speculative_config_to_rust_no_seed() {
let cfg = PySpeculativeConfig::new(
make_engine_config("t.gguf"),
make_engine_config("d.gguf"),
4,
None,
);
let rust = cfg.to_rust();
assert!(rust.seed.is_none(), "seed should be None");
}
#[test]
fn test_speculative_config_custom_threads_propagate() {
let target = PyEngineConfig::new("t.gguf".to_string(), None, 12, None, None);
let draft = PyEngineConfig::new("d.gguf".to_string(), None, 3, None, None);
let cfg = PySpeculativeConfig::new(target, draft, 4, None);
let rust = cfg.to_rust();
assert_eq!(rust.target.num_threads, 12);
assert_eq!(rust.draft.num_threads, 3);
}
#[test]
fn test_speculative_config_context_sizes_propagate() {
let target = PyEngineConfig::new("t.gguf".to_string(), Some(8192), 4, None, None);
let draft = PyEngineConfig::new("d.gguf".to_string(), Some(2048), 4, None, None);
let cfg = PySpeculativeConfig::new(target, draft, 4, None);
let rust = cfg.to_rust();
assert_eq!(rust.target.context_size, Some(8192));
assert_eq!(rust.draft.context_size, Some(2048));
}
#[test]
fn test_speculative_config_full_roundtrip() {
let cfg = PySpeculativeConfig::new(
make_engine_config("llama-7b.gguf"),
make_engine_config("llama-1b.gguf"),
3,
Some(777),
);
let rust = cfg.to_rust();
assert_eq!(rust.target.model_path, "llama-7b.gguf");
assert_eq!(rust.draft.model_path, "llama-1b.gguf");
assert_eq!(rust.num_speculative, 3);
assert_eq!(rust.seed, Some(777));
}
}