#[cfg(feature = "preview")]
use aranya_runtime::SyncHelloType;
use aranya_runtime::{
PolicyStore, StorageError, StorageProvider, SyncRequestMessage, SyncResponder, SyncType,
TraversalBuffers, MAX_SYNC_MESSAGE_SIZE,
};
use aranya_util::{error::ReportExt as _, ready};
use buggy::{bug, BugExt as _};
use derive_where::derive_where;
use tracing::{debug, error, info, instrument, trace, warn};
use super::{
transport::{SyncListener, SyncStream},
Addr, Error, SyncHandle, SyncPeer,
};
use crate::{aranya::Client, sync::SyncResponse};
#[derive_where(Debug; SL)]
pub(crate) struct SyncServer<SL, PS, SP> {
client: Client<PS, SP>,
listener: SL,
handle: SyncHandle,
}
impl<SL, PS, SP> SyncServer<SL, PS, SP>
where
SL: SyncListener + Sync,
PS: PolicyStore + Send + 'static,
SP: StorageProvider + Send + 'static,
{
pub(crate) const fn new(listener: SL, client: Client<PS, SP>, handle: SyncHandle) -> Self {
Self {
client,
listener,
handle,
}
}
pub(crate) fn local_addr(&self) -> Addr {
self.listener.local_addr()
}
#[allow(clippy::disallowed_macros)]
#[instrument(skip_all, fields(addr = ?self.local_addr()))]
pub(crate) async fn serve(mut self, ready: ready::Notifier) {
info!("sync server listening for incoming connections");
ready.notify();
aranya_util::task::scope(async |s| {
while let Some(stream) = self.listener.accept().await {
let peer = stream.peer();
trace!(?peer, "accepted stream");
let client = self.client.clone();
let handle = self.handle.clone();
s.spawn(async move {
if let Err(e) = Self::handle_stream(client, handle, stream).await {
warn!(?peer, error = %e.report(), "error handling sync request");
}
});
}
})
.await;
error!("sync server terminated");
}
#[instrument(skip_all, fields(peer = %stream.peer().addr, graph = %stream.peer().graph_id))]
#[cfg_attr(not(feature = "preview"), expect(unused_variables))]
async fn handle_stream<S: SyncStream>(
client: Client<PS, SP>,
handle: SyncHandle,
mut stream: S,
) -> Result<(), Error> {
trace!("received sync request");
let mut buf = vec![0u8; MAX_SYNC_MESSAGE_SIZE].into_boxed_slice();
let len = stream.receive(&mut buf).await.map_err(Error::transport)?;
trace!(len, "received request bytes");
let buffer = buf.get(..len).assume("valid offset")?;
let sync_type = postcard::from_bytes(buffer)?;
debug!(sync_type = ?std::mem::discriminant(&sync_type), "processing request");
let response = match sync_type {
SyncType::Poll { request } => {
match Self::process_poll_request(&client, stream.peer(), request, &mut buf).await {
Ok(len) => {
debug!(response_bytes = len, "poll request succeeded");
SyncResponse::Ok(buf.get(..len).assume("valid offset")?.into())
}
Err(e) => {
error!(error = %e.report(), "error processing poll message");
SyncResponse::Err(e.report().to_string())
}
}
}
SyncType::Hello(hello_msg) => {
#[cfg(not(feature = "preview"))]
{
let _ = hello_msg;
bug!("sync hello not enabled")
}
#[cfg(feature = "preview")]
{
match Self::process_hello_request(&handle, stream.peer(), hello_msg).await {
Ok(()) => {
debug!("hello request succeeded");
SyncResponse::Ok(Box::new([]))
}
Err(e) => {
error!(error = %e.report(), "error processing hello message");
SyncResponse::Err(e.report().to_string())
}
}
}
}
SyncType::Subscribe { .. } | SyncType::Unsubscribe { .. } | SyncType::Push { .. } => {
bug!("message type not currently implemented!")
}
};
let data = postcard::to_slice(&response, &mut buf)?;
stream.send(data).await.map_err(Error::transport)?;
stream.finish().await.map_err(Error::transport)?;
trace!(n = data.len(), "sent response");
Ok(())
}
async fn process_poll_request(
client: &Client<PS, SP>,
peer: SyncPeer,
request: SyncRequestMessage,
buf: &mut [u8],
) -> Result<usize, Error> {
match request {
SyncRequestMessage::SyncRequest { graph_id, .. } => {
peer.check_request(graph_id)?;
let mut resp = SyncResponder::new();
resp.receive(request)?;
let (mut aranya, mut caches) = client.lock_aranya_and_caches().await;
let cache = caches.entry(peer).or_default();
let len = resp
.poll(buf, aranya.provider(), cache, &mut TraversalBuffers::new())
.or_else(|err| {
if matches!(
err,
aranya_runtime::SyncError::Storage(StorageError::NoSuchStorage)
) {
warn!(team = %peer.graph_id, "missing requested graph");
Ok(0)
} else {
Err(err)
}
})
.map_err(Error::Runtime)?;
trace!(response_bytes = len, "generated poll response");
Ok(len)
}
other => {
warn!(
variant = ?std::mem::discriminant(&other),
"received an unexpected SyncRequestMessage variant"
);
Err(Error::InvalidRequest)
}
}
}
#[cfg(feature = "preview")]
pub(super) async fn process_hello_request(
handle: &SyncHandle,
peer: SyncPeer,
hello_msg: SyncHelloType,
) -> Result<(), Error> {
match hello_msg {
SyncHelloType::Subscribe {
graph_change_delay,
duration,
schedule_delay,
graph_id,
} => {
peer.check_request(graph_id)?;
handle
.hello_subscribe_request(peer, graph_change_delay, duration, schedule_delay)
.await?;
}
SyncHelloType::Unsubscribe { graph_id } => {
peer.check_request(graph_id)?;
handle.hello_unsubscribe_request(peer).await?;
}
SyncHelloType::Hello { head, graph_id } => {
peer.check_request(graph_id)?;
handle.sync_on_hello(peer, head).await?;
}
}
Ok(())
}
}