use pyo3::prelude::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::Duration;
use super::conversions::{entry_to_event_dict, parse_hex_hash};
use super::PyGraphStore;
use crate::entry::Hash;
pub struct NotifyBell {
counter: Mutex<u64>,
cvar: Condvar,
closed: AtomicBool,
}
impl NotifyBell {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn notify(&self) {
let mut guard = self.counter.lock().unwrap();
*guard = guard.wrapping_add(1);
self.cvar.notify_all();
}
pub fn current(&self) -> u64 {
*self.counter.lock().unwrap()
}
pub fn wait_until_changed(&self, last_seen: u64, timeout: Duration) {
let guard = self.counter.lock().unwrap();
let _ = self
.cvar
.wait_timeout_while(guard, timeout, |c| *c == last_seen)
.unwrap();
}
pub fn close(&self) {
self.closed.store(true, Ordering::Release);
self.cvar.notify_all();
}
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::Acquire)
}
}
impl Default for NotifyBell {
fn default() -> Self {
Self {
counter: Mutex::new(0),
cvar: Condvar::new(),
closed: AtomicBool::new(false),
}
}
}
#[pyclass(name = "TailSubscription", module = "silk")]
pub struct PyTailSubscription {
store: Py<PyGraphStore>,
cursor: Mutex<Vec<Hash>>,
closed: AtomicBool,
bell: Arc<NotifyBell>,
}
impl PyTailSubscription {
pub fn new(store: Py<PyGraphStore>, cursor: Vec<Hash>, bell: Arc<NotifyBell>) -> Self {
Self {
store,
cursor: Mutex::new(cursor),
closed: AtomicBool::new(false),
bell,
}
}
}
#[pymethods]
impl PyTailSubscription {
#[pyo3(signature = (timeout_ms=0, max_count=1000))]
fn next_batch(
&self,
py: Python<'_>,
timeout_ms: u64,
max_count: usize,
) -> PyResult<Vec<PyObject>> {
if self.closed.load(Ordering::Acquire) {
return Ok(vec![]);
}
let last_seen = self.bell.current();
if let Some(entries) = self.try_fetch(py, max_count)? {
return Ok(entries);
}
if timeout_ms == 0 {
return Ok(vec![]);
}
let bell = Arc::clone(&self.bell);
py.allow_threads(move || {
bell.wait_until_changed(last_seen, Duration::from_millis(timeout_ms));
});
if self.closed.load(Ordering::Acquire) || self.bell.is_closed() {
return Ok(vec![]);
}
Ok(self.try_fetch(py, max_count)?.unwrap_or_default())
}
fn current_cursor(&self) -> Vec<String> {
self.cursor
.lock()
.unwrap()
.iter()
.map(hex::encode)
.collect()
}
fn close(&self) {
self.closed.store(true, Ordering::Release);
self.bell.notify();
}
fn __repr__(&self) -> String {
let heads = self.cursor.lock().unwrap().len();
format!(
"TailSubscription(cursor={} heads, closed={})",
heads,
self.closed.load(Ordering::Acquire)
)
}
}
impl PyTailSubscription {
fn try_fetch(&self, py: Python<'_>, max_count: usize) -> PyResult<Option<Vec<PyObject>>> {
let cursor_snapshot = {
let guard = self.cursor.lock().unwrap();
guard.clone()
};
let borrowed = self.store.bind(py).borrow();
let oplog = borrowed.backend_oplog();
let entries = oplog
.entries_since_heads(&cursor_snapshot)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("stale cursor: {e}")))?;
if entries.is_empty() {
return Ok(None);
}
let truncated = entries.len() > max_count;
let new_cursor = if truncated {
vec![entries[max_count - 1].hash]
} else {
oplog.heads()
};
let py_entries: Vec<PyObject> = entries
.iter()
.take(max_count)
.map(|e| entry_to_event_dict(py, e, false))
.collect::<PyResult<Vec<_>>>()?;
{
let mut guard = self.cursor.lock().unwrap();
*guard = new_cursor;
}
Ok(Some(py_entries))
}
}
pub fn parse_cursor(cursor: Vec<String>) -> PyResult<Vec<Hash>> {
cursor
.iter()
.map(|s| parse_hex_hash(s))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("invalid cursor: {e}")))
}