common/peer/jax_protocol/
handler.rs

1use std::sync::Arc;
2
3use anyhow::anyhow;
4use futures::future::BoxFuture;
5use iroh::endpoint::Connection;
6use iroh::protocol::AcceptError;
7
8use super::messages::{FetchBucketResponse, PingResponse, Request, Response};
9use super::state::BucketStateProvider;
10
11/// ALPN identifier for the JAX protocol
12pub const JAX_ALPN: &[u8] = b"/iroh-jax/1";
13
14/// Protocol handler for the JAX protocol
15///
16/// Accepts incoming connections and handles ping requests
17#[derive(Clone, Debug)]
18pub struct JaxProtocol {
19    state: Arc<dyn BucketStateProvider>,
20}
21
22impl JaxProtocol {
23    /// Create a new JAX protocol handler with the given state provider
24    pub fn new(state: Arc<dyn BucketStateProvider>) -> Self {
25        Self { state }
26    }
27
28    /// Handle an incoming connection
29    ///
30    /// This is called by the iroh router for each incoming connection
31    /// with the JAX ALPN.
32    pub fn handle_connection(
33        self,
34        conn: Connection,
35    ) -> BoxFuture<'static, Result<(), AcceptError>> {
36        Box::pin(async move {
37            tracing::debug!(
38                "JAX handler: Accepted new connection from {:?}",
39                conn.remote_node_id()
40            );
41
42            // Accept the first bidirectional stream from the connection
43            tracing::debug!("JAX handler: Accepting bidirectional stream");
44            let (mut send, mut recv) = conn.accept_bi().await.map_err(|e| {
45                tracing::error!("JAX handler: Failed to accept bidirectional stream: {}", e);
46                AcceptError::from(e)
47            })?;
48            tracing::debug!("JAX handler: Bidirectional stream accepted");
49
50            // Get remote peer ID for announce handling
51            let remote_node_id = conn.remote_node_id().map(|id| id.to_string());
52
53            // Read the request from the stream
54            tracing::debug!("JAX handler: Reading request from stream");
55            let request_bytes = recv.read_to_end(1024 * 1024).await.map_err(|e| {
56                tracing::error!("JAX handler: Failed to read request: {}", e);
57                AcceptError::from(std::io::Error::other(e))
58            })?; // 1MB limit
59            tracing::debug!(
60                "JAX handler: Read {} bytes from stream",
61                request_bytes.len()
62            );
63
64            tracing::debug!("JAX handler: Deserializing request");
65            let request: Request = bincode::deserialize(&request_bytes).map_err(|e| {
66                tracing::error!("JAX handler: Failed to deserialize request: {}", e);
67                let err: Box<dyn std::error::Error + Send + Sync> =
68                    anyhow!("Failed to deserialize request: {}", e).into();
69                AcceptError::from(err)
70            })?;
71            tracing::debug!("JAX handler: Successfully deserialized request");
72
73            // Dispatch based on request type
74            match request {
75                Request::Ping(ping_req) => {
76                    tracing::info!(
77                        "JAX handler: Received ping request for bucket {} with link {:?}",
78                        ping_req.bucket_id,
79                        ping_req.current_link
80                    );
81
82                    // Check the bucket sync status using the state provider
83                    tracing::debug!("JAX handler: Checking bucket sync status");
84                    let status = self
85                        .state
86                        .check_bucket_sync(ping_req.bucket_id, &ping_req.current_link)
87                        .await
88                        .unwrap_or_else(|e| {
89                            tracing::error!("JAX handler: Error checking bucket sync: {}", e);
90                            super::messages::SyncStatus::NotFound
91                        });
92                    tracing::debug!("JAX handler: Bucket sync status: {:?}", status);
93
94                    let response = Response::Ping(PingResponse::new(status));
95                    tracing::debug!("JAX handler: Created ping response");
96
97                    // Serialize and send the response
98                    tracing::debug!("JAX handler: Serializing ping response");
99                    let response_bytes = bincode::serialize(&response).map_err(|e| {
100                        tracing::error!("JAX handler: Failed to serialize response: {}", e);
101                        let err: Box<dyn std::error::Error + Send + Sync> =
102                            anyhow!("Failed to serialize response: {}", e).into();
103                        AcceptError::from(err)
104                    })?;
105                    tracing::debug!(
106                        "JAX handler: Serialized response to {} bytes",
107                        response_bytes.len()
108                    );
109
110                    tracing::debug!("JAX handler: Writing response to stream");
111                    send.write_all(&response_bytes).await.map_err(|e| {
112                        tracing::error!("JAX handler: Failed to write response: {}", e);
113                        AcceptError::from(std::io::Error::other(e))
114                    })?;
115
116                    tracing::debug!("JAX handler: Finishing send stream");
117                    send.finish().map_err(|e| {
118                        tracing::error!("JAX handler: Failed to finish send stream: {}", e);
119                        AcceptError::from(std::io::Error::other(e))
120                    })?;
121
122                    tracing::debug!("JAX handler: Waiting for connection to close");
123                    conn.closed().await;
124
125                    tracing::info!(
126                        "JAX handler: Successfully sent ping response: {:?}",
127                        response
128                    );
129                }
130
131                Request::FetchBucket(fetch_req) => {
132                    tracing::debug!(
133                        "Received fetch bucket request for bucket {}",
134                        fetch_req.bucket_id
135                    );
136
137                    // Get the current bucket link
138                    let current_link = self
139                        .state
140                        .get_bucket_link(fetch_req.bucket_id)
141                        .await
142                        .unwrap_or_else(|e| {
143                            tracing::error!("Error fetching bucket link: {}", e);
144                            None
145                        });
146
147                    let response = Response::FetchBucket(FetchBucketResponse::new(current_link));
148
149                    // Serialize and send the response
150                    let response_bytes = bincode::serialize(&response).map_err(|e| {
151                        let err: Box<dyn std::error::Error + Send + Sync> =
152                            anyhow!("Failed to serialize response: {}", e).into();
153                        AcceptError::from(err)
154                    })?;
155
156                    send.write_all(&response_bytes)
157                        .await
158                        .map_err(|e| AcceptError::from(std::io::Error::other(e)))?;
159
160                    send.finish()
161                        .map_err(|e| AcceptError::from(std::io::Error::other(e)))?;
162
163                    conn.closed().await;
164
165                    tracing::debug!("Sent fetch bucket response: {:?}", response);
166                }
167
168                Request::Announce(announce_msg) => {
169                    let peer_id = remote_node_id.unwrap_or_else(|_| "unknown".to_string());
170
171                    tracing::info!(
172                        "Received announce from peer {} for bucket {} with new link {:?}",
173                        peer_id,
174                        announce_msg.bucket_id,
175                        announce_msg.new_link
176                    );
177
178                    // Handle the announce message (triggers sync event)
179                    if let Err(e) = self
180                        .state
181                        .handle_announce(
182                            announce_msg.bucket_id,
183                            peer_id,
184                            announce_msg.new_link,
185                            announce_msg.previous_link,
186                        )
187                        .await
188                    {
189                        tracing::error!("Error handling announce: {}", e);
190                    }
191
192                    // No response needed for announce - just finish the stream
193                    send.finish()
194                        .map_err(|e| AcceptError::from(std::io::Error::other(e)))?;
195                }
196            }
197
198            Ok(())
199        })
200    }
201}
202
203// Implement the iroh protocol handler trait
204// This allows the router to accept connections for this protocol
205impl iroh::protocol::ProtocolHandler for JaxProtocol {
206    #[allow(refining_impl_trait)]
207    fn accept(
208        &self,
209        conn: iroh::endpoint::Connection,
210    ) -> BoxFuture<'static, Result<(), AcceptError>> {
211        let this = self.clone();
212        this.handle_connection(conn)
213    }
214}