use apalis_core::task::Task;
use apalis_libsql::{
CompactType, Config, SqlContext,
sink::{LibsqlSink, push_tasks},
};
use futures::Sink;
use libsql::Builder;
use std::sync::Arc;
use tempfile::TempDir;
struct TestDb {
db: &'static libsql::Database,
_temp_dir: Arc<TempDir>,
}
async fn setup_test_db() -> TestDb {
let temp_dir = Arc::new(TempDir::new().unwrap());
let db_path = temp_dir.path().join("test_sink.db");
let db = Builder::new_local(db_path.to_str().unwrap())
.build()
.await
.unwrap();
let db_static: &'static libsql::Database = Box::leak(Box::new(db));
let conn = db_static.connect().unwrap();
conn.execute_batch(include_str!("../migrations/001_initial.sql"))
.await
.unwrap();
TestDb {
db: db_static,
_temp_dir: temp_dir,
}
}
#[tokio::test]
async fn test_sink_new() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let sink = LibsqlSink::<(), ()>::new(db, &config);
let debug_str = format!("{:?}", sink);
assert!(debug_str.contains("LibsqlSink"));
assert!(debug_str.contains("TestTask"));
assert!(debug_str.contains("buffer_len"));
}
#[tokio::test]
async fn test_push_tasks_empty() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let empty_buffer = Vec::new();
push_tasks(db, &config, empty_buffer).await.unwrap();
let conn = db.connect().unwrap();
let mut rows = conn
.query("SELECT COUNT(*) FROM Jobs", libsql::params![])
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let count: i64 = row.get(0).unwrap();
assert_eq!(count, 0);
}
}
#[tokio::test]
async fn test_push_tasks_batch() {
let test_db = setup_test_db().await;
let db = test_db.db;
let job_type = "TestTask";
let config = Config::new(job_type);
let mut tasks = Vec::new();
for i in 0..3 {
let ctx = SqlContext::new().with_max_attempts(5);
let mut task = Task::new(CompactType::from(vec![i as u8]));
task.parts.ctx = ctx;
tasks.push(task);
}
push_tasks(db, &config, tasks).await.unwrap();
let conn = db.connect().unwrap();
let mut rows = conn
.query(
"SELECT COUNT(*) FROM Jobs WHERE job_type = ?1",
libsql::params![job_type],
)
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let count: i64 = row.get(0).unwrap();
assert_eq!(count, 3);
}
}
#[tokio::test]
async fn test_sink_poll_ready() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let mut sink = LibsqlSink::<(), ()>::new(db, &config);
use std::pin::Pin;
use std::task::{Context, Poll};
let mut pinned_sink = Pin::new(&mut sink);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let result = pinned_sink.as_mut().poll_ready(cx);
assert!(matches!(result, Poll::Ready(Ok(()))));
}
#[tokio::test]
async fn test_sink_start_send() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let mut sink = LibsqlSink::<(), ()>::new(db, &config);
let ctx = SqlContext::new().with_max_attempts(5);
let mut task = Task::new(CompactType::from(vec![1, 2, 3]));
task.parts.ctx = ctx;
use std::pin::Pin;
use std::task::Context;
let mut pinned_sink = Pin::new(&mut sink);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let _ready = pinned_sink.as_mut().poll_ready(cx);
pinned_sink.as_mut().start_send(task).unwrap();
let conn = db.connect().unwrap();
let mut rows = conn
.query("SELECT COUNT(*) FROM Jobs", libsql::params![])
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let count: i64 = row.get(0).unwrap();
assert_eq!(count, 0); }
}
#[tokio::test]
async fn test_sink_poll_flush_empty() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let mut sink = LibsqlSink::<(), ()>::new(db, &config);
use std::pin::Pin;
use std::task::{Context, Poll};
let mut pinned_sink = Pin::new(&mut sink);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let result = pinned_sink.as_mut().poll_flush(cx);
assert!(matches!(result, Poll::Ready(Ok(()))));
}
#[tokio::test]
async fn test_sink_poll_flush_with_tasks() {
let test_db = setup_test_db().await;
let db = test_db.db;
let job_type = "TestTask";
let config = Config::new(job_type);
let mut sink = LibsqlSink::<(), ()>::new(db, &config);
for i in 0..2 {
let ctx = SqlContext::new().with_max_attempts(5);
let mut task = Task::new(CompactType::from(vec![i as u8]));
task.parts.ctx = ctx;
use std::pin::Pin;
use std::task::Context;
let mut pinned_sink = Pin::new(&mut sink);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let _ready = pinned_sink.as_mut().poll_ready(cx);
pinned_sink.as_mut().start_send(task).unwrap();
}
use std::pin::Pin;
use std::task::{Context, Poll};
let mut pinned_sink = Pin::new(&mut sink);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let result = pinned_sink.as_mut().poll_flush(cx);
assert!(matches!(result, Poll::Ready(Ok(()))));
let conn = db.connect().unwrap();
let mut rows = conn
.query(
"SELECT COUNT(*) FROM Jobs WHERE job_type = ?1",
libsql::params![job_type],
)
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let count: i64 = row.get(0).unwrap();
assert_eq!(count, 2);
}
}
#[tokio::test]
async fn test_sink_poll_close() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let mut sink = LibsqlSink::<(), ()>::new(db, &config);
use std::pin::Pin;
use std::task::{Context, Poll};
let mut pinned_sink = Pin::new(&mut sink);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let result = pinned_sink.as_mut().poll_close(cx);
assert!(matches!(result, Poll::Ready(Ok(()))));
}
#[tokio::test]
async fn test_sink_clone_does_not_copy_buffer() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let mut sink1 = LibsqlSink::<(), ()>::new(db, &config);
let ctx = SqlContext::new().with_max_attempts(5);
let mut task = Task::new(CompactType::from(vec![1, 2, 3]));
task.parts.ctx = ctx;
use std::pin::Pin;
use std::task::Context;
let mut pinned_sink = Pin::new(&mut sink1);
let cx = &mut Context::from_waker(futures::task::noop_waker_ref());
let _ready = pinned_sink.as_mut().poll_ready(cx);
pinned_sink.as_mut().start_send(task).unwrap();
let debug_str1 = format!("{:?}", sink1);
assert!(
debug_str1.contains("buffer_len: 1"),
"Original sink should have 1 task in buffer: {}",
debug_str1
);
let sink2 = sink1.clone();
let debug_str2 = format!("{:?}", sink2);
assert!(
debug_str2.contains("buffer_len: 0"),
"Cloned sink should have empty buffer: {}",
debug_str2
);
let mut pinned_sink = Pin::new(&mut sink1);
let _result = pinned_sink.as_mut().poll_flush(cx);
let conn = db.connect().unwrap();
let mut rows = conn
.query("SELECT COUNT(*) FROM Jobs", libsql::params![])
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let count: i64 = row.get(0).unwrap();
assert_eq!(count, 1);
}
}
#[tokio::test]
async fn test_sink_debug() {
let test_db = setup_test_db().await;
let db = test_db.db;
let config = Config::new("TestTask");
let sink = LibsqlSink::<(), ()>::new(db, &config);
let debug_str = format!("{:?}", sink);
assert!(debug_str.contains("LibsqlSink"));
assert!(debug_str.contains("TestTask"));
}
#[tokio::test]
async fn test_push_tasks_rollback_on_error() {
let test_db = setup_test_db().await;
let db = test_db.db;
let job_type = "TestTask";
let config = Config::new(job_type);
let conn = db.connect().unwrap();
let duplicate_ulid = ulid::Ulid::new();
let duplicate_id = duplicate_ulid.to_string();
conn.execute(
"INSERT INTO Jobs (job, id, job_type, status, attempts, max_attempts, run_at, priority, metadata)
VALUES (?1, ?2, ?3, 'Pending', 0, 3, strftime('%s', 'now'), 0, '{}')",
libsql::params![b"existing_task", duplicate_id.clone(), job_type],
)
.await
.unwrap();
let mut rows = conn
.query(
"SELECT COUNT(*) FROM Jobs WHERE job_type = ?1",
libsql::params![job_type],
)
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let initial_count: i64 = row.get(0).unwrap();
assert_eq!(initial_count, 1);
}
let mut tasks = Vec::new();
let ctx1 = SqlContext::new().with_max_attempts(5);
let mut task1 = Task::new(CompactType::from(vec![1u8]));
task1.parts.ctx = ctx1;
tasks.push(task1);
let ctx2 = SqlContext::new().with_max_attempts(5);
let mut task2 = Task::new(CompactType::from(vec![2u8]));
task2.parts.ctx = ctx2;
task2.parts.task_id = Some(apalis_core::task::task_id::TaskId::new(duplicate_ulid));
tasks.push(task2);
let ctx3 = SqlContext::new().with_max_attempts(5);
let mut task3 = Task::new(CompactType::from(vec![3u8]));
task3.parts.ctx = ctx3;
tasks.push(task3);
let result = push_tasks(db, &config, tasks).await;
assert!(
result.is_err(),
"Expected push_tasks to fail due to duplicate ID"
);
let mut rows = conn
.query(
"SELECT COUNT(*) FROM Jobs WHERE job_type = ?1",
libsql::params![job_type],
)
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let final_count: i64 = row.get(0).unwrap();
assert_eq!(
final_count, 1,
"Expected only the original task to remain after rollback"
);
}
let mut rows = conn
.query(
"SELECT job FROM Jobs WHERE job_type = ?1 AND id = ?2",
libsql::params![job_type, duplicate_id.clone()],
)
.await
.unwrap();
if let Some(row) = rows.next().await.unwrap() {
let job_data: Vec<u8> = row.get(0).unwrap();
assert_eq!(
job_data, b"existing_task",
"Original task data should be preserved"
);
}
}