use pyo3::prelude::*;
use pyo3::types::PyAny;
use oxillama_runtime::{SpeculativeConfig as RustSpeculativeConfig, SpeculativeEngine};
use crate::engine::PyEngineConfig;
use crate::error::runtime_to_py;
#[pyclass(name = "SpeculativeConfig", from_py_object)]
#[derive(Clone)]
pub struct PySpeculativeConfig {
#[pyo3(get, set)]
pub target: PyEngineConfig,
#[pyo3(get, set)]
pub draft: PyEngineConfig,
#[pyo3(get, set)]
pub num_speculative: usize,
#[pyo3(get, set)]
pub seed: Option<u64>,
}
#[pymethods]
impl PySpeculativeConfig {
#[new]
#[pyo3(signature = (target, draft, *, num_speculative = 4, seed = None))]
pub fn new(
target: PyEngineConfig,
draft: PyEngineConfig,
num_speculative: usize,
seed: Option<u64>,
) -> Self {
Self {
target,
draft,
num_speculative,
seed,
}
}
fn __repr__(&self) -> String {
format!(
"SpeculativeConfig(target={:?}, draft={:?}, num_speculative={}, seed={:?})",
self.target.model_path, self.draft.model_path, self.num_speculative, self.seed,
)
}
}
impl PySpeculativeConfig {
pub fn to_rust(&self) -> RustSpeculativeConfig {
RustSpeculativeConfig {
target: self.target.to_rust(),
draft: self.draft.to_rust(),
num_speculative: self.num_speculative,
seed: self.seed,
}
}
}
#[pyclass(name = "SpeculativeEngine")]
pub struct PySpeculativeEngine {
inner: SpeculativeEngine,
}
#[pymethods]
#[allow(clippy::useless_conversion)]
impl PySpeculativeEngine {
#[new]
pub fn new(py: Python<'_>, config: &PySpeculativeConfig) -> PyResult<Self> {
let rust_cfg = config.to_rust();
let inner = py
.detach(|| SpeculativeEngine::new(rust_cfg))
.map_err(runtime_to_py)?;
Ok(Self { inner })
}
#[pyo3(signature = (prompt, max_tokens = 128))]
pub fn generate(
&mut self,
py: Python<'_>,
prompt: &str,
max_tokens: usize,
) -> PyResult<String> {
let inner = &mut self.inner;
py.detach(|| inner.generate(prompt, max_tokens, |_| {}))
.map_err(runtime_to_py)
}
#[pyo3(signature = (prompt, max_tokens = 128, callback = None))]
pub fn generate_streaming(
&mut self,
py: Python<'_>,
prompt: &str,
max_tokens: usize,
callback: Option<Py<PyAny>>,
) -> PyResult<String> {
let inner = &mut self.inner;
py.detach(|| {
inner.generate(prompt, max_tokens, |tok| {
if let Some(ref cb) = callback {
Python::attach(|py| {
let _ = cb.call1(py, (tok,));
});
}
})
})
.map_err(runtime_to_py)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sampler::PySamplerConfig;
fn make_engine_config(path: &str) -> PyEngineConfig {
PyEngineConfig::new(path.to_string(), None, 4, None, None)
}
#[test]
fn test_speculative_config_default_k() {
let cfg = PySpeculativeConfig::new(
make_engine_config("target.gguf"),
make_engine_config("draft.gguf"),
4,
None,
);
assert_eq!(cfg.num_speculative, 4);
}
#[test]
fn test_speculative_config_override_k() {
let cfg = PySpeculativeConfig::new(
make_engine_config("target.gguf"),
make_engine_config("draft.gguf"),
8,
None,
);
assert_eq!(cfg.num_speculative, 8);
}
#[test]
fn test_speculative_config_to_rust() {
let cfg = PySpeculativeConfig::new(
make_engine_config("target.gguf"),
make_engine_config("draft.gguf"),
4,
Some(42),
);
let rust = cfg.to_rust();
assert_eq!(rust.target.model_path, "target.gguf");
assert_eq!(rust.draft.model_path, "draft.gguf");
assert_eq!(rust.num_speculative, 4);
assert_eq!(rust.seed, Some(42));
}
#[test]
fn test_engine_config_default_sampler() {
let cfg = make_engine_config("x.gguf");
let rust = cfg.to_rust();
let default_sampler = PySamplerConfig::default_config().to_rust();
assert!(
(rust.sampler.temperature - default_sampler.temperature).abs() < 1e-6,
"sampler temperature should match default"
);
}
#[test]
fn test_speculative_config_repr() {
let cfg = PySpeculativeConfig::new(
make_engine_config("big.gguf"),
make_engine_config("tiny.gguf"),
6,
Some(99),
);
let repr = cfg.__repr__();
assert!(
repr.contains("big.gguf"),
"repr missing target path: {repr}"
);
assert!(
repr.contains("tiny.gguf"),
"repr missing draft path: {repr}"
);
assert!(repr.contains('6'), "repr missing num_speculative: {repr}");
assert!(repr.contains("99"), "repr missing seed: {repr}");
}
#[test]
fn test_speculative_config_to_rust_no_seed() {
let cfg = PySpeculativeConfig::new(
make_engine_config("t.gguf"),
make_engine_config("d.gguf"),
4,
None,
);
let rust = cfg.to_rust();
assert!(rust.seed.is_none(), "seed should be None");
}
#[test]
fn test_speculative_config_custom_threads_propagate() {
let target = PyEngineConfig::new("t.gguf".to_string(), None, 12, None, None);
let draft = PyEngineConfig::new("d.gguf".to_string(), None, 3, None, None);
let cfg = PySpeculativeConfig::new(target, draft, 4, None);
let rust = cfg.to_rust();
assert_eq!(rust.target.num_threads, 12);
assert_eq!(rust.draft.num_threads, 3);
}
#[test]
fn test_speculative_config_context_sizes_propagate() {
let target = PyEngineConfig::new("t.gguf".to_string(), Some(8192), 4, None, None);
let draft = PyEngineConfig::new("d.gguf".to_string(), Some(2048), 4, None, None);
let cfg = PySpeculativeConfig::new(target, draft, 4, None);
let rust = cfg.to_rust();
assert_eq!(rust.target.context_size, Some(8192));
assert_eq!(rust.draft.context_size, Some(2048));
}
#[test]
fn test_speculative_config_full_roundtrip() {
let cfg = PySpeculativeConfig::new(
make_engine_config("llama-7b.gguf"),
make_engine_config("llama-1b.gguf"),
3,
Some(777),
);
let rust = cfg.to_rust();
assert_eq!(rust.target.model_path, "llama-7b.gguf");
assert_eq!(rust.draft.model_path, "llama-1b.gguf");
assert_eq!(rust.num_speculative, 3);
assert_eq!(rust.seed, Some(777));
}
}