use std::{
collections::{BTreeMap, HashMap},
net::SocketAddr,
ops::{Deref, RangeInclusive},
time::{Duration, Instant},
};
use axum::Extension;
use bytes;
use futures::{StreamExt, TryStreamExt, future, stream::FuturesUnordered};
use http_body_util::BodyExt;
use hyper::StatusCode;
use hyper_util::client::legacy::{Client, connect::HttpConnector};
use hyper_util::rt::TokioExecutor;
use rand::{SeedableRng, distr::Uniform, prelude::Distribution, rngs::StdRng, seq::IteratorRandom};
use rangemap::RangeInclusiveSet;
use serde::Deserialize;
use serde_json::json;
use tokio::{
sync::mpsc,
time::{MissedTickBehavior, sleep, timeout},
};
use tracing::{debug, info_span};
use uuid::Uuid;
use crate::{
agent::process_multiple_changes,
api::{
peer::parallel_sync,
public::{TimeoutParams, api_v1_db_schema, api_v1_transactions},
},
transport::Transport,
};
use klukai_tests::*;
use klukai_types::{
actor::ActorId,
agent::Agent,
api::{ColumnName, TableName},
api::{ExecResponse, ExecResult, Statement},
base::{CrsqlDbVersion, CrsqlDbVersionRange, CrsqlSeq},
broadcast::{ChangeSource, ChangeV1, Changeset},
change::Change,
change::row_to_change,
dbsr, dbsri, dbvri,
pubsub::pack_columns,
spawn::wait_for_all_pending_handles,
sync::generate_sync,
tripwire::Tripwire,
};
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn insert_rows_and_gossip() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let ta2 = launch_test_agent(
|conf| {
conf.bootstrap(vec![ta1.agent.gossip_addr().to_string()])
.build()
},
tripwire.clone(),
)
.await?;
let client: Client<HttpConnector, http_body_util::Full<bytes::Bytes>> =
Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(5)
.pool_idle_timeout(Duration::from_secs(300))
.build(HttpConnector::new());
let req_body: Vec<Statement> = serde_json::from_value(json!([[
"INSERT INTO tests (id,text) VALUES (?,?)",
[1, "hello world 1"]
],]))?;
let res = timeout(
Duration::from_secs(5),
client.request(
hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{}/v1/transactions", ta1.agent.api_addr()))
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&req_body)?.into())?,
),
)
.await??;
let body: ExecResponse = serde_json::from_slice(&res.into_body().collect().await?.to_bytes())?;
let db_version: CrsqlDbVersion =
ta1.agent
.pool()
.read()
.await?
.query_row("SELECT crsql_db_version();", (), |row| row.get(0))?;
assert_eq!(db_version, CrsqlDbVersion(1));
println!("body: {body:?}");
let svc: TestRecord = ta1.agent.pool().read().await?.query_row(
"SELECT id, text FROM tests WHERE id = 1;",
[],
|row| {
Ok(TestRecord {
id: row.get(0)?,
text: row.get(1)?,
})
},
)?;
assert_eq!(svc.id, 1);
assert_eq!(svc.text, "hello world 1");
sleep(Duration::from_secs(1)).await;
let svc: TestRecord = ta2.agent.pool().read().await?.query_row(
"SELECT id, text FROM tests WHERE id = 1;",
[],
|row| {
Ok(TestRecord {
id: row.get(0)?,
text: row.get(1)?,
})
},
)?;
assert_eq!(svc.id, 1);
assert_eq!(svc.text, "hello world 1");
let req_body: Vec<Statement> = serde_json::from_value(json!([[
"INSERT INTO tests (id,text) VALUES (?,?)",
[2, "hello world 2"]
]]))?;
let res = client
.request(
hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{}/v1/transactions", ta1.agent.api_addr()))
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&req_body)?.into())?,
)
.await?;
let body: ExecResponse = serde_json::from_slice(&res.into_body().collect().await?.to_bytes())?;
println!("body: {body:?}");
println!("checking crsql_changes");
#[allow(clippy::type_complexity)]
let bk: Vec<(ActorId, CrsqlDbVersion, Option<CrsqlSeq>)> = ta1
.agent
.pool()
.read()
.await?
.prepare("SELECT site_id, db_version, max(seq) FROM crsql_changes group by db_version")?
.query_map((), |row| {
Ok((
row.get::<_, ActorId>(0)?,
row.get::<_, CrsqlDbVersion>(1)?,
row.get::<_, Option<CrsqlSeq>>(2)?,
))
})?
.collect::<rusqlite::Result<_>>()?;
assert_eq!(
bk,
vec![
(ta1.agent.actor_id(), CrsqlDbVersion(1), Some(CrsqlSeq(0))),
(ta1.agent.actor_id(), CrsqlDbVersion(2), Some(CrsqlSeq(0)))
]
);
let svc: TestRecord = ta1.agent.pool().read().await?.query_row(
"SELECT id, text FROM tests WHERE id = 2;",
[],
|row| {
Ok(TestRecord {
id: row.get(0)?,
text: row.get(1)?,
})
},
)?;
assert_eq!(svc.id, 2);
assert_eq!(svc.text, "hello world 2");
sleep(Duration::from_secs(1)).await;
let svc: TestRecord = ta2.agent.pool().read().await?.query_row(
"SELECT id, text FROM tests WHERE id = 2;",
[],
|row| {
Ok(TestRecord {
id: row.get(0)?,
text: row.get(1)?,
})
},
)?;
assert_eq!(svc.id, 2);
assert_eq!(svc.text, "hello world 2");
let values: Vec<serde_json::Value> = (3..1000)
.map(|id| {
serde_json::json!([
"INSERT INTO tests (id,text) VALUES (?,?)",
[id, format!("hello world #{id}")],
])
})
.collect();
let req_body: Vec<Statement> = serde_json::from_value(json!(values))?;
timeout(
Duration::from_secs(5),
client.request(
hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{}/v1/transactions", ta1.agent.api_addr()))
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&req_body)?.into())?,
),
)
.await??;
let db_version: CrsqlDbVersion =
ta1.agent
.pool()
.read()
.await?
.query_row("SELECT crsql_db_version();", (), |row| row.get(0))?;
assert_eq!(db_version, CrsqlDbVersion(3));
let expected_count: i64 =
ta1.agent
.pool()
.read()
.await?
.query_row("SELECT COUNT(*) FROM tests", (), |row| row.get(0))?;
sleep(Duration::from_secs(5)).await;
let got_count: i64 =
ta2.agent
.pool()
.read()
.await?
.query_row("SELECT COUNT(*) FROM tests", (), |row| row.get(0))?;
assert_eq!(expected_count, got_count);
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn chill_test() -> eyre::Result<()> {
configurable_stress_test(2, 1, 4).await
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn stress_test() -> eyre::Result<()> {
configurable_stress_test(30, 10, 200).await
}
#[ignore]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn stresser_test() -> eyre::Result<()> {
configurable_stress_test(45, 15, 1500).await
}
#[allow(unused)]
pub async fn configurable_stress_test(
num_nodes: usize,
connectivity: usize,
input_count: usize,
) -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let agents = futures::stream::iter(
(0..num_nodes).map(|n| "127.0.0.1:0".parse().map(move |addr| (n, addr))),
)
.try_chunks(50)
.try_fold(vec![], {
let tripwire = tripwire.clone();
move |mut agents: Vec<TestAgent>, to_launch| {
let tripwire = tripwire.clone();
async move {
for (n, gossip_addr) in to_launch {
println!("LAUNCHING AGENT #{n}");
let mut rng = StdRng::from_os_rng();
let bootstrap = agents
.iter()
.map(|ta| ta.agent.gossip_addr())
.choose_multiple(&mut rng, connectivity);
agents.push(
launch_test_agent(
|conf| {
conf.gossip_addr(gossip_addr)
.bootstrap(
bootstrap
.iter()
.map(SocketAddr::to_string)
.collect::<Vec<String>>(),
)
.build()
},
tripwire.clone(),
)
.await
.unwrap(),
);
}
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(agents)
}
}
})
.await?;
let client: Client<HttpConnector, http_body_util::Full<bytes::Bytes>> =
Client::builder(TokioExecutor::new()).build(HttpConnector::new());
let addrs: Vec<(ActorId, SocketAddr)> = agents
.iter()
.map(|ta| (ta.agent.actor_id(), ta.agent.api_addr()))
.collect();
let iter = (0..input_count).flat_map(|n| {
serde_json::from_value::<Vec<Statement>>(json!([
[
"INSERT INTO tests (id,text) VALUES (?,?)",
[n, format!("hello world {n}")]
],
[
"INSERT INTO tests2 (id,text) VALUES (?,?)",
[n, format!("hello world {n}")]
],
[
"INSERT INTO tests (id,text) VALUES (?,?)",
[n + 10000, format!("hello world {n}")]
],
[
"INSERT INTO tests2 (id,text) VALUES (?,?)",
[n + 10000, format!("hello world {n}")]
]
]))
.unwrap()
});
let actor_versions = {
let client: Client<HttpConnector, http_body_util::Full<bytes::Bytes>> =
Client::builder(TokioExecutor::new()).build(HttpConnector::new());
tokio_stream::StreamExt::map(futures::stream::iter(iter).chunks(20), {
let addrs = addrs.clone();
let client = client.clone();
move |statements| {
let addrs = addrs.clone();
let client = client.clone();
Ok(async move {
let mut rng = StdRng::from_os_rng();
let (actor_id, chosen) = addrs.iter().choose(&mut rng).unwrap();
let res = client
.request(
hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{chosen}/v1/transactions"))
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&statements)?.into())?,
)
.await?;
if res.status() != StatusCode::OK {
eyre::bail!("unexpected status code: {}", res.status());
}
let body: ExecResponse =
serde_json::from_slice(&res.into_body().collect().await?.to_bytes())?;
for (i, statement) in statements.iter().enumerate() {
if !matches!(
body.results[i],
ExecResult::Execute {
rows_affected: 1,
..
}
) {
eyre::bail!("unexpected exec result for statement {i}: {statement:?}");
}
}
Ok::<_, eyre::Report>((*actor_id, 1))
})
}
})
.try_buffer_unordered(10)
.try_fold(BTreeMap::new(), |mut acc, item| {
{
*acc.entry(item.0).or_insert(0) += item.1
}
future::ready(Ok(acc))
})
.await?
};
let changes_count: i64 = 4 * input_count as i64;
println!("expecting {changes_count} ops");
let start = Instant::now();
let mut interval = tokio::time::interval(Duration::from_secs(1));
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
debug!("looping");
for ta in agents.iter() {
let registry = ta.bookie.registry();
let r = registry.map.read();
for v in r.values() {
println!(
"{}: GOT A LOCK: {} has been locked for {:?}",
ta.agent.actor_id(),
v.label,
v.started_at.elapsed()
);
}
}
tokio::time::sleep(Duration::from_secs(1)).await;
println!("checking status after {}s", start.elapsed().as_secs_f32());
let mut v = vec![];
for ta in agents.iter() {
let span = info_span!("consistency", actor_id = %ta.agent.actor_id().0);
let _entered = span.enter();
let conn = ta.agent.pool().read().await?;
let counts: HashMap<ActorId, i64> = conn
.prepare_cached("SELECT site_id, count(*) FROM crsql_changes GROUP BY site_id;")?
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
.collect::<rusqlite::Result<_>>()?;
debug!("versions count: {counts:?}");
let actual_count: i64 =
conn.query_row("SELECT count(*) FROM crsql_changes;", (), |row| row.get(0))?;
debug!("actual count: {actual_count}");
debug!(
"last version: {:?}",
ta.bookie
.write::<&str, _>("test", None)
.await
.ensure(ta.agent.actor_id())
.read::<&str, _>("test", None)
.await
.last()
);
let sync = generate_sync(&ta.bookie, ta.agent.actor_id()).await;
let needed = sync.need_len();
debug!("generated sync: {sync:?}");
v.push((counts.values().sum::<i64>(), needed));
}
if v.len() != agents.len() {
println!("got {} actors, expecting {}", v.len(), agents.len());
}
if v.len() == agents.len()
&& v.iter()
.all(|(n, needed)| *n == changes_count && *needed == 0)
{
break;
}
println!("we're not done yet...");
if start.elapsed() > Duration::from_secs(30) {
for ta in agents.iter() {
let conn = ta.agent.pool().read().await?;
let mut per_actor: BTreeMap<ActorId, RangeInclusiveSet<CrsqlDbVersion>> =
BTreeMap::new();
let mut prepped =
conn.prepare("SELECT DISTINCT site_id, db_version FROM crsql_changes;")?;
let mut rows = prepped.query(())?;
while let Ok(Some(row)) = rows.next() {
per_actor
.entry(row.get(0)?)
.or_default()
.insert(row.get(1)?..=row.get(1)?);
}
let actual_count: i64 =
conn.query_row("SELECT count(*) FROM crsql_changes;", (), |row| row.get(0))?;
debug!("actual count: {actual_count}");
if actual_count != changes_count {
println!(
"{}: still missing {} rows in crsql_changes",
ta.agent.actor_id(),
changes_count - actual_count
);
}
for (actor_id, versions) in per_actor {
if let Some(versions_len) = actor_versions.get(&actor_id) {
let full_range = CrsqlDbVersion(1)..=CrsqlDbVersion(*versions_len as u64);
let gaps = versions.gaps(&full_range);
for gap in gaps {
println!("{} db gap! {actor_id} => {gap:?}", ta.agent.actor_id());
}
}
}
let sync = generate_sync(&ta.bookie, ta.agent.actor_id()).await;
for (actor, versions) in sync.need {
println!(
"{}: in-memory gap: {actor:?} from {versions:?}",
ta.agent.actor_id()
);
}
let recorded_gaps = conn
.prepare("SELECT actor_id, start, end FROM __corro_bookkeeping_gaps")?
.query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
.collect::<Result<Vec<(ActorId, CrsqlDbVersion, CrsqlDbVersion)>, _>>()?;
for (actor_id, start, end) in recorded_gaps {
println!(
"{} recorded gap: {actor_id} => {start}..={end}",
ta.agent.actor_id()
);
}
}
panic!(
"failed to disseminate all updates to all nodes in {}s",
start.elapsed().as_secs_f32()
);
}
}
println!("fully disseminated in {}s", start.elapsed().as_secs_f32());
println!("checking gaps in db...");
for ta in agents {
let conn = ta.agent.pool().read().await?;
let gaps_count: u64 =
conn.query_row("SELECT count(*) FROM __corro_bookkeeping_gaps", [], |row| {
row.get(0)
})?;
assert_eq!(
gaps_count,
0,
"expected {} to have 0 gaps in DB",
ta.agent.actor_id()
);
}
println!("waiting for things to shut down");
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn large_tx_sync() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let client: Client<HttpConnector, http_body_util::Full<bytes::Bytes>> =
Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(5)
.pool_idle_timeout(Duration::from_secs(300))
.build(HttpConnector::new());
let counts = [
10000, 1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 1000, 900, 800, 700, 600, 500,
400, 300, 200, 100, 1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 1000, 900, 800, 700,
600, 500, 400, 300, 200, 100, 1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 1000, 900,
800, 700, 600, 500, 400, 300, 200, 100, 1000, 900, 800, 700, 600, 500, 400, 300, 200, 100,
1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 1000, 900, 800, 700, 600, 500, 400, 300,
200, 100, 1000, 900, 800, 700, 600, 500, 400, 300, 200, 100,
];
for n in counts.iter() {
let req_body: Vec<Statement> = serde_json::from_value(json!([format!(
"INSERT INTO testsbool (id) WITH RECURSIVE cte(id) AS ( SELECT random() UNION ALL SELECT random() FROM cte LIMIT {n} ) SELECT id FROM cte;"
)]))?;
let res = timeout(
Duration::from_secs(5),
client.request(
hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{}/v1/transactions", ta1.agent.api_addr()))
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&req_body)?.into())?,
),
)
.await??;
let body: ExecResponse =
serde_json::from_slice(&res.into_body().collect().await?.to_bytes())?;
println!("body: {body:?}");
}
let expected_count = counts.into_iter().sum::<usize>();
let db_version: CrsqlDbVersion =
ta1.agent
.pool()
.read()
.await?
.query_row("SELECT crsql_db_version();", (), |row| row.get(0))?;
assert_eq!(db_version, CrsqlDbVersion(counts.len() as u64));
println!("expected count: {expected_count}");
let ta2 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let ta3 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let ta4 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let (rtt_tx, _rtt_rx) = mpsc::channel(1024);
let ta2_transport = Transport::new(&ta2.agent.config().gossip, rtt_tx.clone()).await?;
let ta3_transport = Transport::new(&ta3.agent.config().gossip, rtt_tx.clone()).await?;
let ta4_transport = Transport::new(&ta4.agent.config().gossip, rtt_tx.clone()).await?;
println!("starting sync!?");
for _ in 0..7 {
let res = parallel_sync(
&ta2.agent,
&ta2_transport,
vec![(ta1.agent.actor_id(), ta1.agent.gossip_addr())],
generate_sync(&ta2.bookie, ta2.agent.actor_id()).await,
)
.await?;
println!("ta2 synced {res}");
let res = parallel_sync(
&ta3.agent,
&ta3_transport,
vec![
(ta1.agent.actor_id(), ta1.agent.gossip_addr()),
(ta2.agent.actor_id(), ta2.agent.gossip_addr()),
],
generate_sync(&ta3.bookie, ta3.agent.actor_id()).await,
)
.await?;
println!("ta3 synced {res}");
let res = parallel_sync(
&ta4.agent,
&ta4_transport,
vec![
(ta3.agent.actor_id(), ta3.agent.gossip_addr()),
(ta2.agent.actor_id(), ta2.agent.gossip_addr()),
],
generate_sync(&ta4.bookie, ta4.agent.actor_id()).await,
)
.await?;
println!("ta4 synced {res}");
tokio::time::sleep(Duration::from_secs(2)).await;
}
tokio::time::sleep(Duration::from_secs(10)).await;
let mut ta_counts = vec![];
for (name, ta) in [("ta2", &ta2), ("ta3", &ta3), ("ta4", &ta4)] {
let agent = &ta.agent;
let conn = agent.pool().read().await?;
let count: u64 = conn
.prepare_cached("SELECT COUNT(*) FROM testsbool;")?
.query_row((), |row| row.get(0))?;
println!(
"{name}: {:#?}",
generate_sync(&ta.bookie, agent.actor_id()).await
);
println!(
"{name}: bookie: {:?}",
ta.bookie
.read::<&str, _>("test", None)
.await
.get(&ta1.agent.actor_id())
.unwrap()
.read::<&str, _>("test", None)
.await
.deref()
);
if count as usize != expected_count {
let buf_count: Vec<(CrsqlDbVersion, u64)> = conn
.prepare(
"select db_version,count(*) from __corro_buffered_changes group by db_version",
)?
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
.collect::<rusqlite::Result<Vec<_>>>()?;
println!(
"{name}: BUFFERED COUNT: {buf_count:?} (actor_id: {})",
agent.actor_id()
);
let ranges = conn
.prepare("select start_seq, end_seq from __corro_seq_bookkeeping")?
.query_map([], |row| Ok(row.get::<_, u64>(0)?..=row.get::<_, u64>(1)?))?
.collect::<rusqlite::Result<Vec<_>>>()?;
println!("{name}: ranges: {ranges:?}");
}
ta_counts.push((name, agent.actor_id(), count as usize));
}
for (name, actor_id, count) in ta_counts {
assert_eq!(
count, expected_count,
"{name}: actor {actor_id} did not reach {expected_count} rows",
);
}
println!("now waiting for all futures to end");
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}
#[derive(Debug, Deserialize)]
struct TestRecord {
id: i64,
text: String,
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_clear_empty_versions() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let ta2 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let tx_timeout = Duration::from_secs(60);
let (rtt_tx, _rtt_rx) = mpsc::channel(1024);
let ta2_transport = Transport::new(&ta2.agent.config().gossip, rtt_tx.clone()).await?;
let (status_code, _body) = api_v1_db_schema(
Extension(ta1.agent.clone()),
axum::Json(vec![klukai_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
let (status_code, _body) = api_v1_db_schema(
Extension(ta2.agent.clone()),
axum::Json(vec![klukai_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
insert_rows(ta1.agent.clone(), 1, 50).await;
let rows = get_rows(ta1.agent.clone(), vec![(dbvri!(1, 50), None)]).await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
insert_rows(ta1.agent.clone(), 1, 5).await;
insert_rows(ta1.agent.clone(), 10, 10).await;
insert_rows(ta1.agent.clone(), 23, 25).await;
insert_rows(ta1.agent.clone(), 30, 31).await;
let rows = get_rows(
ta1.agent.clone(),
vec![
(dbvri!(51, 55), None),
(dbvri!(56, 56), None),
(dbvri!(57, 59), None),
(dbvri!(60, 60), None),
],
)
.await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![dbvri!(1, 50)],
vec![],
vec![],
vec![],
)
.await?;
let res = parallel_sync(
&ta2.agent,
&ta2_transport,
vec![(ta1.agent.actor_id(), ta1.agent.gossip_addr())],
generate_sync(&ta2.bookie, ta2.agent.actor_id()).await,
)
.await?;
println!("ta2 synced {res}");
sleep(Duration::from_secs(2)).await;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![],
vec![],
vec![],
vec![dbvri!(1, 5), dbvri!(10, 10), dbvri!(23, 25), dbvri!(30, 31)],
)
.await?;
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn process_failed_changes() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let uuid = Uuid::parse_str("00000000-0000-0000-a716-446655440000")?;
let actor_id = ActorId(uuid);
let (status_code, _body) = api_v1_db_schema(
Extension(ta1.agent.clone()),
axum::Json(vec![klukai_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
let ta2 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let (status_code, _body) = api_v1_db_schema(
Extension(ta2.agent.clone()),
axum::Json(vec![klukai_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
for i in 1..=5_i64 {
let (status_code, _) = api_v1_transactions(
Extension(ta2.agent.clone()),
axum::extract::Query(TimeoutParams { timeout: None }),
axum::Json(vec![Statement::WithParams(
"INSERT OR REPLACE INTO tests (id,text) VALUES (?,?)".into(),
vec![i.into(), "service-text".into()],
)]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
}
let mut good_changes = get_rows(ta2.agent.clone(), vec![(dbvri!(1, 5), None)]).await?;
let change6 = Change {
table: TableName("tests".into()),
pk: pack_columns(&[6i64.into()])?,
cid: ColumnName("text".into()),
val: "six".into(),
col_version: 1,
db_version: CrsqlDbVersion(6),
seq: CrsqlSeq(0),
site_id: actor_id.to_bytes(),
cl: 1,
};
let bad_change = Change {
table: TableName("tests".into()),
pk: pack_columns(&[6i64.into()])?,
cid: ColumnName("nonexistent".into()),
val: "six".into(),
col_version: 1,
db_version: CrsqlDbVersion(6),
seq: CrsqlSeq(1),
site_id: actor_id.to_bytes(),
cl: 1,
};
let mut rows = vec![(
ChangeV1 {
actor_id,
changeset: Changeset::Full {
version: CrsqlDbVersion(1),
changes: vec![change6.clone(), bad_change],
seqs: dbsr!(0, 1),
last_seq: CrsqlSeq(1),
ts: Default::default(),
},
},
ChangeSource::Sync,
Instant::now(),
)];
rows.append(&mut good_changes);
let res = process_multiple_changes(
ta1.agent.clone(),
ta1.bookie.clone(),
rows,
Duration::from_secs(60),
)
.await;
assert!(res.is_ok());
let conn = ta1.agent.pool().read().await?;
for i in 1..=5_i64 {
let pk = pack_columns(&[i.into()])?;
let crsql_dbv = conn
.prepare_cached(
r#"SELECT db_version from crsql_changes where "table" = "tests" and pk = ? and site_id = ?"#,
)?
.query_row((pk, ta2.agent.actor_id()), |row| row.get::<_, CrsqlDbVersion>(0))?;
assert_eq!(crsql_dbv, CrsqlDbVersion(i as u64));
let conn = ta1.agent.pool().read().await?;
conn.prepare_cached("SELECT text from tests where id = ?")?
.query_row([i], |row| row.get::<_, String>(0))?;
}
let res = conn
.prepare_cached("SELECT text from tests where id = 6")?
.query_row([], |row| row.get::<_, String>(0));
assert!(res.is_err());
assert_eq!(res, Err(rusqlite::Error::QueryReturnedNoRows));
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_process_multiple_changes() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let ta2 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
let tx_timeout = Duration::from_secs(60);
let (status_code, _body) = api_v1_db_schema(
Extension(ta1.agent.clone()),
axum::Json(vec![klukai_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
let (status_code, _body) = api_v1_db_schema(
Extension(ta2.agent.clone()),
axum::Json(vec![klukai_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
insert_rows(ta1.agent.clone(), 1, 50).await;
let rows = get_rows(ta1.agent.clone(), vec![(dbvri!(1, 5), None)]).await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![dbvri!(1, 5)],
vec![],
vec![],
vec![],
)
.await?;
let rows = get_rows(ta1.agent.clone(), vec![(dbvri!(9, 10), None)]).await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![],
vec![dbvri!(6, 8)],
vec![],
vec![],
)
.await?;
let rows = get_rows(
ta1.agent.clone(),
vec![(dbvri!(20, 20), None), (dbvri!(15, 16), Some(dbsri!(0, 0)))],
)
.await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![],
vec![dbvri!(11, 14), dbvri!(17, 19)],
vec![(dbvri!(15, 16), dbsri!(0, 0))],
vec![],
)
.await?;
insert_rows(ta1.agent.clone(), 21, 25).await;
let rows = get_rows(
ta1.agent.clone(),
vec![(dbvri!(21, 21), None), (dbvri!(25, 25), None)],
)
.await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![],
vec![],
vec![],
vec![dbvri!(21, 21), dbvri!(25, 25)],
)
.await?;
let rows = get_rows(
ta1.agent.clone(),
vec![
(dbvri!(14, 18), None),
(dbvri!(15, 16), Some(dbsri!(1, 3))),
(dbvri!(23, 24), None),
],
)
.await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![dbvri!(14, 18), dbvri!(15, 16)],
vec![dbvri!(11, 13), dbvri!(19, 19), dbvri!(22, 22)],
vec![],
vec![dbvri!(23, 25)],
)
.await?;
let rows = get_rows(
ta1.agent.clone(),
vec![
(dbvri!(6, 8), None),
(dbvri!(11, 19), None),
(dbvri!(22, 22), None),
],
)
.await?;
process_multiple_changes(ta2.agent.clone(), ta2.bookie.clone(), rows, tx_timeout).await?;
check_bookie_versions(
ta2.clone(),
ta1.agent.actor_id(),
vec![dbvri!(1, 20)],
vec![],
vec![],
vec![dbvri!(21, 25)],
)
.await?;
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}
async fn check_bookie_versions(
ta: TestAgent,
actor_id: ActorId,
complete: Vec<RangeInclusive<CrsqlDbVersion>>,
gap: Vec<RangeInclusive<CrsqlDbVersion>>,
partials: Vec<(RangeInclusive<CrsqlDbVersion>, RangeInclusive<CrsqlSeq>)>,
cleared: Vec<RangeInclusive<CrsqlDbVersion>>,
) -> eyre::Result<()> {
let conn = ta.agent.pool().read().await?;
let booked = ta
.bookie
.write::<&str, _>("test", None)
.await
.ensure(actor_id);
let bookedv = booked.read::<&str, _>("test", None).await;
for versions in complete {
for version in CrsqlDbVersionRange::from(versions.clone()) {
assert!(!conn.prepare_cached(
"SELECT EXISTS (SELECT 1 FROM __corro_bookkeeping_gaps WHERE actor_id = ? and ? between start and end)")?
.query_row((actor_id, version), |row| row.get::<usize, bool>(0))?);
}
bookedv.contains_all(versions, Some(dbsr!(0, 3)));
}
for (versions, seq) in partials {
for version in CrsqlDbVersionRange::from(versions.clone()) {
let bk: Vec<(ActorId, CrsqlDbVersion, CrsqlSeq, CrsqlSeq)> = conn
.prepare(
"SELECT site_id, db_version, start_seq, end_seq FROM __corro_seq_bookkeeping where db_version = ?",
)?
.query_map([version], |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?))
})?.collect::<rusqlite::Result<Vec<_>>>()?;
assert_eq!(bk, vec![(actor_id, version, *seq.start(), *seq.end())]);
assert!(!conn.prepare_cached(
"SELECT EXISTS (SELECT 1 FROM __corro_bookkeeping_gaps WHERE actor_id = ? and ? BETWEEN start and end)")?
.query_row((actor_id, version), |row| row.get::<usize, bool>(0))?);
let partial = bookedv.get_partial(&version);
assert_ne!(partial, None);
}
bookedv.contains_all(versions, Some(dbsr!(0, 3)));
}
for versions in gap {
for version in CrsqlDbVersionRange::from(versions.clone()) {
let needed = bookedv.needed();
assert!(
needed.contains(&version),
"{version:?} should be in {needed:?}"
);
}
assert!(conn.prepare_cached(
"SELECT EXISTS (SELECT 1 FROM __corro_bookkeeping_gaps WHERE actor_id = ? and start = ? and end = ?)")?
.query_row((actor_id, versions.start(), versions.end()), |row| row.get(0))?);
}
for versions in cleared {
for version in CrsqlDbVersionRange::from(versions) {
assert!(!conn.prepare_cached(
"SELECT EXISTS (SELECT 1 FROM crsql_changes WHERE site_id = ? and db_version = ?)")?
.query_row((actor_id, version), |row| row.get::<usize, bool>(0))?, "Version {version} not cleared in crsql_changes table");
}
}
Ok(())
}
async fn get_rows(
agent: Agent,
v: Vec<(
RangeInclusive<CrsqlDbVersion>,
Option<RangeInclusive<CrsqlSeq>>,
)>,
) -> eyre::Result<Vec<(ChangeV1, ChangeSource, Instant)>> {
let mut result = vec![];
let conn = agent.pool().read().await?;
for versions in v {
for version in CrsqlDbVersionRange::from(versions.0) {
let count: u64 = conn.query_row(
"SELECT COUNT(*) FROM crsql_changes where db_version = ?",
[version],
|row| row.get(0),
)?;
let mut last = 4;
if count > 0 {
last = count - 1;
}
let mut query =
r#"SELECT "table", pk, cid, val, col_version, db_version, seq, site_id, cl
FROM crsql_changes where db_version = ?"#
.to_string();
let changes: Vec<Change>;
let seqs = if let Some(seq) = versions.1.clone() {
let seq_query = " and seq >= ? and seq <= ?";
query += seq_query;
let mut prepped = conn.prepare(&query)?;
changes = prepped
.query_map((version, seq.start(), seq.end()), row_to_change)?
.collect::<Result<Vec<_>, _>>()?;
seq
} else {
let mut prepped = conn.prepare(&query)?;
changes = prepped
.query_map([version], row_to_change)?
.collect::<Result<Vec<_>, _>>()?;
CrsqlSeq(0)..=CrsqlSeq(last)
};
result.push((
ChangeV1 {
actor_id: agent.actor_id(),
changeset: Changeset::Full {
version,
changes,
seqs: seqs.into(),
last_seq: CrsqlSeq(last),
ts: agent.clock().new_timestamp().into(),
},
},
ChangeSource::Broadcast,
Instant::now(),
))
}
}
Ok(result)
}
async fn insert_rows(agent: Agent, start: i64, n: i64) {
for i in start..=n {
let (status_code, _) = api_v1_transactions(
Extension(agent.clone()),
axum::extract::Query(TimeoutParams { timeout: None }),
axum::Json(vec![Statement::WithParams(
"INSERT OR REPLACE INTO tests3 (id,text,text2, num, num2) VALUES (?,?,?,?,?)"
.into(),
vec![
i.into(),
"service-name".into(),
"second text".into(),
(i + 20).into(),
(i + 100).into(),
],
)]),
)
.await;
assert_eq!(status_code, StatusCode::OK);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn many_small_changes() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let agents = futures::StreamExt::fold(futures::stream::iter(0..10).chunks(50), vec![], {
let tripwire = tripwire.clone();
move |mut agents: Vec<TestAgent>, to_launch| {
let tripwire = tripwire.clone();
async move {
for n in to_launch {
println!("LAUNCHING AGENT #{n}");
let mut rng = StdRng::from_os_rng();
let bootstrap = agents
.iter()
.map(|ta| ta.agent.gossip_addr())
.choose_multiple(&mut rng, 10);
agents.push(
launch_test_agent(
|conf| {
conf.gossip_addr("127.0.0.1:0".parse().unwrap())
.bootstrap(
bootstrap
.iter()
.map(SocketAddr::to_string)
.collect::<Vec<String>>(),
)
.build()
},
tripwire.clone(),
)
.await
.unwrap(),
);
}
tokio::time::sleep(Duration::from_secs(1)).await;
agents
}
}
})
.await;
let mut start_id = 0;
let _: () = FuturesUnordered::from_iter(agents.iter().map(|ta| {
let ta = ta.clone();
start_id += 100000;
async move {
tokio::spawn(async move {
let client: Client<HttpConnector, http_body_util::Full<bytes::Bytes>> =
Client::builder(TokioExecutor::new()).build(HttpConnector::new());
let durs = {
let between = Uniform::try_from(100..=1000).unwrap();
let mut rng = rand::rng();
(0..100)
.map(|_| between.sample(&mut rng))
.collect::<Vec<_>>()
};
let api_addr = ta.agent.api_addr();
let actor_id = ta.agent.actor_id();
let _: () = FuturesUnordered::from_iter(durs.into_iter().map(|dur| {
let client = client.clone();
start_id += 1;
async move {
sleep(Duration::from_millis(dur)).await;
let req_body = serde_json::from_value::<Vec<Statement>>(json!([[
"INSERT INTO tests (id,text) VALUES (?,?)",
[start_id, format!("hello from {actor_id}")]
],]))?;
let res = client
.request(
hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{api_addr}/v1/transactions"))
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&req_body)?.into())?,
)
.await?;
if res.status() != StatusCode::OK {
eyre::bail!("bad status code: {}", res.status());
}
let body: ExecResponse =
serde_json::from_slice(&res.into_body().collect().await?.to_bytes())?;
match &body.results[0] {
ExecResult::Execute { .. } => {}
ExecResult::Error { error } => {
eyre::bail!("error: {error}");
}
}
Ok::<_, eyre::Report>(())
}
}))
.try_collect()
.await?;
Ok::<_, eyre::Report>(())
})
.await??;
Ok::<_, eyre::Report>(())
}
}))
.try_collect()
.await?;
sleep(Duration::from_secs(10)).await;
for ta in agents {
let conn = ta.agent.pool().read().await?;
let count: i64 = conn.query_row("SELECT count(*) FROM tests", (), |row| row.get(0))?;
println!("actor: {}, count: {count}", ta.agent.actor_id());
}
tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;
Ok(())
}