1use std::collections::HashMap;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use guts_storage::{ObjectId, Reference, Repository};
8use parking_lot::RwLock;
9use tracing::{debug, info, warn};
10
11use crate::message::{Message, ObjectData, RefUpdate, RepoAnnounce, SyncRequest};
12use crate::{P2PError, Result};
13
14pub trait ReplicationHandler: Send + Sync + 'static {
16 fn broadcast(&self, message: Bytes);
18
19 fn send_to(&self, peer: &[u8], message: Bytes);
21}
22
23pub struct ReplicationProtocol {
27 repos: Arc<RwLock<HashMap<String, Arc<Repository>>>>,
29 handler: Option<Arc<dyn ReplicationHandler>>,
31}
32
33impl ReplicationProtocol {
34 pub fn new() -> Self {
36 Self {
37 repos: Arc::new(RwLock::new(HashMap::new())),
38 handler: None,
39 }
40 }
41
42 pub fn set_handler(&mut self, handler: Arc<dyn ReplicationHandler>) {
44 self.handler = Some(handler);
45 }
46
47 pub fn repos(&self) -> Arc<RwLock<HashMap<String, Arc<Repository>>>> {
49 self.repos.clone()
50 }
51
52 pub fn register_repo(&self, key: String, repo: Arc<Repository>) {
54 self.repos.write().insert(key, repo);
55 }
56
57 pub fn get_repo(&self, key: &str) -> Option<Arc<Repository>> {
59 self.repos.read().get(key).cloned()
60 }
61
62 pub fn get_or_create_repo(&self, key: &str) -> Arc<Repository> {
64 {
65 let repos = self.repos.read();
66 if let Some(repo) = repos.get(key) {
67 return repo.clone();
68 }
69 }
70
71 let mut repos = self.repos.write();
72 if let Some(repo) = repos.get(key) {
74 return repo.clone();
75 }
76
77 let parts: Vec<&str> = key.split('/').collect();
79 let (owner, name) = if parts.len() == 2 {
80 (parts[0], parts[1])
81 } else {
82 ("unknown", key)
83 };
84
85 let repo = Arc::new(Repository::new(name, owner));
86 repos.insert(key.to_string(), repo.clone());
87 repo
88 }
89
90 pub fn handle_message(&self, peer_id: &[u8], data: &[u8]) -> Result<Option<Message>> {
94 let message = Message::decode(data)?;
95
96 match message {
97 Message::RepoAnnounce(announce) => self.handle_announce(peer_id, announce),
98 Message::SyncRequest(request) => self.handle_sync_request(peer_id, request),
99 Message::ObjectData(object_data) => self.handle_object_data(peer_id, object_data),
100 Message::RefUpdate(ref_update) => self.handle_ref_update(peer_id, ref_update),
101 }
102 }
103
104 fn handle_announce(&self, peer_id: &[u8], announce: RepoAnnounce) -> Result<Option<Message>> {
106 info!(
107 repo = %announce.repo_key,
108 objects = announce.object_ids.len(),
109 refs = announce.refs.len(),
110 peer = %hex::encode(peer_id),
111 "Received repo announce"
112 );
113
114 let repo = self.get_or_create_repo(&announce.repo_key);
116
117 let missing: Vec<ObjectId> = announce
119 .object_ids
120 .iter()
121 .filter(|oid| !repo.objects.contains(oid))
122 .copied()
123 .collect();
124
125 if missing.is_empty() {
126 for (ref_name, oid) in announce.refs {
128 repo.refs.set(&ref_name, oid);
129 }
130 debug!(repo = %announce.repo_key, "All objects already present");
131 return Ok(None);
132 }
133
134 info!(
136 repo = %announce.repo_key,
137 missing = missing.len(),
138 "Requesting missing objects"
139 );
140
141 Ok(Some(Message::SyncRequest(SyncRequest {
142 repo_key: announce.repo_key,
143 want: missing,
144 })))
145 }
146
147 fn handle_sync_request(&self, peer_id: &[u8], request: SyncRequest) -> Result<Option<Message>> {
149 debug!(
150 repo = %request.repo_key,
151 want = request.want.len(),
152 peer = %hex::encode(peer_id),
153 "Received sync request"
154 );
155
156 let repo = self
157 .get_repo(&request.repo_key)
158 .ok_or_else(|| P2PError::RepoNotFound(request.repo_key.clone()))?;
159
160 let mut objects = Vec::new();
162 for oid in &request.want {
163 match repo.objects.get(oid) {
164 Ok(obj) => objects.push(obj),
165 Err(e) => {
166 warn!(
167 object = %oid.to_hex(),
168 error = %e,
169 "Requested object not found"
170 );
171 }
172 }
173 }
174
175 if objects.is_empty() {
176 debug!(repo = %request.repo_key, "No objects to send");
177 return Ok(None);
178 }
179
180 info!(
181 repo = %request.repo_key,
182 objects = objects.len(),
183 "Sending objects to peer"
184 );
185
186 Ok(Some(Message::ObjectData(ObjectData {
187 repo_key: request.repo_key,
188 objects,
189 })))
190 }
191
192 fn handle_object_data(
194 &self,
195 peer_id: &[u8],
196 object_data: ObjectData,
197 ) -> Result<Option<Message>> {
198 info!(
199 repo = %object_data.repo_key,
200 objects = object_data.objects.len(),
201 peer = %hex::encode(peer_id),
202 "Received objects"
203 );
204
205 let repo = self.get_or_create_repo(&object_data.repo_key);
206
207 for obj in object_data.objects {
209 let oid = repo.objects.put(obj);
210 debug!(object = %oid.to_hex(), "Stored object");
211 }
212
213 Ok(None)
214 }
215
216 fn handle_ref_update(&self, peer_id: &[u8], ref_update: RefUpdate) -> Result<Option<Message>> {
218 info!(
219 repo = %ref_update.repo_key,
220 ref_name = %ref_update.ref_name,
221 old = %ref_update.old_id.to_hex(),
222 new = %ref_update.new_id.to_hex(),
223 peer = %hex::encode(peer_id),
224 "Received ref update"
225 );
226
227 let repo = self.get_or_create_repo(&ref_update.repo_key);
228
229 let zero_id = ObjectId::from_bytes([0u8; 20]);
231 if ref_update.new_id != zero_id && !repo.objects.contains(&ref_update.new_id) {
232 warn!(
234 object = %ref_update.new_id.to_hex(),
235 "Missing target object for ref update"
236 );
237 return Ok(Some(Message::SyncRequest(SyncRequest {
238 repo_key: ref_update.repo_key,
239 want: vec![ref_update.new_id],
240 })));
241 }
242
243 if ref_update.new_id == zero_id {
245 let _ = repo.refs.delete(&ref_update.ref_name);
247 } else {
248 repo.refs.set(&ref_update.ref_name, ref_update.new_id);
249 }
250
251 Ok(None)
252 }
253
254 pub fn broadcast_update(
258 &self,
259 repo_key: &str,
260 new_objects: Vec<ObjectId>,
261 refs: Vec<(String, ObjectId)>,
262 ) {
263 if let Some(handler) = &self.handler {
264 let announce = RepoAnnounce {
265 repo_key: repo_key.to_string(),
266 object_ids: new_objects,
267 refs,
268 };
269 handler.broadcast(announce.encode());
270 }
271 }
272
273 pub fn broadcast_ref_update(
275 &self,
276 repo_key: &str,
277 ref_name: &str,
278 old_id: ObjectId,
279 new_id: ObjectId,
280 ) {
281 if let Some(handler) = &self.handler {
282 let update = RefUpdate {
283 repo_key: repo_key.to_string(),
284 ref_name: ref_name.to_string(),
285 old_id,
286 new_id,
287 };
288 handler.broadcast(update.encode());
289 }
290 }
291
292 pub fn get_repo_state(&self, key: &str) -> Option<RepoState> {
294 let repos = self.repos.read();
295 let repo = repos.get(key)?;
296
297 let objects = repo.objects.list_objects();
298 let refs: Vec<(String, ObjectId)> = repo
299 .refs
300 .list_all()
301 .into_iter()
302 .filter_map(|(name, reference)| match reference {
303 Reference::Direct(oid) => Some((name, oid)),
304 Reference::Symbolic(_) => None,
305 })
306 .collect();
307
308 Some(RepoState { objects, refs })
309 }
310}
311
312impl Default for ReplicationProtocol {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318#[derive(Debug, Clone)]
320pub struct RepoState {
321 pub objects: Vec<ObjectId>,
323 pub refs: Vec<(String, ObjectId)>,
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use guts_storage::GitObject;
331 use std::sync::atomic::{AtomicUsize, Ordering};
332
333 struct MockHandler {
334 broadcast_count: AtomicUsize,
335 messages: RwLock<Vec<Bytes>>,
336 }
337
338 impl MockHandler {
339 fn new() -> Self {
340 Self {
341 broadcast_count: AtomicUsize::new(0),
342 messages: RwLock::new(Vec::new()),
343 }
344 }
345 }
346
347 impl ReplicationHandler for MockHandler {
348 fn broadcast(&self, message: Bytes) {
349 self.broadcast_count.fetch_add(1, Ordering::SeqCst);
350 self.messages.write().push(message);
351 }
352
353 fn send_to(&self, _peer: &[u8], message: Bytes) {
354 self.messages.write().push(message);
355 }
356 }
357
358 #[test]
359 fn test_protocol_register_repo() {
360 let protocol = ReplicationProtocol::new();
361 let repo = Arc::new(Repository::new("test", "alice"));
362 protocol.register_repo("alice/test".to_string(), repo.clone());
363
364 let retrieved = protocol.get_repo("alice/test").unwrap();
365 assert_eq!(retrieved.name, "test");
366 assert_eq!(retrieved.owner, "alice");
367 }
368
369 #[test]
370 fn test_protocol_handle_announce() {
371 let protocol = ReplicationProtocol::new();
372
373 let announce = RepoAnnounce {
375 repo_key: "bob/repo".to_string(),
376 object_ids: vec![ObjectId::from_bytes([1u8; 20])],
377 refs: vec![],
378 };
379
380 let peer_id = [0u8; 32];
381 let result = protocol
382 .handle_message(&peer_id, &announce.encode())
383 .unwrap();
384
385 match result {
387 Some(Message::SyncRequest(req)) => {
388 assert_eq!(req.repo_key, "bob/repo");
389 assert_eq!(req.want.len(), 1);
390 }
391 _ => panic!("expected sync request"),
392 }
393 }
394
395 #[test]
396 fn test_protocol_handle_sync_request() {
397 let protocol = ReplicationProtocol::new();
398
399 let repo = Arc::new(Repository::new("repo", "alice"));
401 let obj = GitObject::blob(b"hello".to_vec());
402 let oid = repo.objects.put(obj);
403 protocol.register_repo("alice/repo".to_string(), repo);
404
405 let request = SyncRequest {
407 repo_key: "alice/repo".to_string(),
408 want: vec![oid],
409 };
410
411 let peer_id = [0u8; 32];
412 let result = protocol
413 .handle_message(&peer_id, &request.encode())
414 .unwrap();
415
416 match result {
418 Some(Message::ObjectData(data)) => {
419 assert_eq!(data.repo_key, "alice/repo");
420 assert_eq!(data.objects.len(), 1);
421 assert_eq!(data.objects[0].id, oid);
422 }
423 _ => panic!("expected object data"),
424 }
425 }
426
427 #[test]
428 fn test_protocol_handle_object_data() {
429 let protocol = ReplicationProtocol::new();
430
431 let obj = GitObject::blob(b"world".to_vec());
432 let oid = obj.id;
433
434 let object_data = ObjectData {
435 repo_key: "carol/code".to_string(),
436 objects: vec![obj],
437 };
438
439 let peer_id = [0u8; 32];
440 let result = protocol
441 .handle_message(&peer_id, &object_data.encode())
442 .unwrap();
443
444 assert!(result.is_none());
445
446 let repo = protocol.get_repo("carol/code").unwrap();
448 assert!(repo.objects.contains(&oid));
449 }
450
451 #[test]
452 fn test_protocol_broadcast() {
453 let mut protocol = ReplicationProtocol::new();
454 let handler = Arc::new(MockHandler::new());
455 protocol.set_handler(handler.clone());
456
457 protocol.broadcast_update("test/repo", vec![ObjectId::from_bytes([1u8; 20])], vec![]);
458
459 assert_eq!(handler.broadcast_count.load(Ordering::SeqCst), 1);
460 assert_eq!(handler.messages.read().len(), 1);
461 }
462}