use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde_json::{json, Value};
use tokio::sync::{mpsc, oneshot, Mutex};
use crate::error::{Result, SlopError};
use crate::state_mirror::StateMirror;
use crate::types::{PatchOp, SlopNode};
type TransportChannels = (
mpsc::UnboundedSender<Value>,
mpsc::UnboundedReceiver<Value>,
);
type ConnectFuture = Pin<Box<dyn Future<Output = Result<TransportChannels>> + Send>>;
type PatchCallback = Arc<dyn Fn(&str, &[PatchOp], u64) + Send + Sync>;
type DisconnectCallback = Arc<dyn Fn() + Send + Sync>;
type ErrorCallback = Arc<dyn Fn(&str, &str, &str) + Send + Sync>;
type EventCallback = Arc<dyn Fn(&str, Option<&Value>) + Send + Sync>;
type GapCallback = Arc<dyn Fn(&str, u64, u64) + Send + Sync>;
#[derive(Clone)]
struct SubscriptionRecord {
path: String,
depth: i32,
}
pub trait ClientTransport: Send + Sync {
fn connect(&self) -> ConnectFuture;
}
pub struct SlopConsumer {
inner: Arc<Mutex<ConsumerInner>>,
}
struct ConsumerInner {
sender: Option<mpsc::UnboundedSender<Value>>,
mirrors: HashMap<String, StateMirror>,
subscriptions: HashMap<String, SubscriptionRecord>,
pending: HashMap<String, oneshot::Sender<Value>>,
sub_counter: u32,
req_counter: u32,
patch_callbacks: Vec<PatchCallback>,
disconnect_callbacks: Vec<DisconnectCallback>,
error_callbacks: Vec<ErrorCallback>,
event_callbacks: Vec<EventCallback>,
gap_callbacks: Vec<GapCallback>,
}
impl SlopConsumer {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(ConsumerInner {
sender: None,
mirrors: HashMap::new(),
subscriptions: HashMap::new(),
pending: HashMap::new(),
sub_counter: 0,
req_counter: 0,
patch_callbacks: Vec::new(),
disconnect_callbacks: Vec::new(),
error_callbacks: Vec::new(),
event_callbacks: Vec::new(),
gap_callbacks: Vec::new(),
})),
}
}
pub async fn connect(&self, transport: &dyn ClientTransport) -> Result<Value> {
let (tx, mut rx) = transport.connect().await?;
{
let mut inner = self.inner.lock().await;
inner.sender = Some(tx);
}
let hello = rx
.recv()
.await
.ok_or(SlopError::ConnectionClosed)?;
let inner_ref = Arc::clone(&self.inner);
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
Self::dispatch(Arc::clone(&inner_ref), msg).await;
}
let inner = inner_ref.lock().await;
for cb in &inner.disconnect_callbacks {
cb();
}
});
Ok(hello)
}
pub async fn subscribe(&self, path: &str, depth: i32) -> Result<(String, SlopNode)> {
let (sub_id, rx) = {
let mut inner = self.inner.lock().await;
inner.sub_counter += 1;
let sub_id = format!("sub-{}", inner.sub_counter);
let (tx, rx) = oneshot::channel();
inner.pending.insert(sub_id.clone(), tx);
inner.subscriptions.insert(
sub_id.clone(),
SubscriptionRecord {
path: path.to_string(),
depth,
},
);
self.send_inner(
&inner,
json!({
"type": "subscribe",
"id": sub_id,
"path": path,
"depth": depth,
}),
)?;
(sub_id, rx)
};
let snapshot = rx.await.map_err(|_| SlopError::ConnectionClosed)?;
let version = snapshot["version"].as_u64().unwrap_or(0);
let seq = snapshot["seq"].as_u64().unwrap_or(0);
let tree: SlopNode =
serde_json::from_value(snapshot["tree"].clone()).map_err(SlopError::Serialization)?;
{
let mut inner = self.inner.lock().await;
inner.mirrors.insert(
sub_id.clone(),
StateMirror::new_with_seq(tree.clone(), version, seq),
);
}
Ok((sub_id, tree))
}
pub async fn unsubscribe(&self, id: &str) {
let mut inner = self.inner.lock().await;
inner.mirrors.remove(id);
inner.subscriptions.remove(id);
let _ = self.send_inner(
&inner,
json!({"type": "unsubscribe", "id": id}),
);
}
pub async fn query(&self, path: &str, depth: i32) -> Result<SlopNode> {
let rx = {
let mut inner = self.inner.lock().await;
inner.req_counter += 1;
let req_id = format!("q-{}", inner.req_counter);
let (tx, rx) = oneshot::channel();
inner.pending.insert(req_id.clone(), tx);
self.send_inner(
&inner,
json!({
"type": "query",
"id": req_id,
"path": path,
"depth": depth,
}),
)?;
rx
};
let snapshot = rx.await.map_err(|_| SlopError::ConnectionClosed)?;
let tree: SlopNode =
serde_json::from_value(snapshot["tree"].clone()).map_err(SlopError::Serialization)?;
Ok(tree)
}
pub async fn invoke(
&self,
path: &str,
action: &str,
params: Option<Value>,
) -> Result<Value> {
let rx = {
let mut inner = self.inner.lock().await;
inner.req_counter += 1;
let req_id = format!("inv-{}", inner.req_counter);
let (tx, rx) = oneshot::channel();
inner.pending.insert(req_id.clone(), tx);
let mut msg = json!({
"type": "invoke",
"id": req_id,
"path": path,
"action": action,
});
if let Some(p) = params {
msg["params"] = p;
}
self.send_inner(&inner, msg)?;
rx
};
let result = rx.await.map_err(|_| SlopError::ConnectionClosed)?;
if result["status"] == "error" {
return Err(SlopError::ActionFailed {
code: result["error"]["code"]
.as_str()
.unwrap_or("unknown")
.to_string(),
message: result["error"]["message"]
.as_str()
.unwrap_or("unknown error")
.to_string(),
});
}
Ok(result)
}
pub async fn tree(&self, subscription_id: &str) -> Option<SlopNode> {
let inner = self.inner.lock().await;
inner
.mirrors
.get(subscription_id)
.map(|m| m.tree().clone())
}
pub async fn disconnect(&self) {
let mut inner = self.inner.lock().await;
inner.sender = None;
inner.mirrors.clear();
inner.subscriptions.clear();
inner.pending.clear();
}
pub async fn on_patch<F>(&self, callback: F)
where
F: Fn(&str, &[PatchOp], u64) + Send + Sync + 'static,
{
let mut inner = self.inner.lock().await;
inner.patch_callbacks.push(Arc::new(callback));
}
pub async fn on_disconnect<F>(&self, callback: F)
where
F: Fn() + Send + Sync + 'static,
{
let mut inner = self.inner.lock().await;
inner.disconnect_callbacks.push(Arc::new(callback));
}
pub async fn on_error<F>(&self, callback: F)
where
F: Fn(&str, &str, &str) + Send + Sync + 'static,
{
let mut inner = self.inner.lock().await;
inner.error_callbacks.push(Arc::new(callback));
}
pub async fn on_event<F>(&self, callback: F)
where
F: Fn(&str, Option<&Value>) + Send + Sync + 'static,
{
let mut inner = self.inner.lock().await;
inner.event_callbacks.push(Arc::new(callback));
}
pub async fn on_gap<F>(&self, callback: F)
where
F: Fn(&str, u64, u64) + Send + Sync + 'static,
{
let mut inner = self.inner.lock().await;
inner.gap_callbacks.push(Arc::new(callback));
}
fn send_inner(
&self,
inner: &ConsumerInner,
msg: Value,
) -> Result<()> {
inner
.sender
.as_ref()
.ok_or(SlopError::ConnectionClosed)?
.send(msg)
.map_err(|e| SlopError::Transport(e.to_string()))
}
async fn dispatch(inner: Arc<Mutex<ConsumerInner>>, msg: Value) {
let msg_type = msg["type"].as_str().unwrap_or("");
let msg_id = msg["id"].as_str().unwrap_or("").to_string();
match msg_type {
"snapshot" => {
let mut locked = inner.lock().await;
if let Some(tx) = locked.pending.remove(&msg_id) {
let _ = tx.send(msg.clone());
}
if let Some(mirror) = locked.mirrors.get_mut(&msg_id) {
let version = msg["version"].as_u64().unwrap_or(0);
let seq = msg["seq"].as_u64().unwrap_or(0);
if let Ok(tree) = serde_json::from_value::<SlopNode>(msg["tree"].clone()) {
*mirror = StateMirror::new_with_seq(tree, version, seq);
}
}
}
"patch" => {
let sub_id = msg["subscription"].as_str().unwrap_or(&msg_id).to_string();
let version = msg["version"].as_u64().unwrap_or(0);
let seq_opt = msg["seq"].as_u64();
let ops: Vec<PatchOp> = msg["ops"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect()
})
.unwrap_or_default();
let mut locked = inner.lock().await;
let mut gap: Option<(u64, u64)> = None;
if let Some(mirror) = locked.mirrors.get_mut(&sub_id) {
match seq_opt {
Some(seq) => match mirror.apply_patch_with_seq(&ops, version, seq) {
Ok(()) => {}
Err(err) => {
gap = Some((err.expected, err.received));
}
},
None => {
mirror.apply_patch(&ops, version);
}
}
}
if let Some((expected, received)) = gap {
locked.mirrors.remove(&sub_id);
let sub = locked.subscriptions.get(&sub_id).cloned();
let _ = locked
.sender
.as_ref()
.map(|tx| tx.send(json!({"type": "unsubscribe", "id": &sub_id})));
if let Some(sub) = sub {
let _ = locked.sender.as_ref().map(|tx| {
tx.send(json!({
"type": "subscribe",
"id": &sub_id,
"path": sub.path,
"depth": sub.depth,
}))
});
}
let callbacks: Vec<_> = locked.gap_callbacks.clone();
drop(locked);
for cb in &callbacks {
cb(&sub_id, expected, received);
}
return;
}
let callbacks: Vec<_> = locked.patch_callbacks.clone();
drop(locked);
for cb in &callbacks {
cb(&sub_id, &ops, version);
}
}
"result" => {
let mut locked = inner.lock().await;
if let Some(tx) = locked.pending.remove(&msg_id) {
let _ = tx.send(msg);
}
}
"error" => {
let mut locked = inner.lock().await;
let code = msg["error"]["code"].as_str().unwrap_or("unknown");
let message = msg["error"]["message"].as_str().unwrap_or("");
if !msg_id.is_empty() {
if let Some(tx) = locked.pending.remove(&msg_id) {
let _ = tx.send(msg.clone());
}
}
let callbacks: Vec<_> = locked.error_callbacks.clone();
drop(locked);
for cb in &callbacks {
cb(&msg_id, code, message);
}
}
"event" => {
let locked = inner.lock().await;
let name = msg["name"].as_str().unwrap_or("");
let data = msg.get("data");
let callbacks: Vec<_> = locked.event_callbacks.clone();
drop(locked);
for cb in &callbacks {
cb(name, data);
}
}
"batch" => {
if let Some(messages) = msg["messages"].as_array() {
for sub_msg in messages {
Box::pin(Self::dispatch(Arc::clone(&inner), sub_msg.clone())).await;
}
}
}
_ => {}
}
}
}
impl Default for SlopConsumer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::atomic::{AtomicBool, Ordering};
struct MockTransport {
provider_messages: Vec<Value>,
}
impl ClientTransport for MockTransport {
fn connect(
&self,
) -> Pin<
Box<
dyn Future<
Output = Result<(
mpsc::UnboundedSender<Value>,
mpsc::UnboundedReceiver<Value>,
)>,
> + Send,
>,
> {
let messages = self.provider_messages.clone();
Box::pin(async move {
let (consumer_tx, _consumer_rx) = mpsc::unbounded_channel();
let (provider_tx, provider_rx) = mpsc::unbounded_channel();
for msg in messages {
provider_tx.send(msg).unwrap();
}
Ok((consumer_tx, provider_rx))
})
}
}
#[tokio::test]
async fn test_connect_hello() {
let transport = MockTransport {
provider_messages: vec![json!({
"type": "hello",
"provider": {"id": "app", "name": "App", "slop_version": "0.1"}
})],
};
let consumer = SlopConsumer::new();
let hello = consumer.connect(&transport).await.unwrap();
assert_eq!(hello["type"], "hello");
assert_eq!(hello["provider"]["id"], "app");
}
#[tokio::test]
async fn test_disconnect_callback() {
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let consumer = SlopConsumer::new();
consumer
.on_disconnect(move || {
called_clone.store(true, Ordering::SeqCst);
})
.await;
let inner = consumer.inner.lock().await;
assert_eq!(inner.disconnect_callbacks.len(), 1);
}
#[tokio::test]
async fn test_patch_callback() {
let consumer = SlopConsumer::new();
let patch_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let pc = patch_count.clone();
consumer
.on_patch(move |_sub, _ops, _v| {
pc.fetch_add(1, Ordering::SeqCst);
})
.await;
let inner = consumer.inner.lock().await;
assert_eq!(inner.patch_callbacks.len(), 1);
}
#[tokio::test]
async fn test_new_default() {
let c1 = SlopConsumer::new();
let c2 = SlopConsumer::default();
let inner1 = c1.inner.lock().await;
let inner2 = c2.inner.lock().await;
assert!(inner1.sender.is_none());
assert!(inner2.sender.is_none());
}
}