use std::{
cmp,
collections::{
BTreeMap,
BTreeSet,
VecDeque,
},
future::Future,
pin::Pin,
};
use convex_sync_types::{
types::SerializedArgs,
AuthenticationToken,
CanonicalizedUdfPath,
ClientMessage,
IdentityVersion,
QueryId,
QuerySetModification,
QuerySetVersion,
SessionRequestSeqNumber,
StateModification,
StateVersion,
Timestamp,
UdfPath,
};
use serde_json::json;
use tokio::sync::oneshot;
#[cfg(doc)]
use crate::ConvexClient;
use crate::{
convex_logs,
sync::{
ReconnectProtocolReason,
ServerMessage,
},
value::Value,
ConvexError,
};
mod request_manager;
use request_manager::{
RequestId,
RequestManager,
};
mod query_result;
pub use query_result::{
FunctionResult,
QueryResults,
};
use self::request_manager::RequestType;
pub type AuthTokenFetcher = Box<
dyn Fn(bool) -> Pin<Box<dyn Future<Output = anyhow::Result<AuthenticationToken>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone, Eq, PartialEq, PartialOrd, Ord, Debug)]
struct QueryToken(String);
#[derive(Clone, Debug)]
struct LocalQuery {
id: QueryId,
canonicalized_udf_path: CanonicalizedUdfPath,
args: BTreeMap<String, Value>,
num_subscribers: usize, subscription_index: usize,
}
#[derive(Clone, Debug)]
struct Query {
result: FunctionResult,
_udf_path: CanonicalizedUdfPath,
_args: BTreeMap<String, Value>,
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct SubscriberId(QueryId, usize);
impl SubscriberId {
#[cfg(test)]
pub fn query_id(&self) -> QueryId {
self.0
}
}
fn serialize_path_and_args(udf_path: UdfPath, args: BTreeMap<String, Value>) -> QueryToken {
let json_path: String = udf_path.canonicalize().into();
let json_args: serde_json::Value = Value::Array(vec![Value::Object(args)]).into();
let json = json!({
"udfPath": json_path,
"args": json_args,
});
QueryToken(json.to_string())
}
#[derive(Default)]
struct LocalSyncState {
next_query_id: QueryId,
query_set_version: QuerySetVersion,
query_set: BTreeMap<QueryToken, LocalQuery>,
query_id_to_token: BTreeMap<QueryId, QueryToken>,
latest_results: QueryResults,
identity_version: IdentityVersion,
auth_fetcher: Option<AuthTokenFetcher>,
}
impl LocalSyncState {
fn subscribe(
&mut self,
udf_path: UdfPath,
args: BTreeMap<String, Value>,
) -> (Option<ClientMessage>, SubscriberId) {
let canonicalized_udf_path = udf_path.clone().canonicalize();
let query_token = serialize_path_and_args(udf_path.clone(), args.clone());
if let Some(existing_entry) = self.query_set.get_mut(&query_token) {
existing_entry.num_subscribers += 1;
existing_entry.subscription_index += 1;
let query_id = existing_entry.id;
let subscription = SubscriberId(query_id, existing_entry.subscription_index);
let prev = self.latest_results.subscribers.insert(subscription);
assert!(prev.is_none(), "INTERNAL BUG: Subscriber ID already taken.");
return (None, subscription);
}
let query_id = self.next_query_id;
self.next_query_id = QueryId::new(self.next_query_id.get_id() + 1);
let base_version = self.query_set_version;
self.query_set_version += 1;
let new_version = self.query_set_version;
let add = QuerySetModification::Add(convex_sync_types::Query {
query_id,
udf_path,
args: SerializedArgs::from_args(vec![Value::Object(args.clone()).into()])
.expect("Could not serialize query arguments"),
journal: None,
component_path: None,
});
let message = ClientMessage::ModifyQuerySet {
base_version,
new_version,
modifications: vec![add],
};
let query = LocalQuery {
id: query_id,
canonicalized_udf_path,
args,
num_subscribers: 1,
subscription_index: 0,
};
self.query_set.insert(query_token.clone(), query);
self.query_id_to_token.insert(query_id, query_token.clone());
let subscription = SubscriberId(query_id, 0);
let prev = self.latest_results.subscribers.insert(subscription);
assert!(prev.is_none(), "INTERNAL BUG: Subscriber ID already taken.");
(Some(message), subscription)
}
fn remove_subscriber(&mut self, subscriber_id: SubscriberId) -> Option<ClientMessage> {
let query_id = self
.latest_results
.subscribers
.remove(&subscriber_id)
.expect("INTERNAL BUG: Dropped unknown Subscriber ID")
.0;
let query_token = match self.query_token(query_id) {
None => panic!("INTERNAL BUG: Unknown query id {query_id}"),
Some(t) => t,
};
let local_query = match self.query_set.get_mut(&query_token) {
None => panic!("INTERNAL BUG: No query found for query token {query_token:?}",),
Some(q) => q,
};
if local_query.num_subscribers > 1 {
local_query.num_subscribers -= 1;
return None;
}
self.query_set.remove(&query_token);
self.query_id_to_token.remove(&query_id);
self.latest_results.results.remove(&query_id);
let base_version = self.query_set_version;
self.query_set_version += 1;
let new_version = self.query_set_version;
let remove = QuerySetModification::Remove { query_id };
Some(ClientMessage::ModifyQuerySet {
base_version,
new_version,
modifications: vec![remove],
})
}
fn query_token(&self, query_id: QueryId) -> Option<QueryToken> {
self.query_id_to_token.get(&query_id).cloned()
}
fn query_args(&self, query_id: QueryId) -> Option<BTreeMap<String, Value>> {
Some(
self.query_set
.get(&self.query_token(query_id)?)?
.args
.clone(),
)
}
fn query_path(&self, query_id: QueryId) -> Option<CanonicalizedUdfPath> {
Some(
self.query_set
.get(&self.query_token(query_id)?)?
.canonicalized_udf_path
.clone(),
)
}
fn authenticate(&mut self, token: AuthenticationToken) -> ClientMessage {
let base_version = self.identity_version;
self.identity_version += 1;
ClientMessage::Authenticate {
base_version,
token,
}
}
async fn restart(&mut self) -> Vec<ClientMessage> {
self.identity_version = 0;
let mut messages = Vec::new();
if let Some(ref fetcher) = self.auth_fetcher {
match fetcher(true).await {
Ok(token) if token != AuthenticationToken::None => {
messages.push(ClientMessage::Authenticate {
base_version: 0,
token,
});
self.identity_version += 1;
},
Ok(_) => {},
Err(e) => {
tracing::error!(
"Auth fetcher failed during reconnect: {e:?}. Skipping auth for this \
reconnect attempt."
);
},
}
}
let mut modifications = Vec::new();
for local_query in self.query_set.values() {
let add = QuerySetModification::Add(convex_sync_types::Query {
query_id: local_query.id,
udf_path: local_query.canonicalized_udf_path.clone().into(),
args: SerializedArgs::from_args(vec![
Value::Object(local_query.args.clone()).into()
])
.expect("Could not serialize query arguments"),
journal: None,
component_path: None,
});
modifications.push(add)
}
self.query_set_version = 1;
messages.push(ClientMessage::ModifyQuerySet {
base_version: 0,
new_version: 1,
modifications,
});
messages
}
}
#[derive(Debug)]
struct RemoteQuerySet {
version: StateVersion,
remote_query_set: BTreeMap<QueryId, FunctionResult>,
}
impl RemoteQuerySet {
fn new() -> Self {
Self {
version: StateVersion::initial(),
remote_query_set: Default::default(),
}
}
fn transition(&mut self, transition: ServerMessage) -> Result<(), ReconnectProtocolReason> {
let ServerMessage::Transition {
start_version,
end_version,
modifications,
client_clock_skew: _,
server_ts: _,
} = transition
else {
panic!("not transition");
};
if start_version != self.version {
tracing::error!(
"INTERNAL BUG: Protocol Error start_version {:?} is different from self.version \
{:?}",
start_version,
self.version
);
return Err("StartVersionMismatch".into());
}
for modification in modifications {
match modification {
StateModification::QueryUpdated {
query_id,
value,
log_lines,
journal: _,
} => {
for log_line in log_lines.0 {
convex_logs!("{}", log_line);
}
self.remote_query_set
.insert(query_id, FunctionResult::Value(value));
},
StateModification::QueryFailed {
query_id,
error_message,
log_lines,
journal: _,
error_data,
} => {
for log_line in log_lines.0 {
convex_logs!("{}", log_line);
}
let function_result = match error_data {
Some(v) => FunctionResult::ConvexError(ConvexError {
message: error_message,
data: v,
}),
None => FunctionResult::ErrorMessage(error_message),
};
self.remote_query_set.insert(query_id, function_result);
},
StateModification::QueryRemoved { query_id } => {
self.remote_query_set.remove(&query_id);
},
}
}
self.version = end_version;
Ok(())
}
}
#[derive(Default, Debug)]
struct OptimisticQueryResults {
query_results: BTreeMap<QueryId, Query>,
}
impl OptimisticQueryResults {
fn ingest_query_results_from_server(
&mut self,
server_query_results: BTreeMap<QueryId, Query>,
_optimistic_updates_to_drop: BTreeSet<RequestId>,
) -> BTreeMap<QueryId, FunctionResult> {
let old_query_results = self.query_results.clone();
self.query_results = server_query_results;
let mut changed_queries = BTreeMap::new();
for (query_id, query) in self.query_results.iter() {
let old_query = old_query_results.get(query_id);
if match old_query {
Some(old_query) => old_query.result != query.result,
None => true,
} {
let result = query.result.clone();
changed_queries.insert(*query_id, result);
}
}
changed_queries
}
fn query_result(&self, query_id: QueryId) -> Option<FunctionResult> {
self.query_results.get(&query_id).map(|q| q.result.clone())
}
}
pub struct BaseConvexClient {
state: LocalSyncState,
remote_query_set: RemoteQuerySet,
optimistic_query_results: OptimisticQueryResults,
request_manager: RequestManager,
next_request_id: SessionRequestSeqNumber,
outgoing_message_queue: VecDeque<ClientMessage>,
max_observed_timestamp: Option<Timestamp>,
}
impl BaseConvexClient {
pub fn new() -> Self {
let request_manager = RequestManager::new();
let state = LocalSyncState::default();
let remote_query_set = RemoteQuerySet::new();
let optimistic_query_results: OptimisticQueryResults = Default::default();
let next_request_id: SessionRequestSeqNumber = 0;
BaseConvexClient {
request_manager,
state,
remote_query_set,
optimistic_query_results,
next_request_id,
outgoing_message_queue: VecDeque::new(),
max_observed_timestamp: None,
}
}
pub fn subscribe(&mut self, udf_path: UdfPath, args: BTreeMap<String, Value>) -> SubscriberId {
let (modification, subscription) = self.state.subscribe(udf_path, args);
if let Some(modification) = modification {
self.outgoing_message_queue.push_back(modification);
}
subscription
}
pub fn unsubscribe(&mut self, subscriber_id: SubscriberId) {
let unsubscribe_message = self.state.remove_subscriber(subscriber_id);
if let Some(message) = unsubscribe_message {
self.outgoing_message_queue.push_back(message);
}
}
pub fn get_query(&self, query_id: QueryId) -> Option<FunctionResult> {
self.local_query_result(query_id)
}
pub fn mutation(
&mut self,
udf_path: UdfPath,
args: BTreeMap<String, Value>,
) -> oneshot::Receiver<FunctionResult> {
let request_id = self.next_request_id;
self.next_request_id = request_id + 1;
tracing::info!("Starting mutation {udf_path} with id {request_id}");
let message = ClientMessage::Mutation {
request_id,
udf_path,
args: SerializedArgs::from_args(vec![Value::Object(args).into()])
.expect("Failed to serialize arguments"),
component_path: None,
};
let result_receiver = self.request_manager.track_request(
&message,
RequestId::new(request_id),
RequestType::Mutation,
);
self.outgoing_message_queue.push_back(message);
result_receiver
}
pub fn action(
&mut self,
udf_path: UdfPath,
args: BTreeMap<String, Value>,
) -> oneshot::Receiver<FunctionResult> {
let request_id = self.next_request_id;
self.next_request_id = request_id + 1;
tracing::info!("Starting action {udf_path:?} with id {request_id:?}");
let message = ClientMessage::Action {
request_id,
udf_path,
args: SerializedArgs::from_args(vec![Value::Object(args).into()]).unwrap(),
component_path: None,
};
let result_receiver = self.request_manager.track_request(
&message,
RequestId::new(request_id),
RequestType::Action,
);
self.outgoing_message_queue.push_back(message);
result_receiver
}
pub async fn set_auth_fetcher(&mut self, fetcher: Option<AuthTokenFetcher>) {
match fetcher {
Some(fetcher) => {
match fetcher(false).await {
Ok(token) => {
let message = self.state.authenticate(token);
self.outgoing_message_queue.push_back(message);
},
Err(e) => {
tracing::error!("Auth token fetcher failed: {e:?}");
},
}
self.state.auth_fetcher = Some(fetcher);
},
None => {
self.state.auth_fetcher = None;
let message = self.state.authenticate(AuthenticationToken::None);
self.outgoing_message_queue.push_back(message);
},
}
}
pub fn pop_next_message(&mut self) -> Option<ClientMessage> {
self.outgoing_message_queue.pop_front()
}
fn observe_timestamp(&mut self, ts: Timestamp) {
if let Some(max_observed_timestamp) = self.max_observed_timestamp {
self.max_observed_timestamp = Some(cmp::max(ts, max_observed_timestamp));
} else {
self.max_observed_timestamp = Some(ts);
}
}
pub fn max_observed_timestamp(&self) -> Option<Timestamp> {
self.max_observed_timestamp
}
pub fn receive_message(
&mut self,
message: ServerMessage,
) -> Result<Option<QueryResults>, ReconnectProtocolReason> {
match message {
ServerMessage::Transition { end_version, .. } => {
self.observe_timestamp(end_version.ts);
self.remote_query_set.transition(message)?;
let completed_requests = self
.request_manager
.remove_and_notify_completed(end_version.ts);
let changed_query_ids = self.on_query_result_changes(completed_requests)?;
for (id, result) in changed_query_ids {
self.state.latest_results.results.insert(id, result);
}
return Ok(Some(self.state.latest_results.clone()));
},
ServerMessage::MutationResponse {
request_id,
result,
ts,
log_lines,
} => {
for log_line in log_lines.0 {
convex_logs!("{}", log_line);
}
if let Some(ts) = ts {
self.observe_timestamp(ts);
}
let request_id = RequestId::new(request_id);
self.request_manager.update_request(
&request_id,
RequestType::Mutation,
result.into(),
ts,
)?;
},
ServerMessage::AuthError {
error_message,
base_version,
..
} => {
tracing::error!(
"AuthError: {error_message} for identity version {base_version:?}. Restarting \
protocol."
);
return Err(format!(
"AuthError: {error_message} for identity version {base_version:?}"
));
},
ServerMessage::FatalError { error_message } => {
tracing::error!("FatalError: {error_message}. Restarting protocol.");
return Err(format!("FatalError: {error_message}"));
},
ServerMessage::ActionResponse {
request_id,
result,
log_lines,
} => {
for log_line in log_lines.0 {
convex_logs!("{}", log_line);
}
let request_id = RequestId::new(request_id);
self.request_manager.update_request(
&request_id,
RequestType::Action,
result.into(),
None,
)?;
},
ServerMessage::Ping => {
},
ServerMessage::TransitionChunk { .. } => {
return Err("Unexpected TransitionChunk message received".to_string());
},
}
Ok(None)
}
pub fn latest_results(&self) -> &QueryResults {
&self.state.latest_results
}
pub async fn resend_ongoing_queries_mutations(&mut self) {
self.outgoing_message_queue.clear();
let state_restart_messages = self.state.restart().await;
let mut ongoing_mutation_messages = self.request_manager.restart();
self.remote_query_set = RemoteQuerySet::new();
for state_restart_message in state_restart_messages {
self.outgoing_message_queue.push_back(state_restart_message);
}
self.outgoing_message_queue
.append(&mut ongoing_mutation_messages);
}
fn on_query_result_changes(
&mut self,
completed_requests: BTreeSet<RequestId>,
) -> Result<BTreeMap<QueryId, FunctionResult>, ReconnectProtocolReason> {
let remote_query_results = &self.remote_query_set.remote_query_set;
let mut query_id_to_value = BTreeMap::new();
for (query_id, result) in remote_query_results.iter() {
let Some(_udf_path) = self.state.query_path(*query_id) else {
continue;
};
let _args = self
.state
.query_args(*query_id)
.expect("INTERNAL BUG: Query args exist, but not query path.");
query_id_to_value.insert(
*query_id,
Query {
result: result.clone(),
_udf_path,
_args,
},
);
}
Ok(self
.optimistic_query_results
.ingest_query_results_from_server(query_id_to_value, completed_requests))
}
fn local_query_result(&self, query_id: QueryId) -> Option<FunctionResult> {
self.optimistic_query_results.query_result(query_id)
}
}
#[macro_export]
macro_rules! convex_logs {
(target: $target:expr, $($arg:tt)+) => {
tracing::event!(target: "convex_logs", tracing::Level::DEBUG, $($arg)+);
};
($($arg:tt)+) => {
tracing::event!(target: "convex_logs", tracing::Level::DEBUG, $($arg)+);
};
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use convex_sync_types::{
AuthenticationToken,
ClientMessage,
LogLinesMessage,
QuerySetVersion,
UdfPath,
};
use maplit::btreemap;
use super::*;
fn simulate_server_version_check(messages: &[ClientMessage]) -> Result<(), String> {
let mut query_set_version: QuerySetVersion = 0;
for msg in messages {
if let ClientMessage::ModifyQuerySet {
base_version,
new_version,
..
} = msg
{
if *base_version != query_set_version {
return Err(format!(
"Base version {base_version} passed up doesn't match the current version \
{query_set_version}"
));
}
query_set_version = *new_version;
}
}
Ok(())
}
#[tokio::test]
async fn test_reconnect_does_not_send_duplicate_version_messages() {
let mut client = BaseConvexClient::new();
client
.set_auth_fetcher(Some(Box::new(|_force_refetch| {
Box::pin(async { Ok(AuthenticationToken::User("test-token".into())) })
})))
.await;
let udf = UdfPath::from_str("some:query").unwrap();
client.subscribe(udf, btreemap! {});
while client.pop_next_message().is_some() {}
client.resend_ongoing_queries_mutations().await;
let _ = client.pop_next_message();
client.resend_ongoing_queries_mutations().await;
let mut messages = vec![];
while let Some(msg) = client.pop_next_message() {
messages.push(msg);
}
simulate_server_version_check(&messages)
.expect("Server would reject these messages with a FatalError");
}
#[tokio::test]
async fn test_reconnect_path_requests_refreshed_token() {
let mut client = BaseConvexClient::new();
client
.set_auth_fetcher(Some(Box::new(|force_refetch| {
Box::pin(async move {
if force_refetch {
Ok(AuthenticationToken::User("refetched-token".into()))
} else {
Ok(AuthenticationToken::User("original-token".into()))
}
})
})))
.await;
let udf = UdfPath::from_str("some:query").unwrap();
client.subscribe(udf, btreemap! {});
while client.pop_next_message().is_some() {}
client.resend_ongoing_queries_mutations().await;
assert_eq!(
client
.pop_next_message()
.expect("Expected an authentication message."),
ClientMessage::Authenticate {
base_version: 0,
token: AuthenticationToken::User("refetched-token".into()),
}
);
}
fn drain_add_message(client: &mut BaseConvexClient) -> QueryId {
match client.pop_next_message() {
Some(ClientMessage::ModifyQuerySet { modifications, .. }) => {
let [QuerySetModification::Add(query)] = modifications.as_slice() else {
panic!("expected a single add modification, got {modifications:?}");
};
query.query_id
},
other => panic!("expected add query message, got {other:?}"),
}
}
fn drain_remove_message(client: &mut BaseConvexClient) -> QueryId {
match client.pop_next_message() {
Some(ClientMessage::ModifyQuerySet { modifications, .. }) => {
let [QuerySetModification::Remove { query_id }] = modifications.as_slice() else {
panic!("expected a single remove modification, got {modifications:?}");
};
*query_id
},
other => panic!("expected remove query message, got {other:?}"),
}
}
fn apply_query_update(
client: &mut BaseConvexClient,
version: &mut StateVersion,
query_id: QueryId,
value: Value,
) {
let end_version = StateVersion {
ts: version.ts.succ().expect("timestamp overflow in test"),
..*version
};
let transition = ServerMessage::Transition {
start_version: *version,
end_version,
modifications: vec![StateModification::QueryUpdated {
query_id,
value,
log_lines: LogLinesMessage(vec![]),
journal: None,
}],
client_clock_skew: None,
server_ts: None,
};
let latest_results = client
.receive_message(transition)
.expect("transition should be accepted");
assert!(
latest_results.is_some(),
"query update should publish results"
);
*version = end_version;
}
#[test]
fn test_final_unsubscribe_removes_cached_query_result() {
let mut client = BaseConvexClient::new();
let mut version = StateVersion::initial();
let subscriber_id = client.subscribe("getValue1".parse().unwrap(), BTreeMap::new());
let query_id = drain_add_message(&mut client);
assert!(client.pop_next_message().is_none());
apply_query_update(&mut client, &mut version, query_id, 10.into());
assert!(client.state.latest_results.results.contains_key(&query_id));
assert_eq!(
client.latest_results().get(&subscriber_id),
Some(&FunctionResult::Value(10.into()))
);
client.unsubscribe(subscriber_id);
assert_eq!(drain_remove_message(&mut client), query_id);
assert!(client.pop_next_message().is_none());
assert!(client.state.latest_results.subscribers.is_empty());
assert!(!client.state.latest_results.results.contains_key(&query_id));
}
#[test]
fn test_cached_query_result_persists_while_subscribers_exist() {
let mut client = BaseConvexClient::new();
let mut version = StateVersion::initial();
let subscriber_a = client.subscribe("getValue1".parse().unwrap(), BTreeMap::new());
let query_id = drain_add_message(&mut client);
let subscriber_b = client.subscribe("getValue1".parse().unwrap(), BTreeMap::new());
assert!(client.pop_next_message().is_none());
apply_query_update(&mut client, &mut version, query_id, 10.into());
client.unsubscribe(subscriber_a);
assert!(client.pop_next_message().is_none());
assert!(client.state.latest_results.results.contains_key(&query_id));
assert_eq!(
client.latest_results().get(&subscriber_b),
Some(&FunctionResult::Value(10.into()))
);
}
}