pub mod sub;
use corro_api_types::{
ChangeId, ExecResponse, ExecResult, SqliteValue, Statement, QUERY_HASH_HEADER, QUERY_ID_HEADER,
};
use hickory_resolver::net::NetError as ResolveError;
use serde::de::DeserializeOwned;
use std::{
fmt::Write as _,
net::SocketAddr,
ops::Deref,
path::Path,
sync::Arc,
time::{self, Duration, Instant},
};
use sub::{QueryStream, SubscriptionStream, UpdatesStream};
use tokio::{
sync::{RwLock, RwLockReadGuard},
time::timeout,
};
use tracing::{debug, info};
use uuid::Uuid;
const HTTP2_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(10);
const DNS_RESOLVE_TIMEOUT: Duration = Duration::from_secs(3);
type Resolver = hickory_resolver::Resolver<hickory_resolver::net::runtime::TokioRuntimeProvider>;
#[derive(Clone)]
pub struct CorrosionApiClient {
api_addr: SocketAddr,
api_client: reqwest::Client,
}
impl CorrosionApiClient {
pub fn new(api_addr: SocketAddr) -> Result<Self, reqwest::Error> {
Ok(Self {
api_addr,
api_client: reqwest::ClientBuilder::new()
.http2_prior_knowledge()
.connect_timeout(HTTP2_CONNECT_TIMEOUT)
.http2_keep_alive_interval(Some(HTTP2_KEEP_ALIVE_INTERVAL))
.http2_keep_alive_timeout(HTTP2_KEEP_ALIVE_INTERVAL / 2)
.build()?,
})
}
pub async fn query_typed<T: DeserializeOwned + Unpin>(
&self,
statement: &Statement,
timeout: Option<u64>,
) -> Result<QueryStream<T>, Error> {
let mut uri = format!("http://{}/v1/queries", self.api_addr);
if let Some(t) = timeout {
write!(&mut uri, "?timeout={t}").unwrap();
}
let res = self
.api_client
.post(uri)
.header(http::header::CONTENT_TYPE, "application/json")
.header(http::header::ACCEPT, "application/json")
.body(serde_json::to_vec(statement)?)
.send()
.await?;
if !res.status().is_success() {
let status = res.status();
match res.bytes().await {
Ok(b) => match serde_json::from_slice(&b) {
Ok(res) => match res {
ExecResult::Error { error } => return Err(Error::ResponseError(error)),
res => return Err(Error::UnexpectedResult(res)),
},
Err(error) => {
debug!(
%error,
"could not deserialize response body, sending generic error..."
);
return Err(Error::UnexpectedStatusCode(status));
}
},
Err(error) => {
debug!(
%error,
"could not aggregate response body bytes, sending generic error..."
);
return Err(Error::UnexpectedStatusCode(status));
}
}
}
Ok(QueryStream::new(res.into()))
}
pub async fn query(
&self,
statement: &Statement,
timeout: Option<u64>,
) -> Result<QueryStream<Vec<SqliteValue>>, Error> {
self.query_typed(statement, timeout).await
}
pub async fn subscribe_typed<T: DeserializeOwned + Unpin>(
&self,
statement: &Statement,
skip_rows: bool,
from: Option<ChangeId>,
) -> Result<SubscriptionStream<T>, Error> {
let mut uri = format!(
"http://{}/v1/subscriptions?skip_rows={skip_rows}",
self.api_addr
);
if let Some(change_id) = from {
write!(&mut uri, "&from={change_id}").unwrap();
}
let res = self
.api_client
.post(uri)
.header(http::header::CONTENT_TYPE, "application/json")
.header(http::header::ACCEPT, "application/json")
.body(serde_json::to_vec(statement)?)
.send()
.await?;
if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let id = res
.headers()
.get(QUERY_ID_HEADER)
.and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
.ok_or(Error::ExpectedQueryId)?;
let hash = res
.headers()
.get(QUERY_HASH_HEADER)
.and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
Ok(SubscriptionStream::new(
id,
hash,
self.api_client.clone(),
self.api_addr,
res.into(),
from,
))
}
pub async fn subscribe(
&self,
statement: &Statement,
skip_rows: bool,
from: Option<ChangeId>,
) -> Result<SubscriptionStream<Vec<SqliteValue>>, Error> {
self.subscribe_typed(statement, skip_rows, from).await
}
pub async fn subscription_typed<T: DeserializeOwned + Unpin>(
&self,
id: Uuid,
skip_rows: bool,
from: Option<ChangeId>,
) -> Result<SubscriptionStream<T>, Error> {
let mut uri = format!(
"http://{}/v1/subscriptions/{id}?skip_rows={skip_rows}",
self.api_addr
);
if let Some(change_id) = from {
write!(&mut uri, "&from={change_id}").unwrap();
}
let res = self
.api_client
.get(uri)
.header(http::header::ACCEPT, "application/json")
.send()
.await?;
if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let hash = res
.headers()
.get(QUERY_HASH_HEADER)
.and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
Ok(SubscriptionStream::new(
id,
hash,
self.api_client.clone(),
self.api_addr,
res.into(),
from,
))
}
pub async fn subscription(
&self,
id: Uuid,
skip_rows: bool,
from: Option<ChangeId>,
) -> Result<SubscriptionStream<Vec<SqliteValue>>, Error> {
self.subscription_typed(id, skip_rows, from).await
}
pub async fn updates_typed<T: DeserializeOwned + Unpin>(
&self,
table: &str,
) -> Result<UpdatesStream<T>, Error> {
let res = self
.api_client
.post(format!("http://{}/v1/updates/{table}", self.api_addr))
.header(http::header::CONTENT_TYPE, "application/json")
.header(http::header::ACCEPT, "application/json")
.send()
.await?;
if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let id = res
.headers()
.get(QUERY_ID_HEADER)
.and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
.ok_or(Error::ExpectedQueryId)?;
Ok(UpdatesStream::new(id, res.into()))
}
pub async fn updates(&self, table: &str) -> Result<UpdatesStream<Vec<SqliteValue>>, Error> {
self.updates_typed(table).await
}
pub async fn execute(
&self,
statements: &[Statement],
timeout: Option<u64>,
) -> Result<ExecResponse, Error> {
let uri = if let Some(timeout) = timeout {
format!("http://{}/v1/transactions?timeout={timeout}", self.api_addr)
} else {
format!("http://{}/v1/transactions", self.api_addr)
};
let res = self
.api_client
.post(uri)
.header(http::header::CONTENT_TYPE, "application/json")
.header(http::header::ACCEPT, "application/json")
.body(serde_json::to_vec(statements)?)
.send()
.await?;
let status = res.status();
if !status.is_success() {
match res.bytes().await {
Ok(b) => match serde_json::from_slice(&b) {
Ok(ExecResponse { results, .. }) => {
if let Some(ExecResult::Error { error }) = results
.into_iter()
.find(|r| matches!(r, ExecResult::Error { .. }))
{
return Err(Error::ResponseError(error));
}
return Err(Error::UnexpectedStatusCode(status));
}
Err(error) => {
debug!(
%error,
"could not deserialize response body, sending generic error..."
);
return Err(Error::UnexpectedStatusCode(status));
}
},
Err(error) => {
debug!(
%error,
"could not aggregate response body bytes, sending generic error..."
);
return Err(Error::UnexpectedStatusCode(status));
}
}
}
Ok(serde_json::from_slice(&res.bytes().await?)?)
}
}
#[derive(Clone)]
pub struct CorrosionClient {
api_client: CorrosionApiClient,
pool: sqlite_pool::RusqlitePool,
}
impl CorrosionClient {
pub fn new<P: AsRef<Path>>(api_addr: SocketAddr, db_path: P) -> Result<Self, reqwest::Error> {
Ok(Self {
api_client: CorrosionApiClient::new(api_addr)?,
pool: sqlite_pool::Config::new(db_path.as_ref())
.max_size(5)
.create_pool()
.expect("could not build pool, this can't fail because we specified a runtime"),
})
}
pub fn with_sqlite_pool(
api_addr: SocketAddr,
pool: sqlite_pool::RusqlitePool,
) -> Result<Self, reqwest::Error> {
Ok(Self {
api_client: CorrosionApiClient::new(api_addr)?,
pool,
})
}
pub fn pool(&self) -> &sqlite_pool::RusqlitePool {
&self.pool
}
}
impl Deref for CorrosionClient {
type Target = CorrosionApiClient;
fn deref(&self) -> &Self::Target {
&self.api_client
}
}
#[derive(Clone)]
pub struct CorrosionPooledClient {
inner: Arc<RwLock<PooledClientInner>>,
}
struct PooledClientInner {
picker: AddrPicker,
stickiness_timeout: time::Duration,
client: Option<CorrosionApiClient>,
had_success: bool,
first_fail_at: Option<Instant>,
generation: u64,
}
impl CorrosionPooledClient {
pub fn new(addrs: Vec<String>, stickiness_timeout: time::Duration, resolver: Resolver) -> Self {
Self {
inner: Arc::new(RwLock::new(PooledClientInner {
picker: AddrPicker::new(addrs, resolver),
stickiness_timeout,
client: None,
had_success: false,
first_fail_at: None,
generation: 0,
})),
}
}
pub async fn query_typed<T: DeserializeOwned + Unpin>(
&self,
statement: &Statement,
timeout: Option<u64>,
) -> Result<QueryStream<T>, Error> {
let (response, generation) = {
let (client, generation) = self.get_client().await?;
let response = client.query_typed(statement, timeout).await;
(response, generation)
};
if matches!(response, Err(Error::Reqwest(_))) {
self.handle_error(generation).await;
} else {
self.handle_success(generation).await;
}
response
}
pub async fn subscribe_typed<T: DeserializeOwned + Unpin>(
&self,
statement: &Statement,
skip_rows: bool,
from: Option<ChangeId>,
) -> Result<SubscriptionStream<T>, Error> {
let (response, generation) = {
let (client, generation) = self.get_client().await?;
let response = client.subscribe_typed(statement, skip_rows, from).await;
(response, generation)
};
if matches!(response, Err(Error::Reqwest(_))) {
self.handle_error(generation).await;
} else {
self.handle_success(generation).await;
}
response
}
pub async fn subscription_typed<T: DeserializeOwned + Unpin>(
&self,
id: Uuid,
skip_rows: bool,
from: Option<ChangeId>,
) -> Result<SubscriptionStream<T>, Error> {
let (response, generation) = {
let (client, generation) = self.get_client().await?;
let response = client.subscription_typed(id, skip_rows, from).await;
(response, generation)
};
if matches!(response, Err(Error::Reqwest(_))) {
self.handle_error(generation).await;
} else {
self.handle_success(generation).await;
}
response
}
async fn get_client(&self) -> Result<(RwLockReadGuard<'_, CorrosionApiClient>, u64), Error> {
let mut inner = self.inner.write().await;
let generation = inner.generation;
if inner.client.is_none() {
let addr = inner.picker.next().await?;
info!(
"next Corrosion server to attempt: {}, generation: {}",
addr, generation
);
inner.client = Some(CorrosionApiClient::new(addr)?)
}
Ok((
RwLockReadGuard::map(inner.downgrade(), |inner| inner.client.as_ref().unwrap()),
generation,
))
}
async fn handle_success(&self, generation: u64) {
let mut inner = self.inner.write().await;
if inner.generation != generation {
return;
}
inner.had_success = true;
inner.first_fail_at = None;
}
async fn handle_error(&self, generation: u64) {
let mut inner = self.inner.write().await;
if generation != inner.generation {
return;
}
match inner.first_fail_at {
None if inner.had_success => {
inner.first_fail_at = Some(Instant::now());
}
Some(first) if Instant::now().duration_since(first) < inner.stickiness_timeout => {}
_ => {
if inner.had_success {
inner.picker.reset()
}
inner.client = None;
inner.first_fail_at = None;
inner.had_success = false;
inner.generation += 1;
}
}
}
}
struct AddrPicker {
resolver: Resolver,
addrs: Vec<String>,
next_addr: usize,
last_resolved_addrs: Option<Vec<SocketAddr>>,
next_resolved_addr: usize,
}
impl AddrPicker {
fn new(addrs: Vec<String>, resolver: Resolver) -> AddrPicker {
Self {
resolver,
addrs,
next_addr: 0,
last_resolved_addrs: None,
next_resolved_addr: 0,
}
}
async fn next(&mut self) -> Result<SocketAddr, Error> {
if self.next_resolved_addr
>= self
.last_resolved_addrs
.as_ref()
.map(|v| v.len())
.unwrap_or_default()
{
let host_port = self
.addrs
.get(self.next_addr)
.ok_or(ResolveError::from("No addresses available"))?;
self.next_addr = (self.next_addr + 1) % self.addrs.len();
let mut addrs = if let Ok(addr) = host_port.parse() {
vec![addr]
} else {
let (host, port) = host_port
.rsplit_once(':')
.and_then(|(host, port)| Some((host, port.parse().ok()?)))
.ok_or(ResolveError::from("Invalid Corrosion server address"))?;
timeout(DNS_RESOLVE_TIMEOUT, self.resolver.lookup_ip(host))
.await
.map_err(|_| ResolveError::Timeout)??
.iter()
.map(|addr| (addr, port).into())
.collect::<Vec<_>>()
};
addrs.sort();
debug!("got the following Corrosion servers: {:?}", addrs);
self.last_resolved_addrs = Some(addrs);
self.next_resolved_addr = 0;
}
if let Some(addr) = self
.last_resolved_addrs
.as_ref()
.and_then(|a| a.get(self.next_resolved_addr).copied())
{
self.next_resolved_addr += 1;
Ok(addr)
} else {
Err(ResolveError::from("DNS didn't return any addresses").into())
}
}
fn reset(&mut self) {
self.next_addr = 0;
self.last_resolved_addrs = None;
self.next_resolved_addr = 0;
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Dns(#[from] ResolveError),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
InvalidUri(#[from] http::uri::InvalidUri),
#[error(transparent)]
Http(#[from] http::Error),
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error("received unexpected response code: {0}")]
UnexpectedStatusCode(http::StatusCode),
#[error("{0}")]
ResponseError(String),
#[error("unexpected result: {0:?}")]
UnexpectedResult(ExecResult),
#[error("could not retrieve subscription id from headers")]
ExpectedQueryId,
}
#[cfg(test)]
mod tests {
use crate::{CorrosionPooledClient, Error};
use corro_api_types::{SqliteValue, QUERY_ID_HEADER};
use hickory_resolver::Resolver;
use hyper::{header::HeaderValue, service::service_fn, Request, Response};
use std::{
convert::Infallible,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::{net::TcpListener, pin, sync::broadcast};
use uuid::Uuid;
struct Empty<D>(std::marker::PhantomData<D>);
impl Empty<bytes::Bytes> {
fn new() -> Self {
Self(std::marker::PhantomData)
}
}
impl<D: bytes::Buf> http_body::Body for Empty<D> {
type Data = D;
type Error = std::convert::Infallible;
fn poll_frame(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
std::task::Poll::Ready(None)
}
fn is_end_stream(&self) -> bool {
true
}
fn size_hint(&self) -> http_body::SizeHint {
http_body::SizeHint::with_exact(0)
}
}
struct Server {
id: Uuid,
addr: SocketAddr,
refuse: Arc<AtomicBool>,
drop_conn_tx: broadcast::Sender<()>,
}
impl Server {
async fn new(id: Uuid) -> Self {
let refuse = Arc::new(AtomicBool::new(false));
let (drop_conn_tx, drop_conn_rx) = broadcast::channel::<()>(1);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn({
let refuse = refuse.clone();
async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
if refuse.load(Ordering::Relaxed) {
drop(stream);
continue;
}
let io = hyper_util::rt::TokioIo::new(stream);
let mut drop_conn_rx = drop_conn_rx.resubscribe();
tokio::spawn(async move {
let conn = hyper::server::conn::http2::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(
io,
service_fn(move |_: Request<hyper::body::Incoming>| async move {
let mut res = Response::new(Empty::new());
res.headers_mut().insert(
QUERY_ID_HEADER,
HeaderValue::from_str(&id.to_string()).unwrap(),
);
Ok::<_, Infallible>(res)
}),
);
pin!(conn);
tokio::select! {
_ = conn.as_mut() => (),
_ = drop_conn_rx.recv() => {
conn.as_mut().graceful_shutdown()
},
}
});
}
}
});
Server {
id,
addr,
refuse,
drop_conn_tx,
}
}
fn refuse_new_conns(&self, refuse: bool) {
self.refuse.store(refuse, Ordering::Relaxed)
}
fn kill_existing_conns(&self) {
_ = self.drop_conn_tx.send(())
}
}
async fn gen_servers(num: usize) -> (Vec<Server>, Vec<String>) {
let mut servers = Vec::new();
for _ in 0..num {
servers.push(Server::new(Uuid::new_v4()).await);
}
servers.sort_by(|a, b| a.addr.partial_cmp(&b.addr).unwrap());
let addrs = servers.iter().map(|s| s.addr.to_string()).collect();
(servers, addrs)
}
#[tokio::test]
async fn test_single_address() {
let statement = "".into();
let (servers, addresses) = gen_servers(1).await;
let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
let sub = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await
.unwrap();
assert_eq!(sub.id(), servers[0].id);
servers[0].kill_existing_conns();
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
assert!(matches!(res, Result::Err(Error::Reqwest(_))));
let sub = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await
.unwrap();
assert_eq!(sub.id(), servers[0].id);
}
#[tokio::test]
async fn test_multiple_addresses() {
let statement = "".into();
let (servers, addresses) = gen_servers(3).await;
let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
servers[0].refuse_new_conns(true);
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
assert!(matches!(res, Result::Err(Error::Reqwest(_))));
let sub = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await
.unwrap();
assert_eq!(sub.id(), servers[1].id);
servers[1].kill_existing_conns();
servers[1].refuse_new_conns(true);
servers[0].refuse_new_conns(false);
for _ in 0..2 {
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
assert!(matches!(res, Result::Err(Error::Reqwest(_))));
}
let sub = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await
.unwrap();
assert_eq!(sub.id(), servers[0].id);
}
#[tokio::test]
async fn test_multiple_addresses_sticky() {
let statement = "".into();
let (servers, addresses) = gen_servers(3).await;
let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
let client = CorrosionPooledClient::new(addresses, Duration::from_millis(50), resolver);
servers[0].refuse_new_conns(true);
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
assert!(matches!(res, Result::Err(Error::Reqwest(_))));
let sub = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await
.unwrap();
assert_eq!(sub.id(), servers[1].id);
servers[1].kill_existing_conns();
servers[1].refuse_new_conns(true);
servers[0].refuse_new_conns(false);
let mut attempts = 0;
loop {
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
match res {
Ok(sub) => {
assert_eq!(sub.id(), servers[0].id);
break;
}
Err(_) => attempts += 1,
};
}
assert!(attempts > 2);
}
#[tokio::test]
async fn test_more_servers() {
let statement = "".into();
let (pool1_servers, pool1_addresses) = gen_servers(2).await;
let (pool2_servers, pool2_addresses) = gen_servers(2).await;
let mut addresses = pool1_addresses;
addresses.extend_from_slice(&pool2_addresses);
let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
for i in 0..2 {
pool1_servers[i].refuse_new_conns(true);
pool2_servers[i].refuse_new_conns(true);
}
for _ in 0..15 {
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
assert!(matches!(res, Result::Err(Error::Reqwest(_))));
}
pool2_servers[0].refuse_new_conns(false);
for i in 0..4 {
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
match res {
Result::Err(_) => (),
Ok(sub) => {
assert_eq!(sub.id(), pool2_servers[0].id);
break;
}
}
assert!(i != 3);
}
pool2_servers[0].kill_existing_conns();
pool2_servers[0].refuse_new_conns(true);
pool1_servers[0].refuse_new_conns(false);
pool1_servers[1].refuse_new_conns(false);
for _ in 0..2 {
let res = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await;
assert!(matches!(res, Result::Err(Error::Reqwest(_))));
}
let sub = client
.subscribe_typed::<SqliteValue>(&statement, false, None)
.await
.unwrap();
assert_eq!(sub.id(), pool1_servers[0].id);
}
}