use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::Formatter;
use std::sync::Arc;
use thiserror::Error;
use yrs::block::ClientID;
use yrs::updates::decoder::{Decode, Decoder};
use yrs::updates::encoder::{Encode, Encoder};
use yrs::{Doc, Observer, Subscription};
const NULL_STR: &str = "null";
type AwarenessObserver = Observer<Arc<dyn Fn(&Awareness, &Event) + Send + Sync + 'static>>;
pub struct Awareness {
pub doc: Doc,
states: HashMap<ClientID, String>,
meta: HashMap<ClientID, MetaClientState>,
on_update: Option<AwarenessObserver>,
}
impl Awareness {
pub fn new(doc: Doc) -> Self {
Awareness {
doc,
on_update: None,
states: HashMap::new(),
meta: HashMap::new(),
}
}
pub fn on_update<F>(&mut self, f: F) -> Subscription
where
F: Fn(&Awareness, &Event) + Send + Sync + 'static,
{
let eh = self.on_update.get_or_insert_with(Observer::default);
eh.subscribe(Arc::new(f))
}
pub fn doc(&self) -> &Doc {
&self.doc
}
pub fn doc_mut(&mut self) -> &mut Doc {
&mut self.doc
}
pub fn client_id(&self) -> ClientID {
self.doc.client_id()
}
pub fn clients(&self) -> &HashMap<ClientID, String> {
&self.states
}
pub fn local_state(&self) -> Option<&str> {
Some(self.states.get(&self.doc.client_id())?.as_str())
}
pub fn set_local_state<S: Into<String>>(&mut self, json: S) {
let client_id = self.doc.client_id();
self.update_meta(client_id);
let new: String = json.into();
match self.states.entry(client_id) {
Entry::Occupied(mut e) => {
e.insert(new);
if let Some(eh) = self.on_update.as_ref() {
let e = Event::new(vec![], vec![client_id], vec![]);
eh.trigger(|cb| {
cb(self, &e);
});
}
}
Entry::Vacant(e) => {
e.insert(new);
if let Some(eh) = self.on_update.as_ref() {
let e = Event::new(vec![client_id], vec![], vec![]);
eh.trigger(|cb| {
cb(self, &e);
});
}
}
}
}
pub fn remove_state(&mut self, client_id: ClientID) {
let prev_state = self.states.remove(&client_id);
self.update_meta(client_id);
if let Some(eh) = self.on_update.as_ref() {
if prev_state.is_some() {
let e = Event::new(Vec::default(), Vec::default(), vec![client_id]);
eh.trigger(|cb| {
cb(self, &e);
});
}
}
}
pub fn clean_local_state(&mut self) {
let client_id = self.doc.client_id();
self.remove_state(client_id);
}
fn update_meta(&mut self, client_id: ClientID) {
match self.meta.entry(client_id) {
Entry::Occupied(mut e) => {
let clock = e.get().clock + 1;
let meta = MetaClientState::new(clock);
e.insert(meta);
}
Entry::Vacant(e) => {
e.insert(MetaClientState::new(1));
}
}
}
pub fn update(&self) -> Result<AwarenessUpdate, Error> {
let clients = self.states.keys().cloned();
self.update_with_clients(clients)
}
pub fn update_with_clients<I: IntoIterator<Item = ClientID>>(
&self,
clients: I,
) -> Result<AwarenessUpdate, Error> {
let mut res = HashMap::new();
for client_id in clients {
let clock = if let Some(meta) = self.meta.get(&client_id) {
meta.clock
} else {
return Err(Error::ClientNotFound(client_id));
};
let json = if let Some(json) = self.states.get(&client_id) {
json.clone()
} else {
String::from(NULL_STR)
};
res.insert(client_id, AwarenessUpdateEntry { clock, json });
}
Ok(AwarenessUpdate { clients: res })
}
pub fn apply_update(&mut self, update: AwarenessUpdate) -> Result<(), Error> {
let mut added = Vec::new();
let mut updated = Vec::new();
let mut removed = Vec::new();
for (client_id, entry) in update.clients {
let mut clock = entry.clock;
let is_null = entry.json.as_str() == NULL_STR;
match self.meta.entry(client_id) {
Entry::Occupied(mut e) => {
let prev = e.get();
let is_removed =
prev.clock == clock && is_null && self.states.contains_key(&client_id);
let is_new = prev.clock < clock;
if is_new || is_removed {
if is_null {
if client_id == self.doc.client_id()
&& self.states.contains_key(&client_id)
{
clock += 1;
} else {
self.states.remove(&client_id);
if self.on_update.is_some() {
removed.push(client_id);
}
}
} else {
match self.states.entry(client_id) {
Entry::Occupied(mut e) => {
if self.on_update.is_some() {
updated.push(client_id);
}
e.insert(entry.json);
}
Entry::Vacant(e) => {
e.insert(entry.json);
if self.on_update.is_some() {
updated.push(client_id);
}
}
}
}
e.insert(MetaClientState::new(clock));
true
} else {
false
}
}
Entry::Vacant(e) => {
e.insert(MetaClientState::new(clock));
self.states.insert(client_id, entry.json);
if self.on_update.is_some() {
added.push(client_id);
}
true
}
};
}
if let Some(eh) = self.on_update.as_ref() {
if !added.is_empty() || !updated.is_empty() || !removed.is_empty() {
let e = Event::new(added, updated, removed);
eh.trigger(|cb| {
cb(self, &e);
});
}
}
Ok(())
}
}
impl Default for Awareness {
fn default() -> Self {
Awareness::new(Doc::new())
}
}
impl std::fmt::Debug for Awareness {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Awareness")
.field("state", &self.states)
.field("meta", &self.meta)
.field("doc", &self.doc)
.finish()
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct AwarenessUpdate {
pub(crate) clients: HashMap<ClientID, AwarenessUpdateEntry>,
}
impl Encode for AwarenessUpdate {
fn encode<E: Encoder>(&self, encoder: &mut E) {
encoder.write_var(self.clients.len());
for (&client_id, e) in self.clients.iter() {
encoder.write_var(client_id);
encoder.write_var(e.clock);
encoder.write_string(&e.json);
}
}
}
impl Decode for AwarenessUpdate {
fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
let len: usize = decoder.read_var()?;
let mut clients = HashMap::with_capacity(len);
for _ in 0..len {
let client_id: ClientID = decoder.read_var()?;
let clock: u32 = decoder.read_var()?;
let json = decoder.read_string()?.to_string();
clients.insert(client_id, AwarenessUpdateEntry { clock, json });
}
Ok(AwarenessUpdate { clients })
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct AwarenessUpdateEntry {
pub(crate) clock: u32,
pub(crate) json: String,
}
#[derive(Error, Debug)]
pub enum Error {
#[error("client ID `{0}` not found")]
ClientNotFound(ClientID),
}
#[derive(Debug, Clone)]
struct MetaClientState {
clock: u32,
}
impl MetaClientState {
fn new(clock: u32) -> Self {
MetaClientState { clock }
}
}
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub struct Event {
added: Vec<ClientID>,
updated: Vec<ClientID>,
removed: Vec<ClientID>,
}
impl Event {
pub fn new(added: Vec<ClientID>, updated: Vec<ClientID>, removed: Vec<ClientID>) -> Self {
Event {
added,
updated,
removed,
}
}
pub fn added(&self) -> &[ClientID] {
&self.added
}
pub fn updated(&self) -> &[ClientID] {
&self.updated
}
pub fn removed(&self) -> &[ClientID] {
&self.removed
}
}
#[cfg(test)]
mod test {
use super::*;
use std::sync::mpsc::{channel, Receiver};
use yrs::Doc;
fn update(
recv: &mut Receiver<Event>,
from: &Awareness,
to: &mut Awareness,
) -> Result<Event, Box<dyn std::error::Error>> {
let e = recv.try_recv()?;
let u = from.update_with_clients([e.added(), e.updated(), e.removed()].concat())?;
to.apply_update(u)?;
Ok(e)
}
#[test]
fn awareness() -> Result<(), Box<dyn std::error::Error>> {
let (s1, mut o_local) = channel();
let mut local = Awareness::new(Doc::with_client_id(1));
let _sub_local = local.on_update(move |_, e| {
s1.send(e.clone()).unwrap();
});
let (s2, o_remote) = channel();
let mut remote = Awareness::new(Doc::with_client_id(2));
let _sub_remote = local.on_update(move |_, e| {
s2.send(e.clone()).unwrap();
});
local.set_local_state("{x:3}");
let _e_local = update(&mut o_local, &local, &mut remote)?;
assert_eq!(remote.clients()[&1], "{x:3}");
assert_eq!(remote.meta[&1].clock, 1);
assert_eq!(o_remote.try_recv()?.added, &[1]);
local.set_local_state("{x:4}");
let e_local = update(&mut o_local, &local, &mut remote)?;
let e_remote = o_remote.try_recv()?;
assert_eq!(remote.clients()[&1], "{x:4}");
assert_eq!(e_remote, Event::new(vec![], vec![1], vec![]));
assert_eq!(e_remote, e_local);
local.clean_local_state();
let e_local = update(&mut o_local, &local, &mut remote)?;
let e_remote = o_remote.try_recv()?;
assert_eq!(e_remote.removed.len(), 1);
assert_eq!(local.clients().get(&1), None);
assert_eq!(e_remote, e_local);
Ok(())
}
}