common/peer/jax_protocol/
handler.rs1use std::sync::Arc;
2
3use anyhow::anyhow;
4use futures::future::BoxFuture;
5use iroh::endpoint::Connection;
6use iroh::protocol::AcceptError;
7
8use super::messages::{FetchBucketResponse, PingResponse, Request, Response};
9use super::state::BucketStateProvider;
10
11pub const JAX_ALPN: &[u8] = b"/iroh-jax/1";
13
14#[derive(Clone, Debug)]
18pub struct JaxProtocol {
19 state: Arc<dyn BucketStateProvider>,
20}
21
22impl JaxProtocol {
23 pub fn new(state: Arc<dyn BucketStateProvider>) -> Self {
25 Self { state }
26 }
27
28 pub fn handle_connection(
33 self,
34 conn: Connection,
35 ) -> BoxFuture<'static, Result<(), AcceptError>> {
36 Box::pin(async move {
37 tracing::debug!(
38 "JAX handler: Accepted new connection from {:?}",
39 conn.remote_node_id()
40 );
41
42 tracing::debug!("JAX handler: Accepting bidirectional stream");
44 let (mut send, mut recv) = conn.accept_bi().await.map_err(|e| {
45 tracing::error!("JAX handler: Failed to accept bidirectional stream: {}", e);
46 AcceptError::from(e)
47 })?;
48 tracing::debug!("JAX handler: Bidirectional stream accepted");
49
50 let remote_node_id = conn.remote_node_id().map(|id| id.to_string());
52
53 tracing::debug!("JAX handler: Reading request from stream");
55 let request_bytes = recv.read_to_end(1024 * 1024).await.map_err(|e| {
56 tracing::error!("JAX handler: Failed to read request: {}", e);
57 AcceptError::from(std::io::Error::other(e))
58 })?; tracing::debug!(
60 "JAX handler: Read {} bytes from stream",
61 request_bytes.len()
62 );
63
64 tracing::debug!("JAX handler: Deserializing request");
65 let request: Request = bincode::deserialize(&request_bytes).map_err(|e| {
66 tracing::error!("JAX handler: Failed to deserialize request: {}", e);
67 let err: Box<dyn std::error::Error + Send + Sync> =
68 anyhow!("Failed to deserialize request: {}", e).into();
69 AcceptError::from(err)
70 })?;
71 tracing::debug!("JAX handler: Successfully deserialized request");
72
73 match request {
75 Request::Ping(ping_req) => {
76 tracing::info!(
77 "JAX handler: Received ping request for bucket {} with link {:?}",
78 ping_req.bucket_id,
79 ping_req.current_link
80 );
81
82 tracing::debug!("JAX handler: Checking bucket sync status");
84 let status = self
85 .state
86 .check_bucket_sync(ping_req.bucket_id, &ping_req.current_link)
87 .await
88 .unwrap_or_else(|e| {
89 tracing::error!("JAX handler: Error checking bucket sync: {}", e);
90 super::messages::SyncStatus::NotFound
91 });
92 tracing::debug!("JAX handler: Bucket sync status: {:?}", status);
93
94 let response = Response::Ping(PingResponse::new(status));
95 tracing::debug!("JAX handler: Created ping response");
96
97 tracing::debug!("JAX handler: Serializing ping response");
99 let response_bytes = bincode::serialize(&response).map_err(|e| {
100 tracing::error!("JAX handler: Failed to serialize response: {}", e);
101 let err: Box<dyn std::error::Error + Send + Sync> =
102 anyhow!("Failed to serialize response: {}", e).into();
103 AcceptError::from(err)
104 })?;
105 tracing::debug!(
106 "JAX handler: Serialized response to {} bytes",
107 response_bytes.len()
108 );
109
110 tracing::debug!("JAX handler: Writing response to stream");
111 send.write_all(&response_bytes).await.map_err(|e| {
112 tracing::error!("JAX handler: Failed to write response: {}", e);
113 AcceptError::from(std::io::Error::other(e))
114 })?;
115
116 tracing::debug!("JAX handler: Finishing send stream");
117 send.finish().map_err(|e| {
118 tracing::error!("JAX handler: Failed to finish send stream: {}", e);
119 AcceptError::from(std::io::Error::other(e))
120 })?;
121
122 tracing::debug!("JAX handler: Waiting for connection to close");
123 conn.closed().await;
124
125 tracing::info!(
126 "JAX handler: Successfully sent ping response: {:?}",
127 response
128 );
129 }
130
131 Request::FetchBucket(fetch_req) => {
132 tracing::debug!(
133 "Received fetch bucket request for bucket {}",
134 fetch_req.bucket_id
135 );
136
137 let current_link = self
139 .state
140 .get_bucket_link(fetch_req.bucket_id)
141 .await
142 .unwrap_or_else(|e| {
143 tracing::error!("Error fetching bucket link: {}", e);
144 None
145 });
146
147 let response = Response::FetchBucket(FetchBucketResponse::new(current_link));
148
149 let response_bytes = bincode::serialize(&response).map_err(|e| {
151 let err: Box<dyn std::error::Error + Send + Sync> =
152 anyhow!("Failed to serialize response: {}", e).into();
153 AcceptError::from(err)
154 })?;
155
156 send.write_all(&response_bytes)
157 .await
158 .map_err(|e| AcceptError::from(std::io::Error::other(e)))?;
159
160 send.finish()
161 .map_err(|e| AcceptError::from(std::io::Error::other(e)))?;
162
163 conn.closed().await;
164
165 tracing::debug!("Sent fetch bucket response: {:?}", response);
166 }
167
168 Request::Announce(announce_msg) => {
169 let peer_id = remote_node_id.unwrap_or_else(|_| "unknown".to_string());
170
171 tracing::info!(
172 "Received announce from peer {} for bucket {} with new link {:?}",
173 peer_id,
174 announce_msg.bucket_id,
175 announce_msg.new_link
176 );
177
178 if let Err(e) = self
180 .state
181 .handle_announce(
182 announce_msg.bucket_id,
183 peer_id,
184 announce_msg.new_link,
185 announce_msg.previous_link,
186 )
187 .await
188 {
189 tracing::error!("Error handling announce: {}", e);
190 }
191
192 send.finish()
194 .map_err(|e| AcceptError::from(std::io::Error::other(e)))?;
195 }
196 }
197
198 Ok(())
199 })
200 }
201}
202
203impl iroh::protocol::ProtocolHandler for JaxProtocol {
206 #[allow(refining_impl_trait)]
207 fn accept(
208 &self,
209 conn: iroh::endpoint::Connection,
210 ) -> BoxFuture<'static, Result<(), AcceptError>> {
211 let this = self.clone();
212 this.handle_connection(conn)
213 }
214}