use std::io;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_core::Stream;
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use tokio::time::{interval, Duration};
use crate::live::Live;
use crate::progress::{Progress, TaskId};
use crate::text::Text;
pub trait ProgressStreamExt: Stream {
fn track_progress(self, description: &str, total: Option<f64>) -> ProgressStream<Self>
where
Self: Sized;
}
impl<S: Stream> ProgressStreamExt for S {
fn track_progress(self, description: &str, total: Option<f64>) -> ProgressStream<Self>
where
Self: Sized,
{
ProgressStream::new(self, description, total)
}
}
pub struct ProgressStream<S> {
inner: S,
progress: Progress,
task: TaskId,
started: bool,
}
impl<S: Stream> ProgressStream<S> {
pub fn new(inner: S, description: &str, total: Option<f64>) -> Self {
let mut progress = Progress::new(Progress::default_columns()).with_auto_refresh(true);
let task = progress.add_task(description, total);
ProgressStream {
inner,
progress,
task,
started: false,
}
}
pub fn task_id(&self) -> TaskId {
self.task
}
pub fn progress(&self) -> &Progress {
&self.progress
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S: Stream + Unpin> Stream for ProgressStream<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if !this.started {
this.progress.start();
this.started = true;
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(item)) => {
let task_id = this.task;
this.progress.advance(task_id, 1.0);
this.progress.refresh();
Poll::Ready(Some(item))
}
Poll::Ready(None) => {
this.progress.stop();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<S> Drop for ProgressStream<S> {
fn drop(&mut self) {
if self.started {
self.progress.stop();
}
}
}
struct LiveAsyncState {
live: Live,
stopped: bool,
}
pub struct LiveAsync {
state: Arc<Mutex<LiveAsyncState>>,
refresh_handle: Option<JoinHandle<()>>,
refresh_interval: Duration,
started: bool,
}
impl LiveAsync {
pub fn new(renderable: Text) -> Self {
LiveAsync {
state: Arc::new(Mutex::new(LiveAsyncState {
live: Live::new(renderable).with_auto_refresh(false),
stopped: false,
})),
refresh_handle: None,
refresh_interval: Duration::from_millis(250),
started: false,
}
}
#[must_use]
pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
self.refresh_interval = interval;
self
}
pub async fn start(&mut self) {
if self.started {
return;
}
self.started = true;
{
let mut state = self.state.lock().await;
state.live.start();
state.stopped = false;
}
let state = Arc::clone(&self.state);
let interval_duration = self.refresh_interval;
let handle = tokio::spawn(async move {
let mut ticker = interval(interval_duration);
loop {
ticker.tick().await;
let state = state.lock().await;
if state.stopped {
break;
}
state.live.refresh();
}
});
self.refresh_handle = Some(handle);
}
pub async fn update(&mut self, renderable: Text) {
let mut state = self.state.lock().await;
state.live.update_renderable(renderable, true);
}
pub async fn stop(&mut self) {
if !self.started {
return;
}
self.started = false;
{
let mut state = self.state.lock().await;
state.stopped = true;
}
if let Some(handle) = self.refresh_handle.take() {
handle.abort();
let _ = handle.await;
}
let mut state = self.state.lock().await;
state.live.stop();
}
pub fn is_started(&self) -> bool {
self.started
}
pub fn refresh_interval(&self) -> Duration {
self.refresh_interval
}
}
impl Drop for LiveAsync {
fn drop(&mut self) {
if self.started {
if let Some(handle) = self.refresh_handle.take() {
handle.abort();
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum ProgressUpdate {
Set(f64),
Finish,
}
#[derive(Debug, Clone)]
pub struct ProgressSender {
sender: mpsc::Sender<ProgressUpdate>,
}
impl ProgressSender {
pub async fn update(&self, completed: f64) {
let _ = self.sender.send(ProgressUpdate::Set(completed)).await;
}
pub async fn finish(&self) {
let _ = self.sender.send(ProgressUpdate::Finish).await;
}
}
pub struct ProgressChannel {
receiver: mpsc::Receiver<ProgressUpdate>,
progress: Progress,
task: TaskId,
}
impl ProgressChannel {
pub fn new(description: &str) -> (ProgressSender, Self) {
let (sender, receiver) = mpsc::channel(1024);
let mut progress = Progress::new(Progress::default_columns()).with_auto_refresh(true);
let task = progress.add_task(description, None);
(
ProgressSender { sender },
ProgressChannel {
receiver,
progress,
task,
},
)
}
pub fn with_total(description: &str, total: f64) -> (ProgressSender, Self) {
let (sender, receiver) = mpsc::channel(1024);
let mut progress = Progress::new(Progress::default_columns()).with_auto_refresh(true);
let task = progress.add_task(description, Some(total));
(
ProgressSender { sender },
ProgressChannel {
receiver,
progress,
task,
},
)
}
pub async fn run(mut self) {
self.progress.start();
tokio::time::sleep(Duration::from_millis(50)).await;
while let Some(update) = self.receiver.recv().await {
match update {
ProgressUpdate::Set(completed) => {
self.progress
.update(self.task, Some(completed), None, None, None, None);
self.progress.refresh();
}
ProgressUpdate::Finish => {
if let Some(task) = self.progress.get_task(self.task) {
if let Some(total) = task.total {
self.progress
.update(self.task, Some(total), None, None, None, None);
}
}
self.progress.refresh();
break;
}
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
self.progress.stop();
}
pub fn task_id(&self) -> TaskId {
self.task
}
pub fn progress(&self) -> &Progress {
&self.progress
}
}
pub mod fs {
use super::*;
pub async fn read_with_progress(path: &Path, description: &str) -> io::Result<Vec<u8>> {
use tokio::io::AsyncReadExt;
let metadata = tokio::fs::metadata(path).await?;
let total_size = metadata.len() as f64;
let mut file = tokio::fs::File::open(path).await?;
let mut progress = Progress::new(Progress::default_columns()).with_auto_refresh(true);
let task = progress.add_task(description, Some(total_size));
progress.start();
let mut buffer = Vec::with_capacity(total_size as usize);
let mut chunk = vec![0u8; 8192];
let mut bytes_read = 0u64;
loop {
match file.read(&mut chunk).await {
Ok(0) => break,
Ok(n) => {
buffer.extend_from_slice(&chunk[..n]);
bytes_read += n as u64;
progress.update(task, Some(bytes_read as f64), None, None, None, None);
}
Err(e) => {
progress.stop();
return Err(e);
}
}
}
progress.stop();
Ok(buffer)
}
pub async fn copy_with_progress(src: &Path, dst: &Path, description: &str) -> io::Result<u64> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let metadata = tokio::fs::metadata(src).await?;
let total_size = metadata.len() as f64;
let mut src_file = tokio::fs::File::open(src).await?;
let mut dst_file = tokio::fs::File::create(dst).await?;
let mut progress = Progress::new(Progress::default_columns()).with_auto_refresh(true);
let task = progress.add_task(description, Some(total_size));
progress.start();
let mut buffer = vec![0u8; 8192];
let mut total_copied = 0u64;
loop {
match src_file.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => {
if let Err(e) = dst_file.write_all(&buffer[..n]).await {
progress.stop();
return Err(e);
}
total_copied += n as u64;
progress.update(task, Some(total_copied as f64), None, None, None, None);
}
Err(e) => {
progress.stop();
return Err(e);
}
}
}
progress.stop();
Ok(total_copied)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream::{self, StreamExt};
fn test_console() -> crate::console::Console {
crate::console::Console::builder()
.width(80)
.height(25)
.quiet(true)
.markup(false)
.no_color(true)
.force_terminal(true)
.build()
}
#[tokio::test]
async fn test_progress_stream_tracks_items() {
let items: Vec<i32> = vec![1, 2, 3, 4, 5];
let stream = stream::iter(items);
let mut progress_stream = stream.track_progress("Testing", Some(5.0));
let mut count = 0;
while let Some(_) = progress_stream.next().await {
count += 1;
}
assert_eq!(count, 5);
}
#[tokio::test]
async fn test_progress_stream_size_hint() {
let stream = stream::iter(0..100);
let progress_stream = ProgressStream::new(stream, "Testing", Some(100.0));
let (lower, upper) = progress_stream.size_hint();
assert_eq!(lower, 100);
assert_eq!(upper, Some(100));
}
#[tokio::test]
async fn test_progress_channel_basic() {
let (tx, progress) = ProgressChannel::with_total("Test", 100.0);
let worker = tokio::spawn(async move {
for i in 0..=100 {
tx.update(i as f64).await;
}
tx.finish().await;
});
let progress_handle = tokio::spawn(async move {
progress.run().await;
});
let _ = tokio::join!(worker, progress_handle);
}
#[tokio::test]
async fn test_progress_channel_multiple_senders() {
let (tx, progress) = ProgressChannel::with_total("Test", 200.0);
let tx2 = tx.clone();
let worker1 = tokio::spawn(async move {
for i in 0..100 {
tx.update(i as f64).await;
tokio::task::yield_now().await;
}
});
let worker2 = tokio::spawn(async move {
for i in 100..=200 {
tx2.update(i as f64).await;
tokio::task::yield_now().await;
}
tx2.finish().await;
});
let progress_handle = tokio::spawn(async move {
progress.run().await;
});
let _ = tokio::join!(worker1, worker2, progress_handle);
}
#[tokio::test]
async fn test_live_async_lifecycle() {
let mut live = LiveAsync::new(Text::new("Test", crate::style::Style::null()));
assert!(!live.is_started());
live.start().await;
assert!(live.is_started());
live.update(Text::new("Updated", crate::style::Style::null()))
.await;
live.stop().await;
assert!(!live.is_started());
}
#[tokio::test]
async fn test_live_async_double_start_stop() {
let mut live = LiveAsync::new(Text::new("Test", crate::style::Style::null()));
live.start().await;
live.start().await; assert!(live.is_started());
live.stop().await;
live.stop().await; assert!(!live.is_started());
}
#[tokio::test]
async fn test_fs_read_with_progress_small_file() {
let temp_dir = std::env::temp_dir();
let test_file = temp_dir.join("gilt_async_test_read.txt");
let test_content = b"Hello, async world!";
tokio::fs::write(&test_file, test_content).await.unwrap();
let result = fs::read_with_progress(&test_file, "Reading test file").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), test_content);
let _ = tokio::fs::remove_file(&test_file).await;
}
#[tokio::test]
async fn test_fs_copy_with_progress() {
let temp_dir = std::env::temp_dir();
let src_file = temp_dir.join("gilt_async_test_src.txt");
let dst_file = temp_dir.join("gilt_async_test_dst.txt");
let test_content = b"Copy this content!";
tokio::fs::write(&src_file, test_content).await.unwrap();
let result = fs::copy_with_progress(&src_file, &dst_file, "Copying test file").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), test_content.len() as u64);
let copied = tokio::fs::read(&dst_file).await.unwrap();
assert_eq!(copied, test_content);
let _ = tokio::fs::remove_file(&src_file).await;
let _ = tokio::fs::remove_file(&dst_file).await;
}
}