apfsds_transport/
exit_client.rs1use crate::SharedPacketDispatcher;
6use apfsds_protocol::PlainPacket;
7use bytes::{Buf, Bytes, BytesMut};
8use futures::StreamExt;
9use reqwest::Client;
10use rkyv::rancor::Error as RkyvError;
11use std::sync::Arc;
12use std::time::Duration;
13use thiserror::Error;
14use tracing::{debug, error, info, trace, warn};
15
16#[derive(Error, Debug)]
18pub enum ExitClientError {
19 #[error("Connection failed: {0}")]
20 ConnectionFailed(String),
21
22 #[error("Request failed: {0}")]
23 RequestFailed(String),
24
25 #[error("Serialization error: {0}")]
26 SerializationError(String),
27
28 #[error("Timeout")]
29 Timeout,
30
31 #[error("Exit node unhealthy")]
32 Unhealthy,
33}
34
35#[derive(Debug, Clone)]
37pub struct ExitClientConfig {
38 pub base_url: String,
40
41 pub timeout: Duration,
43
44 pub http2: bool,
46}
47
48impl Default for ExitClientConfig {
49 fn default() -> Self {
50 Self {
51 base_url: "http://127.0.0.1:8081".to_string(),
52 timeout: Duration::from_secs(10),
53 http2: true,
54 }
55 }
56}
57
58pub struct ExitClient {
60 client: Client,
61 config: ExitClientConfig,
62 healthy: std::sync::atomic::AtomicBool,
63}
64
65impl ExitClient {
66 pub fn new(config: ExitClientConfig) -> Result<Self, ExitClientError> {
68 let mut builder = Client::builder()
69 .timeout(config.timeout)
70 .pool_max_idle_per_host(10);
71
72 if config.http2 {
73 builder = builder.http2_prior_knowledge();
74 }
75
76 let client = builder
77 .build()
78 .map_err(|e| ExitClientError::ConnectionFailed(e.to_string()))?;
79
80 Ok(Self {
81 client,
82 config,
83 healthy: std::sync::atomic::AtomicBool::new(true),
84 })
85 }
86
87 pub async fn forward(&self, packet: &PlainPacket) -> Result<(), ExitClientError> {
89 if !self.is_healthy() {
90 return Err(ExitClientError::Unhealthy);
91 }
92
93 let bytes = rkyv::to_bytes::<RkyvError>(packet)
95 .map_err(|e| ExitClientError::SerializationError(e.to_string()))?;
96
97 let url = format!("{}/forward", self.config.base_url);
98 trace!("Forwarding packet to {}", url);
99
100 let response = self
101 .client
102 .post(&url)
103 .header("Content-Type", "application/octet-stream")
104 .body(bytes.to_vec())
105 .send()
106 .await
107 .map_err(|e| {
108 self.mark_unhealthy();
109 ExitClientError::RequestFailed(e.to_string())
110 })?;
111
112 if !response.status().is_success() {
113 error!("Exit node returned error: {}", response.status());
114 return Err(ExitClientError::RequestFailed(format!(
115 "HTTP {}",
116 response.status()
117 )));
118 }
119
120 debug!("Packet forwarded successfully");
121 Ok(())
122 }
123
124 pub fn subscribe(self: Arc<Self>, handler_id: u64, dispatcher: SharedPacketDispatcher) {
126 tokio::spawn(async move {
127 let url = format!("{}/stream?handler_id={}", self.config.base_url, handler_id);
128 let mut backoff = Duration::from_secs(1);
129
130 loop {
131 info!("Connecting to exit node stream at {}", url);
132 match self.client.get(&url).send().await {
133 Ok(mut resp) => {
134 if !resp.status().is_success() {
135 warn!("Stream failed HTTP {}", resp.status());
136 tokio::time::sleep(backoff).await;
137 continue;
138 }
139
140 self.healthy
141 .store(true, std::sync::atomic::Ordering::Relaxed);
142 backoff = Duration::from_secs(1);
143
144 let mut buffer = BytesMut::new();
146
147 loop {
148 match resp.chunk().await {
149 Ok(Some(chunk)) => {
150 buffer.extend_from_slice(&chunk);
151
152 loop {
154 if buffer.len() < 4 {
155 break;
156 }
157
158 let mut len_bytes = [0u8; 4];
159 len_bytes.copy_from_slice(&buffer[..4]);
160 let len = u32::from_le_bytes(len_bytes) as usize;
161
162 if buffer.len() < 4 + len {
163 break; }
165
166 buffer.advance(4);
168 let payload = buffer.split_to(len);
170
171 match rkyv::from_bytes::<PlainPacket, rkyv::rancor::Error>(
173 &payload,
174 ) {
175 Ok(packet) => {
176 dispatcher.dispatch(packet).await;
177 }
178 Err(e) => {
179 error!("Stream deserialization error: {}", e);
180 }
181 }
182 }
183 }
184 Ok(None) => {
185 break; }
187 Err(e) => {
188 error!("Stream read error: {}", e);
189 break;
190 }
191 }
192 }
193 warn!("Stream disconnected");
194 }
195 Err(e) => {
196 error!("Failed to connect stream: {}", e);
197 self.mark_unhealthy();
198 }
199 }
200
201 tokio::time::sleep(backoff).await;
202 backoff = std::cmp::min(backoff * 2, Duration::from_secs(30));
203 }
204 });
205 }
206
207 pub async fn health_check(&self) -> bool {
209 let url = format!("{}/health", self.config.base_url);
210
211 match self.client.get(&url).send().await {
212 Ok(resp) if resp.status().is_success() => {
213 self.healthy
214 .store(true, std::sync::atomic::Ordering::Relaxed);
215 true
216 }
217 _ => {
218 self.healthy
219 .store(false, std::sync::atomic::Ordering::Relaxed);
220 false
221 }
222 }
223 }
224
225 pub fn is_healthy(&self) -> bool {
227 self.healthy.load(std::sync::atomic::Ordering::Relaxed)
228 }
229
230 fn mark_unhealthy(&self) {
232 self.healthy
233 .store(false, std::sync::atomic::Ordering::Relaxed);
234 }
235
236 pub fn base_url(&self) -> &str {
238 &self.config.base_url
239 }
240}
241
242pub type SharedExitClient = Arc<ExitClient>;
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_exit_client_config_default() {
251 let config = ExitClientConfig::default();
252 assert!(config.http2);
253 assert_eq!(config.timeout, Duration::from_secs(10));
254 }
255}