Skip to main content

apfsds_transport/
exit_client.rs

1//! Exit node client for Handler → Exit communication
2//!
3//! Uses HTTP/2 + rkyv serialization for high performance.
4
5use crate::SharedPacketDispatcher;
6use apfsds_protocol::PlainPacket;
7use bytes::{Buf, Bytes, BytesMut};
8use futures::StreamExt;
9use reqwest::Client;
10use rkyv::rancor::Error as RkyvError;
11use std::sync::Arc;
12use std::time::Duration;
13use thiserror::Error;
14use tracing::{debug, error, info, trace, warn};
15
16/// Exit client errors
17#[derive(Error, Debug)]
18pub enum ExitClientError {
19    #[error("Connection failed: {0}")]
20    ConnectionFailed(String),
21
22    #[error("Request failed: {0}")]
23    RequestFailed(String),
24
25    #[error("Serialization error: {0}")]
26    SerializationError(String),
27
28    #[error("Timeout")]
29    Timeout,
30
31    #[error("Exit node unhealthy")]
32    Unhealthy,
33}
34
35/// Configuration for exit client
36#[derive(Debug, Clone)]
37pub struct ExitClientConfig {
38    /// Exit node base URL (e.g., "http://exit-1.internal:8081")
39    pub base_url: String,
40
41    /// Request timeout
42    pub timeout: Duration,
43
44    /// Enable HTTP/2
45    pub http2: bool,
46}
47
48impl Default for ExitClientConfig {
49    fn default() -> Self {
50        Self {
51            base_url: "http://127.0.0.1:8081".to_string(),
52            timeout: Duration::from_secs(10),
53            http2: true,
54        }
55    }
56}
57
58/// Client for communicating with exit nodes
59pub struct ExitClient {
60    client: Client,
61    config: ExitClientConfig,
62    healthy: std::sync::atomic::AtomicBool,
63}
64
65impl ExitClient {
66    /// Create a new exit client
67    pub fn new(config: ExitClientConfig) -> Result<Self, ExitClientError> {
68        let mut builder = Client::builder()
69            .timeout(config.timeout)
70            .pool_max_idle_per_host(10);
71
72        if config.http2 {
73            builder = builder.http2_prior_knowledge();
74        }
75
76        let client = builder
77            .build()
78            .map_err(|e| ExitClientError::ConnectionFailed(e.to_string()))?;
79
80        Ok(Self {
81            client,
82            config,
83            healthy: std::sync::atomic::AtomicBool::new(true),
84        })
85    }
86
87    /// Forward a packet to the exit node
88    pub async fn forward(&self, packet: &PlainPacket) -> Result<(), ExitClientError> {
89        if !self.is_healthy() {
90            return Err(ExitClientError::Unhealthy);
91        }
92
93        // Serialize with rkyv
94        let bytes = rkyv::to_bytes::<RkyvError>(packet)
95            .map_err(|e| ExitClientError::SerializationError(e.to_string()))?;
96
97        let url = format!("{}/forward", self.config.base_url);
98        trace!("Forwarding packet to {}", url);
99
100        let response = self
101            .client
102            .post(&url)
103            .header("Content-Type", "application/octet-stream")
104            .body(bytes.to_vec())
105            .send()
106            .await
107            .map_err(|e| {
108                self.mark_unhealthy();
109                ExitClientError::RequestFailed(e.to_string())
110            })?;
111
112        if !response.status().is_success() {
113            error!("Exit node returned error: {}", response.status());
114            return Err(ExitClientError::RequestFailed(format!(
115                "HTTP {}",
116                response.status()
117            )));
118        }
119
120        debug!("Packet forwarded successfully");
121        Ok(())
122    }
123
124    /// Subscribe to return traffic stream
125    pub fn subscribe(self: Arc<Self>, handler_id: u64, dispatcher: SharedPacketDispatcher) {
126        tokio::spawn(async move {
127            let url = format!("{}/stream?handler_id={}", self.config.base_url, handler_id);
128            let mut backoff = Duration::from_secs(1);
129
130            loop {
131                info!("Connecting to exit node stream at {}", url);
132                match self.client.get(&url).send().await {
133                    Ok(mut resp) => {
134                        if !resp.status().is_success() {
135                            warn!("Stream failed HTTP {}", resp.status());
136                            tokio::time::sleep(backoff).await;
137                            continue;
138                        }
139
140                        self.healthy
141                            .store(true, std::sync::atomic::Ordering::Relaxed);
142                        backoff = Duration::from_secs(1);
143
144                        // let mut stream = resp.bytes_stream();
145                        let mut buffer = BytesMut::new();
146
147                        loop {
148                            match resp.chunk().await {
149                                Ok(Some(chunk)) => {
150                                    buffer.extend_from_slice(&chunk);
151
152                                    // Process frames (Length + Payload)
153                                    loop {
154                                        if buffer.len() < 4 {
155                                            break;
156                                        }
157
158                                        let mut len_bytes = [0u8; 4];
159                                        len_bytes.copy_from_slice(&buffer[..4]);
160                                        let len = u32::from_le_bytes(len_bytes) as usize;
161
162                                        if buffer.len() < 4 + len {
163                                            break; // Wait for more data
164                                        }
165
166                                        // Consume header
167                                        buffer.advance(4);
168                                        // Extract payload
169                                        let payload = buffer.split_to(len);
170
171                                        // Deserialize PlainPacket
172                                        match rkyv::from_bytes::<PlainPacket, rkyv::rancor::Error>(
173                                            &payload,
174                                        ) {
175                                            Ok(packet) => {
176                                                dispatcher.dispatch(packet).await;
177                                            }
178                                            Err(e) => {
179                                                error!("Stream deserialization error: {}", e);
180                                            }
181                                        }
182                                    }
183                                }
184                                Ok(None) => {
185                                    break; // EOF
186                                }
187                                Err(e) => {
188                                    error!("Stream read error: {}", e);
189                                    break;
190                                }
191                            }
192                        }
193                        warn!("Stream disconnected");
194                    }
195                    Err(e) => {
196                        error!("Failed to connect stream: {}", e);
197                        self.mark_unhealthy();
198                    }
199                }
200
201                tokio::time::sleep(backoff).await;
202                backoff = std::cmp::min(backoff * 2, Duration::from_secs(30));
203            }
204        });
205    }
206
207    /// Check health of exit node
208    pub async fn health_check(&self) -> bool {
209        let url = format!("{}/health", self.config.base_url);
210
211        match self.client.get(&url).send().await {
212            Ok(resp) if resp.status().is_success() => {
213                self.healthy
214                    .store(true, std::sync::atomic::Ordering::Relaxed);
215                true
216            }
217            _ => {
218                self.healthy
219                    .store(false, std::sync::atomic::Ordering::Relaxed);
220                false
221            }
222        }
223    }
224
225    /// Check if client is marked healthy
226    pub fn is_healthy(&self) -> bool {
227        self.healthy.load(std::sync::atomic::Ordering::Relaxed)
228    }
229
230    /// Mark as unhealthy
231    fn mark_unhealthy(&self) {
232        self.healthy
233            .store(false, std::sync::atomic::Ordering::Relaxed);
234    }
235
236    /// Get base URL
237    pub fn base_url(&self) -> &str {
238        &self.config.base_url
239    }
240}
241
242/// Shared exit client
243pub type SharedExitClient = Arc<ExitClient>;
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_exit_client_config_default() {
251        let config = ExitClientConfig::default();
252        assert!(config.http2);
253        assert_eq!(config.timeout, Duration::from_secs(10));
254    }
255}