Skip to main content

guts_p2p/
protocol.rs

1//! Replication protocol implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use guts_storage::{ObjectId, Reference, Repository};
8use parking_lot::RwLock;
9use tracing::{debug, info, warn};
10
11use crate::message::{Message, ObjectData, RefUpdate, RepoAnnounce, SyncRequest};
12use crate::{P2PError, Result};
13
14/// Callback trait for sending messages to peers.
15pub trait ReplicationHandler: Send + Sync + 'static {
16    /// Send a message to all connected peers.
17    fn broadcast(&self, message: Bytes);
18
19    /// Send a message to a specific peer.
20    fn send_to(&self, peer: &[u8], message: Bytes);
21}
22
23/// Repository replication protocol.
24///
25/// Handles incoming P2P messages and coordinates repository synchronization.
26pub struct ReplicationProtocol {
27    /// Repository store (owner/name -> Repository).
28    repos: Arc<RwLock<HashMap<String, Arc<Repository>>>>,
29    /// Message handler for sending responses.
30    handler: Option<Arc<dyn ReplicationHandler>>,
31}
32
33impl ReplicationProtocol {
34    /// Create a new replication protocol instance.
35    pub fn new() -> Self {
36        Self {
37            repos: Arc::new(RwLock::new(HashMap::new())),
38            handler: None,
39        }
40    }
41
42    /// Set the message handler for sending responses.
43    pub fn set_handler(&mut self, handler: Arc<dyn ReplicationHandler>) {
44        self.handler = Some(handler);
45    }
46
47    /// Get the repository store reference.
48    pub fn repos(&self) -> Arc<RwLock<HashMap<String, Arc<Repository>>>> {
49        self.repos.clone()
50    }
51
52    /// Register a repository for replication.
53    pub fn register_repo(&self, key: String, repo: Arc<Repository>) {
54        self.repos.write().insert(key, repo);
55    }
56
57    /// Get a repository by key.
58    pub fn get_repo(&self, key: &str) -> Option<Arc<Repository>> {
59        self.repos.read().get(key).cloned()
60    }
61
62    /// Get or create a repository by key.
63    pub fn get_or_create_repo(&self, key: &str) -> Arc<Repository> {
64        {
65            let repos = self.repos.read();
66            if let Some(repo) = repos.get(key) {
67                return repo.clone();
68            }
69        }
70
71        let mut repos = self.repos.write();
72        // Double-check after acquiring write lock
73        if let Some(repo) = repos.get(key) {
74            return repo.clone();
75        }
76
77        // Create new repo
78        let parts: Vec<&str> = key.split('/').collect();
79        let (owner, name) = if parts.len() == 2 {
80            (parts[0], parts[1])
81        } else {
82            ("unknown", key)
83        };
84
85        let repo = Arc::new(Repository::new(name, owner));
86        repos.insert(key.to_string(), repo.clone());
87        repo
88    }
89
90    /// Handle an incoming message from a peer.
91    ///
92    /// Returns an optional response message.
93    pub fn handle_message(&self, peer_id: &[u8], data: &[u8]) -> Result<Option<Message>> {
94        let message = Message::decode(data)?;
95
96        match message {
97            Message::RepoAnnounce(announce) => self.handle_announce(peer_id, announce),
98            Message::SyncRequest(request) => self.handle_sync_request(peer_id, request),
99            Message::ObjectData(object_data) => self.handle_object_data(peer_id, object_data),
100            Message::RefUpdate(ref_update) => self.handle_ref_update(peer_id, ref_update),
101        }
102    }
103
104    /// Handle a repository announcement.
105    fn handle_announce(&self, peer_id: &[u8], announce: RepoAnnounce) -> Result<Option<Message>> {
106        info!(
107            repo = %announce.repo_key,
108            objects = announce.object_ids.len(),
109            refs = announce.refs.len(),
110            peer = %hex::encode(peer_id),
111            "Received repo announce"
112        );
113
114        // Get or create the repository
115        let repo = self.get_or_create_repo(&announce.repo_key);
116
117        // Find objects we don't have
118        let missing: Vec<ObjectId> = announce
119            .object_ids
120            .iter()
121            .filter(|oid| !repo.objects.contains(oid))
122            .copied()
123            .collect();
124
125        if missing.is_empty() {
126            // We have all objects, just apply ref updates
127            for (ref_name, oid) in announce.refs {
128                repo.refs.set(&ref_name, oid);
129            }
130            debug!(repo = %announce.repo_key, "All objects already present");
131            return Ok(None);
132        }
133
134        // Request missing objects
135        info!(
136            repo = %announce.repo_key,
137            missing = missing.len(),
138            "Requesting missing objects"
139        );
140
141        Ok(Some(Message::SyncRequest(SyncRequest {
142            repo_key: announce.repo_key,
143            want: missing,
144        })))
145    }
146
147    /// Handle a sync request.
148    fn handle_sync_request(&self, peer_id: &[u8], request: SyncRequest) -> Result<Option<Message>> {
149        debug!(
150            repo = %request.repo_key,
151            want = request.want.len(),
152            peer = %hex::encode(peer_id),
153            "Received sync request"
154        );
155
156        let repo = self
157            .get_repo(&request.repo_key)
158            .ok_or_else(|| P2PError::RepoNotFound(request.repo_key.clone()))?;
159
160        // Collect requested objects
161        let mut objects = Vec::new();
162        for oid in &request.want {
163            match repo.objects.get(oid) {
164                Ok(obj) => objects.push(obj),
165                Err(e) => {
166                    warn!(
167                        object = %oid.to_hex(),
168                        error = %e,
169                        "Requested object not found"
170                    );
171                }
172            }
173        }
174
175        if objects.is_empty() {
176            debug!(repo = %request.repo_key, "No objects to send");
177            return Ok(None);
178        }
179
180        info!(
181            repo = %request.repo_key,
182            objects = objects.len(),
183            "Sending objects to peer"
184        );
185
186        Ok(Some(Message::ObjectData(ObjectData {
187            repo_key: request.repo_key,
188            objects,
189        })))
190    }
191
192    /// Handle object data response.
193    fn handle_object_data(
194        &self,
195        peer_id: &[u8],
196        object_data: ObjectData,
197    ) -> Result<Option<Message>> {
198        info!(
199            repo = %object_data.repo_key,
200            objects = object_data.objects.len(),
201            peer = %hex::encode(peer_id),
202            "Received objects"
203        );
204
205        let repo = self.get_or_create_repo(&object_data.repo_key);
206
207        // Store all received objects
208        for obj in object_data.objects {
209            let oid = repo.objects.put(obj);
210            debug!(object = %oid.to_hex(), "Stored object");
211        }
212
213        Ok(None)
214    }
215
216    /// Handle a reference update.
217    fn handle_ref_update(&self, peer_id: &[u8], ref_update: RefUpdate) -> Result<Option<Message>> {
218        info!(
219            repo = %ref_update.repo_key,
220            ref_name = %ref_update.ref_name,
221            old = %ref_update.old_id.to_hex(),
222            new = %ref_update.new_id.to_hex(),
223            peer = %hex::encode(peer_id),
224            "Received ref update"
225        );
226
227        let repo = self.get_or_create_repo(&ref_update.repo_key);
228
229        // Check if we have the target object
230        let zero_id = ObjectId::from_bytes([0u8; 20]);
231        if ref_update.new_id != zero_id && !repo.objects.contains(&ref_update.new_id) {
232            // We don't have the target object, need to sync
233            warn!(
234                object = %ref_update.new_id.to_hex(),
235                "Missing target object for ref update"
236            );
237            return Ok(Some(Message::SyncRequest(SyncRequest {
238                repo_key: ref_update.repo_key,
239                want: vec![ref_update.new_id],
240            })));
241        }
242
243        // Apply the ref update
244        if ref_update.new_id == zero_id {
245            // Deletion
246            let _ = repo.refs.delete(&ref_update.ref_name);
247        } else {
248            repo.refs.set(&ref_update.ref_name, ref_update.new_id);
249        }
250
251        Ok(None)
252    }
253
254    /// Broadcast a repository update to all peers.
255    ///
256    /// Called after a push to notify peers about new objects.
257    pub fn broadcast_update(
258        &self,
259        repo_key: &str,
260        new_objects: Vec<ObjectId>,
261        refs: Vec<(String, ObjectId)>,
262    ) {
263        if let Some(handler) = &self.handler {
264            let announce = RepoAnnounce {
265                repo_key: repo_key.to_string(),
266                object_ids: new_objects,
267                refs,
268            };
269            handler.broadcast(announce.encode());
270        }
271    }
272
273    /// Broadcast a reference update to all peers.
274    pub fn broadcast_ref_update(
275        &self,
276        repo_key: &str,
277        ref_name: &str,
278        old_id: ObjectId,
279        new_id: ObjectId,
280    ) {
281        if let Some(handler) = &self.handler {
282            let update = RefUpdate {
283                repo_key: repo_key.to_string(),
284                ref_name: ref_name.to_string(),
285                old_id,
286                new_id,
287            };
288            handler.broadcast(update.encode());
289        }
290    }
291
292    /// Get repository state summary for a given repo.
293    pub fn get_repo_state(&self, key: &str) -> Option<RepoState> {
294        let repos = self.repos.read();
295        let repo = repos.get(key)?;
296
297        let objects = repo.objects.list_objects();
298        let refs: Vec<(String, ObjectId)> = repo
299            .refs
300            .list_all()
301            .into_iter()
302            .filter_map(|(name, reference)| match reference {
303                Reference::Direct(oid) => Some((name, oid)),
304                Reference::Symbolic(_) => None,
305            })
306            .collect();
307
308        Some(RepoState { objects, refs })
309    }
310}
311
312impl Default for ReplicationProtocol {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318/// Repository state summary.
319#[derive(Debug, Clone)]
320pub struct RepoState {
321    /// All object IDs in the repository.
322    pub objects: Vec<ObjectId>,
323    /// All direct references (name -> target).
324    pub refs: Vec<(String, ObjectId)>,
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use guts_storage::GitObject;
331    use std::sync::atomic::{AtomicUsize, Ordering};
332
333    struct MockHandler {
334        broadcast_count: AtomicUsize,
335        messages: RwLock<Vec<Bytes>>,
336    }
337
338    impl MockHandler {
339        fn new() -> Self {
340            Self {
341                broadcast_count: AtomicUsize::new(0),
342                messages: RwLock::new(Vec::new()),
343            }
344        }
345    }
346
347    impl ReplicationHandler for MockHandler {
348        fn broadcast(&self, message: Bytes) {
349            self.broadcast_count.fetch_add(1, Ordering::SeqCst);
350            self.messages.write().push(message);
351        }
352
353        fn send_to(&self, _peer: &[u8], message: Bytes) {
354            self.messages.write().push(message);
355        }
356    }
357
358    #[test]
359    fn test_protocol_register_repo() {
360        let protocol = ReplicationProtocol::new();
361        let repo = Arc::new(Repository::new("test", "alice"));
362        protocol.register_repo("alice/test".to_string(), repo.clone());
363
364        let retrieved = protocol.get_repo("alice/test").unwrap();
365        assert_eq!(retrieved.name, "test");
366        assert_eq!(retrieved.owner, "alice");
367    }
368
369    #[test]
370    fn test_protocol_handle_announce() {
371        let protocol = ReplicationProtocol::new();
372
373        // Create announce with unknown objects
374        let announce = RepoAnnounce {
375            repo_key: "bob/repo".to_string(),
376            object_ids: vec![ObjectId::from_bytes([1u8; 20])],
377            refs: vec![],
378        };
379
380        let peer_id = [0u8; 32];
381        let result = protocol
382            .handle_message(&peer_id, &announce.encode())
383            .unwrap();
384
385        // Should request the missing object
386        match result {
387            Some(Message::SyncRequest(req)) => {
388                assert_eq!(req.repo_key, "bob/repo");
389                assert_eq!(req.want.len(), 1);
390            }
391            _ => panic!("expected sync request"),
392        }
393    }
394
395    #[test]
396    fn test_protocol_handle_sync_request() {
397        let protocol = ReplicationProtocol::new();
398
399        // Create a repo with an object
400        let repo = Arc::new(Repository::new("repo", "alice"));
401        let obj = GitObject::blob(b"hello".to_vec());
402        let oid = repo.objects.put(obj);
403        protocol.register_repo("alice/repo".to_string(), repo);
404
405        // Request that object
406        let request = SyncRequest {
407            repo_key: "alice/repo".to_string(),
408            want: vec![oid],
409        };
410
411        let peer_id = [0u8; 32];
412        let result = protocol
413            .handle_message(&peer_id, &request.encode())
414            .unwrap();
415
416        // Should return the object
417        match result {
418            Some(Message::ObjectData(data)) => {
419                assert_eq!(data.repo_key, "alice/repo");
420                assert_eq!(data.objects.len(), 1);
421                assert_eq!(data.objects[0].id, oid);
422            }
423            _ => panic!("expected object data"),
424        }
425    }
426
427    #[test]
428    fn test_protocol_handle_object_data() {
429        let protocol = ReplicationProtocol::new();
430
431        let obj = GitObject::blob(b"world".to_vec());
432        let oid = obj.id;
433
434        let object_data = ObjectData {
435            repo_key: "carol/code".to_string(),
436            objects: vec![obj],
437        };
438
439        let peer_id = [0u8; 32];
440        let result = protocol
441            .handle_message(&peer_id, &object_data.encode())
442            .unwrap();
443
444        assert!(result.is_none());
445
446        // Verify object was stored
447        let repo = protocol.get_repo("carol/code").unwrap();
448        assert!(repo.objects.contains(&oid));
449    }
450
451    #[test]
452    fn test_protocol_broadcast() {
453        let mut protocol = ReplicationProtocol::new();
454        let handler = Arc::new(MockHandler::new());
455        protocol.set_handler(handler.clone());
456
457        protocol.broadcast_update("test/repo", vec![ObjectId::from_bytes([1u8; 20])], vec![]);
458
459        assert_eq!(handler.broadcast_count.load(Ordering::SeqCst), 1);
460        assert_eq!(handler.messages.read().len(), 1);
461    }
462}