use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use crossbeam_queue::SegQueue;
use tokio::sync::broadcast;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum LoadPhase {
Connecting = 0,
Querying = 1,
Fetching = 2,
Parsing = 3,
Converting = 4,
}
impl LoadPhase {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(Self::Connecting),
1 => Some(Self::Querying),
2 => Some(Self::Fetching),
3 => Some(Self::Parsing),
4 => Some(Self::Converting),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Connecting => "Connecting",
Self::Querying => "Querying",
Self::Fetching => "Fetching",
Self::Parsing => "Parsing",
Self::Converting => "Converting",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum ProgressGranularity {
#[default]
Coarse = 0,
Fine = 1,
}
impl ProgressGranularity {
pub fn from_u8(value: u8) -> Self {
match value {
1 => Self::Fine,
_ => Self::Coarse,
}
}
}
#[derive(Debug, Clone)]
pub enum ProgressEvent {
Phase {
operation_id: u64,
phase: LoadPhase,
source: String,
},
Progress {
operation_id: u64,
rows_processed: u64,
total_rows: Option<u64>,
bytes_processed: u64,
},
Complete {
operation_id: u64,
rows_loaded: u64,
duration_ms: u64,
},
Error { operation_id: u64, message: String },
}
impl ProgressEvent {
pub fn operation_id(&self) -> u64 {
match self {
Self::Phase { operation_id, .. } => *operation_id,
Self::Progress { operation_id, .. } => *operation_id,
Self::Complete { operation_id, .. } => *operation_id,
Self::Error { operation_id, .. } => *operation_id,
}
}
}
pub struct ProgressHandle {
operation_id: u64,
source: String,
registry: Arc<ProgressRegistry>,
start_time: Instant,
granularity: ProgressGranularity,
}
impl ProgressHandle {
pub fn phase(&self, phase: LoadPhase) {
self.registry.emit(ProgressEvent::Phase {
operation_id: self.operation_id,
phase,
source: self.source.clone(),
});
}
pub fn progress(&self, rows_processed: u64, total_rows: Option<u64>, bytes_processed: u64) {
if self.granularity == ProgressGranularity::Fine {
self.registry.emit(ProgressEvent::Progress {
operation_id: self.operation_id,
rows_processed,
total_rows,
bytes_processed,
});
}
}
pub fn complete(self, rows_loaded: u64) {
let duration_ms = self.start_time.elapsed().as_millis() as u64;
self.registry.emit(ProgressEvent::Complete {
operation_id: self.operation_id,
rows_loaded,
duration_ms,
});
}
pub fn error(self, message: String) {
self.registry.emit(ProgressEvent::Error {
operation_id: self.operation_id,
message,
});
}
pub fn operation_id(&self) -> u64 {
self.operation_id
}
pub fn granularity(&self) -> ProgressGranularity {
self.granularity
}
}
pub struct ProgressRegistry {
events: SegQueue<ProgressEvent>,
broadcast_tx: broadcast::Sender<ProgressEvent>,
next_id: AtomicU64,
}
impl ProgressRegistry {
pub fn new() -> Arc<Self> {
let (broadcast_tx, _) = broadcast::channel(256);
Arc::new(Self {
events: SegQueue::new(),
broadcast_tx,
next_id: AtomicU64::new(1),
})
}
pub fn start_operation(
self: &Arc<Self>,
source: &str,
granularity: ProgressGranularity,
) -> ProgressHandle {
let operation_id = self.next_id.fetch_add(1, Ordering::SeqCst);
ProgressHandle {
operation_id,
source: source.to_string(),
registry: Arc::clone(self),
start_time: Instant::now(),
granularity,
}
}
fn emit(&self, event: ProgressEvent) {
self.events.push(event.clone());
let _ = self.broadcast_tx.send(event);
}
pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
self.broadcast_tx.subscribe()
}
pub fn poll(&self) -> Option<ProgressEvent> {
self.events.pop()
}
pub fn poll_all(&self) -> Vec<ProgressEvent> {
let mut events = Vec::new();
while let Some(event) = self.events.pop() {
events.push(event);
}
events
}
pub fn try_recv(&self) -> Option<ProgressEvent> {
self.poll()
}
pub fn is_empty(&self) -> bool {
self.events.is_empty()
}
}
impl Default for ProgressRegistry {
fn default() -> Self {
let (broadcast_tx, _) = broadcast::channel(256);
Self {
events: SegQueue::new(),
broadcast_tx,
next_id: AtomicU64::new(1),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progress_handle() {
let registry = ProgressRegistry::new();
let handle = registry.start_operation("test-source", ProgressGranularity::Fine);
handle.phase(LoadPhase::Connecting);
handle.progress(100, Some(1000), 8000);
handle.complete(1000);
let events = registry.poll_all();
assert_eq!(events.len(), 3);
matches!(
&events[0],
ProgressEvent::Phase {
phase: LoadPhase::Connecting,
..
}
);
matches!(
&events[1],
ProgressEvent::Progress {
rows_processed: 100,
..
}
);
matches!(
&events[2],
ProgressEvent::Complete {
rows_loaded: 1000,
..
}
);
}
#[test]
fn test_coarse_granularity_skips_progress() {
let registry = ProgressRegistry::new();
let handle = registry.start_operation("test-source", ProgressGranularity::Coarse);
handle.phase(LoadPhase::Fetching);
handle.progress(100, Some(1000), 8000); handle.complete(1000);
let events = registry.poll_all();
assert_eq!(events.len(), 2); }
#[test]
fn test_load_phase_from_u8() {
assert_eq!(LoadPhase::from_u8(0), Some(LoadPhase::Connecting));
assert_eq!(LoadPhase::from_u8(4), Some(LoadPhase::Converting));
assert_eq!(LoadPhase::from_u8(99), None);
}
}