1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::Result;
5use axum::{
6 Json, Router,
7 extract::Query,
8 extract::State,
9 extract::ws::{Message, WebSocket, WebSocketUpgrade},
10 routing::get,
11};
12use colored::*;
13use futures::{SinkExt, StreamExt};
14
15use crate::crdt::Operation;
16use crate::storage::{Database, OperationLog};
17use crate::sync::{GLOBAL_CLOCK, SyncManager, SyncMessage};
18use dashmap::DashSet;
19use serde::Deserialize;
20use sha2::{Digest, Sha256};
21use uuid::Uuid;
22
23#[derive(Clone)]
24pub struct AppState {
25 pub oplog: Arc<OperationLog>,
26 pub db: Arc<Database>,
27 pub sync: SyncManager,
28 pub actor_id: String,
29 pub repo_id: String,
30 pub seen: Arc<DashSet<Uuid>>,
31}
32
33pub async fn serve(port: u16, path: PathBuf) -> Result<()> {
34 let forge_path = path.join(".dx/forge");
36 let db = Arc::new(Database::new(&forge_path)?);
37 db.initialize()?;
38 let oplog = Arc::new(OperationLog::new(db.clone()));
39
40 let config_path = forge_path.join("config.json");
42 let default_repo_id = {
43 let mut hasher = Sha256::new();
44 let path_string = forge_path.to_string_lossy().into_owned();
45 hasher.update(path_string.as_bytes());
46 format!("repo-{:x}", hasher.finalize())
47 };
48
49 let (actor_id, repo_id) = if let Ok(bytes) = tokio::fs::read(&config_path).await {
50 if let Ok(cfg) = serde_json::from_slice::<serde_json::Value>(&bytes) {
51 let actor = cfg
52 .get("actor_id")
53 .and_then(|s| s.as_str())
54 .map(|s| s.to_string())
55 .unwrap_or_else(|| whoami::username());
56 let repo = cfg
57 .get("repo_id")
58 .and_then(|s| s.as_str())
59 .map(|s| s.to_string())
60 .unwrap_or_else(|| default_repo_id.clone());
61 (actor, repo)
62 } else {
63 (whoami::username(), default_repo_id.clone())
64 }
65 } else {
66 (whoami::username(), default_repo_id)
67 };
68
69 let state = AppState {
70 oplog,
71 db,
72 sync: SyncManager::new(),
73 actor_id,
74 repo_id,
75 seen: Arc::new(DashSet::new()),
76 };
77
78 let app = Router::new()
79 .route("/", get(|| async { "Forge DeltaDB Server" }))
80 .route("/health", get(|| async { Json("OK") }))
81 .route("/ops", get(get_ops))
82 .route("/ws", get(ws_handler))
83 .with_state(state);
84
85 let addr = format!("0.0.0.0:{}", port);
86 println!(
87 "{} Server running at {}",
88 "✓".green(),
89 format!("http://{}", addr).bright_blue()
90 );
91
92 let listener = tokio::net::TcpListener::bind(&addr).await?;
93 axum::serve(listener, app).await?;
94
95 Ok(())
96}
97
98async fn ws_handler(
99 State(state): State<AppState>,
100 ws: WebSocketUpgrade,
101) -> impl axum::response::IntoResponse {
102 ws.on_upgrade(move |socket| handle_ws(state, socket))
103}
104
105async fn handle_ws(state: AppState, socket: WebSocket) {
106 let (mut sender, mut receiver) = socket.split();
107
108 let handshake = SyncMessage::handshake(state.actor_id.clone(), state.repo_id.clone());
110 if let Ok(text) = serde_json::to_string(&handshake) {
111 let _ = sender.send(Message::Text(text.into())).await;
112 }
113
114 let mut rx = state.sync.subscribe();
116 let send_task = tokio::spawn(async move {
117 while let Ok(op_arc) = rx.recv().await {
118 if let Ok(text) = serde_json::to_string(&SyncMessage::operation((*op_arc).clone())) {
120 if sender.send(Message::Text(text.into())).await.is_err() {
121 break;
122 }
123 }
124 }
125 });
126
127 let state_recv = state.clone();
129 let recv_task = tokio::spawn(async move {
130 let oplog = state_recv.oplog.clone();
131 while let Some(msg) = receiver.next().await {
132 match msg {
133 Ok(Message::Text(text)) => {
134 let text: String = text.to_string();
135 if let Ok(msg) = serde_json::from_str::<SyncMessage>(&text) {
136 match msg {
137 SyncMessage::Handshake { actor_id, repo_id } => {
138 println!(
139 "{} Peer handshake: actor={} repo={}",
140 "↔".bright_blue(),
141 actor_id.bright_yellow(),
142 repo_id.bright_white()
143 );
144 }
145 SyncMessage::Operation { operation: op } => {
146 if insert_seen(&state_recv.seen, op.id) {
147 if let Some(lamport) = op.lamport() {
148 GLOBAL_CLOCK.observe(lamport);
149 }
150 let _ = oplog.append(op.clone());
151 let _ = state_recv.sync.publish(Arc::new(op));
152 }
153 }
154 }
155 } else if let Ok(op) = serde_json::from_str::<Operation>(&text) {
156 if insert_seen(&state_recv.seen, op.id) {
157 if let Some(lamport) = op.lamport() {
158 GLOBAL_CLOCK.observe(lamport);
159 }
160 let _ = oplog.append(op.clone());
161 let _ = state_recv.sync.publish(Arc::new(op));
162 }
163 }
164 }
165 Ok(Message::Binary(bin)) => {
166 if let Ok(op) = serde_cbor::from_slice::<Operation>(&bin) {
167 if insert_seen(&state_recv.seen, op.id) {
168 if let Some(lamport) = op.lamport() {
169 GLOBAL_CLOCK.observe(lamport);
170 }
171 let _ = oplog.append(op.clone());
172 let _ = state_recv.sync.publish(Arc::new(op));
173 }
174 }
175 }
176 Ok(Message::Close(_)) | Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
177 Err(_) => break,
178 }
179 }
180 });
181
182 let _ = tokio::join!(send_task, recv_task);
183}
184
185#[derive(Deserialize)]
186struct OpsQuery {
187 file: Option<String>,
188 limit: Option<usize>,
189}
190
191async fn get_ops(
192 State(state): State<AppState>,
193 Query(query): Query<OpsQuery>,
194) -> Result<Json<Vec<Operation>>, axum::http::StatusCode> {
195 let limit = query.limit.unwrap_or(50);
196 let result = if let Some(file) = query.file.as_deref() {
197 let p = std::path::PathBuf::from(file);
198 state.db.get_operations(Some(&p), limit)
199 } else {
200 state.db.get_operations(None, limit)
201 };
202
203 match result {
204 Ok(ops) => Ok(Json(ops)),
205 Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
206 }
207}
208
209const SEEN_LIMIT: usize = 10_000;
210
211fn insert_seen(cache: &DashSet<Uuid>, id: Uuid) -> bool {
212 let inserted = cache.insert(id);
213 if inserted {
214 enforce_seen_limit(cache);
215 }
216 inserted
217}
218
219fn enforce_seen_limit(cache: &DashSet<Uuid>) {
220 while cache.len() > SEEN_LIMIT {
221 if let Some(entry) = cache.iter().next() {
222 let key = *entry.key();
223 drop(entry);
224 cache.remove(&key);
225 } else {
226 break;
227 }
228 }
229}