use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyModule, PyTuple};
pub struct StreamingCallbackBridge {
callback: Py<PyAny>,
}
impl StreamingCallbackBridge {
pub fn new(callback: Py<PyAny>) -> Self {
Self { callback }
}
pub fn call(&self, token: &str) {
Python::attach(|py| {
let _ = self.callback.call1(py, (token,));
});
}
}
pub fn make_callback(callback: Option<Py<PyAny>>) -> impl FnMut(&str) {
move |tok: &str| {
if let Some(ref cb) = callback {
Python::attach(|py| {
let _ = cb.call1(py, (tok,));
});
}
}
}
pub const DEFAULT_THROTTLE_MS: u64 = 50;
pub const DEFAULT_THROTTLE_TOKENS: usize = 4;
pub struct Throttler {
last_fire: Instant,
tokens_since_fire: usize,
min_interval: Duration,
min_tokens: usize,
}
impl Throttler {
pub fn new(min_interval_ms: u64, min_tokens: usize) -> Self {
Self {
last_fire: Instant::now(),
tokens_since_fire: 0,
min_interval: Duration::from_millis(min_interval_ms),
min_tokens,
}
}
pub fn note_token(&mut self) {
self.tokens_since_fire = self.tokens_since_fire.saturating_add(1);
}
pub fn should_fire(&mut self, force: bool) -> bool {
let elapsed_ok = self.last_fire.elapsed() >= self.min_interval;
let tokens_ok = self.tokens_since_fire >= self.min_tokens;
let fire = force || elapsed_ok || tokens_ok;
if fire {
self.last_fire = Instant::now();
self.tokens_since_fire = 0;
}
fire
}
}
pub struct ProgressBridge {
callback: Py<PyAny>,
finaliser: Py<PyAny>,
throttle: Throttler,
start: Instant,
tokens_total: AtomicUsize,
capture_text: bool,
accumulated_text: String,
finalised: AtomicBool,
stashed_error: Option<PyErr>,
}
impl ProgressBridge {
pub fn tokens_total(&self) -> usize {
self.tokens_total.load(Ordering::Relaxed)
}
pub fn take_stashed_error(&mut self) -> Option<PyErr> {
self.stashed_error.take()
}
pub fn note_token(
&mut self,
py: Python<'_>,
token: &str,
is_final: bool,
strict: bool,
) -> PyResult<()> {
let prev = self.tokens_total.fetch_add(1, Ordering::Relaxed);
let tokens_now = prev + 1;
if self.capture_text {
self.accumulated_text.push_str(token);
}
let force = is_final || tokens_now == 1;
self.throttle.note_token();
if !self.throttle.should_fire(force) {
return Ok(());
}
let elapsed = self.start.elapsed().as_secs_f64();
let text_view: &str = if self.capture_text {
self.accumulated_text.as_str()
} else {
""
};
let payload = PyTuple::new(
py,
[
tokens_now.into_pyobject(py)?.into_any(),
elapsed.into_pyobject(py)?.into_any(),
is_final.into_pyobject(py)?.to_owned().into_any(),
text_view.into_pyobject(py)?.into_any(),
],
)?;
match self.callback.call1(py, (payload,)) {
Ok(_) => Ok(()),
Err(err) => {
if strict {
Err(err)
} else {
if self.stashed_error.is_none() {
self.stashed_error = Some(err);
}
Ok(())
}
}
}
}
pub fn fire_final(&mut self, py: Python<'_>) {
let tokens_now = self.tokens_total.load(Ordering::Relaxed);
let elapsed = self.start.elapsed().as_secs_f64();
let text_view: &str = if self.capture_text {
self.accumulated_text.as_str()
} else {
""
};
let payload = match (
tokens_now.into_pyobject(py),
elapsed.into_pyobject(py),
true.into_pyobject(py),
text_view.into_pyobject(py),
) {
(Ok(t), Ok(e), Ok(f), Ok(s)) => {
let owned_f = f.to_owned();
PyTuple::new(
py,
[t.into_any(), e.into_any(), owned_f.into_any(), s.into_any()],
)
}
_ => return,
};
let payload = match payload {
Ok(p) => p,
Err(_) => return,
};
let _ = self.callback.call1(py, (payload,));
}
pub fn finalise(&mut self, py: Python<'_>, error: Option<&PyErr>) {
if self.finalised.swap(true, Ordering::Relaxed) {
return;
}
let arg: Py<PyAny> = match error {
Some(err) => err.clone_ref(py).into_value(py).into_any(),
None => py.None(),
};
let _ = self.finaliser.call1(py, (arg,));
}
}
impl Drop for ProgressBridge {
fn drop(&mut self) {
if self.finalised.load(Ordering::Relaxed) {
return;
}
Python::attach(|py| {
self.finalise(py, None);
});
}
}
pub fn make_progress_bridge(
py: Python<'_>,
progress: Option<&Py<PyAny>>,
max_tokens: usize,
throttle_ms: u64,
throttle_tokens: usize,
capture_text: bool,
) -> PyResult<Option<ProgressBridge>> {
let progress = match progress {
Some(obj) => obj,
None => return Ok(None),
};
let module = PyModule::import(py, "oxillama_py.progress")?;
let builder = module.getattr("_build_bridge")?;
let pair = builder.call1((progress.bind(py), max_tokens))?;
let tuple: Bound<'_, PyTuple> = pair.cast_into::<PyTuple>().map_err(|e| {
pyo3::exceptions::PyTypeError::new_err(format!(
"_build_bridge must return a (callback, finaliser) tuple: {e}"
))
})?;
if tuple.len() != 2 {
return Err(pyo3::exceptions::PyTypeError::new_err(
"_build_bridge must return a 2-tuple (callback, finaliser)",
));
}
let callback: Py<PyAny> = tuple.get_item(0)?.unbind();
let finaliser: Py<PyAny> = tuple.get_item(1)?.unbind();
Ok(Some(ProgressBridge {
callback,
finaliser,
throttle: Throttler::new(throttle_ms, throttle_tokens),
start: Instant::now(),
tokens_total: AtomicUsize::new(0),
capture_text,
accumulated_text: String::new(),
finalised: AtomicBool::new(false),
stashed_error: None,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn test_make_callback_none_is_noop() {
let mut cb = make_callback(None);
cb("hello");
cb("world");
}
#[test]
fn test_throttler_fires_on_first_token() {
let mut t = Throttler::new(60_000, 999);
assert!(t.should_fire(true), "force=true must always fire");
}
#[test]
fn test_throttler_throttles_subsequent_calls() {
let mut t = Throttler::new(60_000, 999);
assert!(t.should_fire(true));
t.note_token();
t.note_token();
assert!(
!t.should_fire(false),
"throttler should not fire while both gates are closed"
);
}
#[test]
fn test_throttler_fires_on_token_threshold() {
let mut t = Throttler::new(60_000, 4);
assert!(t.should_fire(true));
for _ in 0..3 {
t.note_token();
assert!(
!t.should_fire(false),
"should not fire before crossing the 4-token threshold"
);
}
t.note_token();
assert!(
t.should_fire(false),
"should fire once the 4-token threshold is reached"
);
}
#[test]
fn test_throttler_fires_on_interval() {
let mut t = Throttler::new(20, 999);
assert!(t.should_fire(true));
sleep(Duration::from_millis(35));
assert!(
t.should_fire(false),
"should fire once the 20 ms interval has elapsed"
);
}
#[test]
fn test_throttler_force_resets_counters() {
let mut t = Throttler::new(60_000, 4);
for _ in 0..3 {
t.note_token();
}
assert!(t.should_fire(true), "force fire");
for _ in 0..3 {
t.note_token();
assert!(
!t.should_fire(false),
"counters were not reset by the force fire"
);
}
}
}