1use std::sync::Arc;
2
3use anyhow::{anyhow, Result};
4use futures::{SinkExt, StreamExt};
5use tokio::task::JoinHandle;
6use tokio_tungstenite::tungstenite::Message;
7use url::Url;
8
9use super::protocol::SyncManager;
10use crate::crdt::Operation;
11use crate::storage::OperationLog;
12use crate::sync::{SyncMessage, GLOBAL_CLOCK};
13use colored::*;
14use dashmap::DashSet;
15use reqwest::Client;
16use uuid::Uuid;
17
18pub async fn connect_peer(
22 url: &str,
23 actor_id: String,
24 repo_id: String,
25 sync: SyncManager,
26 oplog: Arc<OperationLog>,
27) -> Result<JoinHandle<()>> {
28 let seen = Arc::new(DashSet::new());
29 let url = Url::parse(url).map_err(|e| anyhow!("invalid ws url: {e}"))?;
30 let (ws_stream, _) = tokio_tungstenite::connect_async(url.as_str()).await?;
31
32 let (mut ws_tx, mut ws_rx) = ws_stream.split();
33
34 let handshake = SyncMessage::handshake(actor_id.clone(), repo_id.clone());
36 let handshake_json = serde_json::to_string(&handshake)?;
37 ws_tx.send(Message::Text(handshake_json.into())).await?;
38
39 if let Some(ops_url) = derive_ops_url(&url) {
41 if let Ok(ops) = fetch_initial_ops(ops_url).await {
42 for op in ops.into_iter().rev() {
43 if insert_seen(&seen, op.id) {
44 if let Some(lamport) = op.lamport() {
45 GLOBAL_CLOCK.observe(lamport);
46 }
47 if let Ok(true) = oplog.append(op.clone()) {
48 let _ = sync.publish(Arc::new(op));
49 }
50 }
51 }
52 }
53 }
54
55 let mut rx = sync.subscribe();
57
58 let actor_id_clone = actor_id.clone();
60 let seen_forward = seen.clone();
61 let forward = tokio::spawn(async move {
62 loop {
63 match rx.recv().await {
64 Ok(op_arc) => {
65 if op_arc.actor_id == actor_id_clone && insert_seen(&seen_forward, op_arc.id) {
67 if let Ok(json) =
68 serde_json::to_string(&SyncMessage::operation((*op_arc).clone()))
69 {
70 if ws_tx.send(Message::Text(json.into())).await.is_err() {
71 break;
72 }
73 }
74 }
75 }
76 Err(_) => break,
77 }
78 }
79 });
80
81 let sync_clone = sync.clone();
83 let actor_id_clone2 = actor_id.clone();
84 let oplog_clone = oplog.clone();
85 let seen_recv = seen.clone();
86 let recv = tokio::spawn(async move {
87 while let Some(msg) = ws_rx.next().await {
88 match msg {
89 Ok(Message::Text(text)) => {
90 let text: String = text.to_string();
91 if let Ok(msg) = serde_json::from_str::<SyncMessage>(&text) {
92 match msg {
93 SyncMessage::Handshake { actor_id, repo_id } => {
94 println!(
95 "{} Connected peer handshake (actor={} repo={})",
96 "↔".bright_blue(),
97 actor_id.bright_yellow(),
98 repo_id.bright_white()
99 );
100 }
101 SyncMessage::Operation { operation: op } => {
102 if op.actor_id != actor_id_clone2 && insert_seen(&seen_recv, op.id)
103 {
104 if let Some(lamport) = op.lamport() {
105 GLOBAL_CLOCK.observe(lamport);
106 }
107 let _ = oplog_clone.append(op.clone());
108 let _ = sync_clone.publish(Arc::new(op));
109 }
110 }
111 }
112 } else if let Ok(op) = serde_json::from_str::<Operation>(&text) {
113 if op.actor_id != actor_id_clone2 && insert_seen(&seen_recv, op.id) {
114 if let Some(lamport) = op.lamport() {
115 GLOBAL_CLOCK.observe(lamport);
116 }
117 let _ = oplog_clone.append(op.clone());
118 let _ = sync_clone.publish(Arc::new(op));
119 }
120 }
121 }
122 Ok(Message::Binary(bin)) => {
123 if let Ok(op) = serde_cbor::from_slice::<Operation>(&bin) {
124 if op.actor_id != actor_id_clone2 && insert_seen(&seen_recv, op.id) {
125 if let Some(lamport) = op.lamport() {
126 GLOBAL_CLOCK.observe(lamport);
127 }
128 let _ = oplog_clone.append(op.clone());
129 let _ = sync_clone.publish(Arc::new(op));
130 }
131 }
132 }
133 Ok(Message::Frame(_)) => { }
134 Ok(Message::Close(_)) | Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
135 }
137 Err(_) => break,
138 }
139 }
140 });
141
142 let handle = tokio::spawn(async move {
144 let _ = tokio::join!(forward, recv);
145 });
146
147 Ok(handle)
148}
149
150const SEEN_LIMIT: usize = 10_000;
151
152fn insert_seen(cache: &DashSet<Uuid>, id: Uuid) -> bool {
153 let inserted = cache.insert(id);
154 if inserted {
155 enforce_seen_limit(cache);
156 }
157 inserted
158}
159
160fn enforce_seen_limit(cache: &DashSet<Uuid>) {
161 while cache.len() > SEEN_LIMIT {
162 if let Some(entry) = cache.iter().next() {
163 let key = *entry.key();
164 drop(entry);
165 cache.remove(&key);
166 } else {
167 break;
168 }
169 }
170}
171
172fn derive_ops_url(ws_url: &Url) -> Option<Url> {
173 let mut http = ws_url.clone();
174 let scheme = match ws_url.scheme() {
175 "ws" => "http",
176 "wss" => "https",
177 _ => return None,
178 };
179
180 if http.set_scheme(scheme).is_err() {
181 return None;
182 }
183
184 http.set_path("/ops");
185 http.set_query(Some("limit=200"));
186 Some(http)
187}
188
189async fn fetch_initial_ops(url: Url) -> Result<Vec<Operation>> {
190 let client = Client::new();
191 let resp = client.get(url).send().await?;
192 let status = resp.status();
193 if !status.is_success() {
194 return Err(anyhow!("failed to fetch ops: {status}"));
195 }
196 let ops = resp.json::<Vec<Operation>>().await?;
197 Ok(ops)
198}