use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use crate::commands::Responder;
use swiftide::traits::{Transformer, WithIndexingDefaults};
use tokio_util::task::AbortOnDropHandle;
#[derive(Debug, Clone)]
pub struct ProgressUpdater {
responder: Option<Arc<dyn Responder>>,
updater: Option<UnboundedSender<String>>,
total_chunks: Arc<AtomicU64>,
processed_chunks: Arc<AtomicU64>,
}
impl From<Option<Arc<dyn Responder>>> for ProgressUpdater {
fn from(responder: Option<Arc<dyn Responder>>) -> Self {
Self {
responder,
updater: None,
total_chunks: Arc::new(AtomicU64::new(0)),
processed_chunks: Arc::new(AtomicU64::new(0)),
}
}
}
impl ProgressUpdater {
pub fn spawn(&mut self) -> AbortOnDropHandle<()> {
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
self.updater = Some(sender);
let responder = self.responder.take();
tracing::info!("Spawning progress updater");
AbortOnDropHandle::new(tokio::spawn(async move {
tracing::info!("Starting progress updater");
while let Some(update) = receiver.recv().await {
tracing::info!("Sending update: {}", update);
if let Some(responder) = &responder {
responder.update(&update).await;
}
}
}))
}
pub fn count_processed_fn(&self) -> impl Transformer + WithIndexingDefaults + 'static {
let processed_chunks = Arc::clone(&self.processed_chunks);
let total_chunks = Arc::clone(&self.total_chunks);
let updater = self
.updater
.as_ref()
.expect("Progress updater not initialized")
.clone();
move |node| {
let current = processed_chunks.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
tracing::debug!("Processed chunks: {}", current);
let _ = updater.send(format!(
"Indexing a bit of code {}/{}",
current,
total_chunks.load(std::sync::atomic::Ordering::Relaxed)
));
Ok(node)
}
}
pub fn count_total_fn(&self) -> impl Transformer + WithIndexingDefaults + 'static {
let processed_chunks = Arc::clone(&self.processed_chunks);
let total_chunks = Arc::clone(&self.total_chunks);
let updater = self
.updater
.as_ref()
.expect("Progress updater not initialized")
.clone();
move |node| {
let total_chunks = total_chunks.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
tracing::debug!("Total chunks: {}", total_chunks);
let _ = updater.send(format!(
"Indexing a bit of code {}/{}",
processed_chunks.load(std::sync::atomic::Ordering::Relaxed),
total_chunks
));
Ok(node)
}
}
pub fn send_update(&self, update: impl Into<String>) {
if let Some(updater) = &self.updater {
let _ = updater.send(update.into());
}
}
}
#[cfg(test)]
mod tests {
use mockall::predicate;
use swiftide::indexing::Node;
use crate::commands::{CommandResponse, MockResponder};
use super::*;
use std::sync::atomic::Ordering;
#[tokio::test]
async fn test_spawn() {
let mut responder = MockResponder::default();
responder
.expect_send()
.with(predicate::eq(CommandResponse::Activity(
"Test update".to_string(),
)))
.once();
let responder = Arc::new(responder);
let mut updater = ProgressUpdater::from(Some(responder.clone() as Arc<dyn Responder>));
let _handle = updater.spawn();
updater.send_update("Test update");
}
#[test_log::test(tokio::test)]
async fn test_count_processed_fn() {
let mut responder = MockResponder::new();
responder
.expect_send()
.with(predicate::eq(CommandResponse::Activity(
"Indexing a bit of code 1/0".to_string(),
)))
.returning(|_| ())
.once();
let responder = Arc::new(responder);
let mut updater = ProgressUpdater::from(Some(responder.clone() as Arc<dyn Responder>));
let _handle = updater.spawn();
let transformer = updater.count_processed_fn();
let node = Node::default();
let result = transformer.transform_node(node).await.unwrap();
tokio::task::yield_now().await;
assert_eq!(updater.processed_chunks.load(Ordering::Relaxed), 1);
assert_eq!(result, Node::default());
}
#[tokio::test]
async fn test_count_total_fn() {
let mut responder = MockResponder::default();
responder
.expect_send()
.with(predicate::eq(CommandResponse::Activity(
"Indexing a bit of code 0/1".to_string(),
)))
.returning(|_| ())
.once();
let responder = Arc::new(responder);
let mut updater = ProgressUpdater::from(Some(responder.clone() as Arc<dyn Responder>));
let _handle = updater.spawn();
let transformer = updater.count_total_fn();
let node = Node::default();
let result = transformer.transform_node(node).await.unwrap();
tokio::task::yield_now().await;
assert_eq!(updater.total_chunks.load(Ordering::Relaxed), 1);
assert_eq!(result, Node::default());
}
}