use std::sync::Arc;
pub trait ConversionProgressCallback: Send + Sync {
fn on_conversion_start(&self, total_pages: usize) {
let _ = total_pages;
}
fn on_page_start(&self, page_num: usize, total_pages: usize) {
let _ = (page_num, total_pages);
}
fn on_page_complete(&self, page_num: usize, total_pages: usize, markdown_len: usize) {
let _ = (page_num, total_pages, markdown_len);
}
fn on_page_error(&self, page_num: usize, total_pages: usize, error: String) {
let _ = (page_num, total_pages, error);
}
fn on_conversion_complete(&self, total_pages: usize, success_count: usize) {
let _ = (total_pages, success_count);
}
}
pub struct NoopProgressCallback;
impl ConversionProgressCallback for NoopProgressCallback {}
pub type ProgressCallback = Arc<dyn ConversionProgressCallback>;
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct TrackingCallback {
starts: Arc<AtomicUsize>,
completes: Arc<AtomicUsize>,
errors: Arc<AtomicUsize>,
started_total: Arc<AtomicUsize>,
completed_total: Arc<AtomicUsize>,
}
impl ConversionProgressCallback for TrackingCallback {
fn on_conversion_start(&self, total_pages: usize) {
self.started_total.store(total_pages, Ordering::SeqCst);
}
fn on_page_start(&self, _page_num: usize, _total_pages: usize) {
self.starts.fetch_add(1, Ordering::SeqCst);
}
fn on_page_complete(&self, _page_num: usize, _total_pages: usize, _markdown_len: usize) {
self.completes.fetch_add(1, Ordering::SeqCst);
}
fn on_page_error(&self, _page_num: usize, _total_pages: usize, _error: String) {
self.errors.fetch_add(1, Ordering::SeqCst);
}
fn on_conversion_complete(&self, _total_pages: usize, success_count: usize) {
self.completed_total.store(success_count, Ordering::SeqCst);
}
}
#[test]
fn noop_callback_does_not_panic() {
let cb = NoopProgressCallback;
cb.on_conversion_start(5);
cb.on_page_start(1, 5);
cb.on_page_complete(1, 5, 42);
cb.on_page_error(2, 5, "some error".to_string());
cb.on_conversion_complete(5, 4);
}
#[test]
fn tracking_callback_receives_events() {
let tracker = TrackingCallback {
starts: Arc::new(AtomicUsize::new(0)),
completes: Arc::new(AtomicUsize::new(0)),
errors: Arc::new(AtomicUsize::new(0)),
started_total: Arc::new(AtomicUsize::new(0)),
completed_total: Arc::new(AtomicUsize::new(0)),
};
tracker.on_conversion_start(3);
assert_eq!(tracker.started_total.load(Ordering::SeqCst), 3);
tracker.on_page_start(1, 3);
tracker.on_page_complete(1, 3, 100);
tracker.on_page_start(2, 3);
tracker.on_page_complete(2, 3, 200);
tracker.on_page_start(3, 3);
tracker.on_page_error(3, 3, "VLM timeout".to_string());
assert_eq!(tracker.starts.load(Ordering::SeqCst), 3);
assert_eq!(tracker.completes.load(Ordering::SeqCst), 2);
assert_eq!(tracker.errors.load(Ordering::SeqCst), 1);
tracker.on_conversion_complete(3, 2);
assert_eq!(tracker.completed_total.load(Ordering::SeqCst), 2);
}
#[test]
fn arc_dyn_callback_works() {
let cb: Arc<dyn ConversionProgressCallback> = Arc::new(NoopProgressCallback);
cb.on_conversion_start(10);
cb.on_page_start(1, 10);
cb.on_page_complete(1, 10, 512);
}
#[tokio::test]
async fn on_page_error_is_send_when_used_in_spawn() {
use std::sync::Mutex;
struct StringCollector {
errors: Arc<Mutex<Vec<String>>>,
}
impl ConversionProgressCallback for StringCollector {
fn on_page_error(&self, _page_num: usize, _total_pages: usize, error: String) {
self.errors.lock().unwrap().push(error);
}
}
let collector = Arc::new(StringCollector {
errors: Arc::new(Mutex::new(Vec::new())),
});
let cb: Arc<dyn ConversionProgressCallback> =
Arc::clone(&collector) as Arc<dyn ConversionProgressCallback>;
tokio::spawn(async move {
cb.on_page_error(1, 5, "error from spawn".to_string());
})
.await
.unwrap();
let errors = collector.errors.lock().unwrap();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0], "error from spawn");
}
#[test]
fn on_page_error_receives_owned_string() {
use std::sync::Mutex;
struct ErrorCapture {
captured: Arc<Mutex<Option<String>>>,
}
impl ConversionProgressCallback for ErrorCapture {
fn on_page_error(&self, _p: usize, _t: usize, error: String) {
*self.captured.lock().unwrap() = Some(error);
}
}
let capture = ErrorCapture {
captured: Arc::new(Mutex::new(None)),
};
let long_error = "x".repeat(200);
capture.on_page_error(3, 10, long_error.clone());
let got = capture.captured.lock().unwrap().clone().unwrap();
assert_eq!(got, long_error, "Full error string should be forwarded");
}
}