use std::sync::Arc;
use std::time::Duration;
use serde_json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{debug, error, info};
use crate::error::{ProxyError, ProxyResult};
use crate::proxy::{BackendConnector, IdTranslator};
#[derive(Debug, Clone)]
pub struct TcpFrontendConfig {
pub bind: String,
pub timeout: Duration,
pub max_request_size: usize,
}
impl TcpFrontendConfig {
#[must_use]
pub fn new(bind: impl Into<String>, timeout: Duration, max_request_size: usize) -> Self {
Self {
bind: bind.into(),
timeout,
max_request_size,
}
}
}
pub struct TcpFrontend {
config: TcpFrontendConfig,
backend: BackendConnector,
id_translator: Arc<IdTranslator>,
}
impl TcpFrontend {
#[must_use]
pub fn new(
config: TcpFrontendConfig,
backend: BackendConnector,
id_translator: Arc<IdTranslator>,
) -> Self {
Self {
config,
backend,
id_translator,
}
}
pub async fn run(&self) -> ProxyResult<()> {
let listener = TcpListener::bind(&self.config.bind).await.map_err(|e| {
ProxyError::backend_connection(format!(
"Failed to bind TCP listener to {}: {}",
self.config.bind, e
))
})?;
let addr = listener.local_addr().map_err(|e| {
ProxyError::backend_connection(format!("Failed to get listener address: {e}"))
})?;
info!("TCP frontend listening on {}", addr);
loop {
let (socket, peer_addr) = listener
.accept()
.await
.map_err(|e| ProxyError::backend_connection(format!("TCP accept error: {e}")))?;
debug!("Accepted TCP connection from {}", peer_addr);
let backend = self.backend.clone();
let id_translator = Arc::clone(&self.id_translator);
let timeout = self.config.timeout;
let max_request_size = self.config.max_request_size;
tokio::spawn(async move {
if let Err(e) =
handle_connection(socket, backend, id_translator, timeout, max_request_size)
.await
{
error!("TCP connection error from {}: {}", peer_addr, e);
}
});
}
}
}
async fn handle_connection(
mut socket: TcpStream,
_backend: BackendConnector,
_id_translator: Arc<IdTranslator>,
timeout: Duration,
max_request_size: usize,
) -> ProxyResult<()> {
let mut buf = vec![0; max_request_size];
let mut read_pos = 0;
loop {
let _n = match tokio::time::timeout(timeout, socket.read(&mut buf[read_pos..])).await {
Ok(Ok(0)) => {
debug!("TCP client closed connection");
break;
}
Ok(Ok(n)) => {
read_pos += n;
n
}
Ok(Err(e)) => {
error!("TCP read error: {}", e);
break;
}
Err(_) => {
error!("TCP read timeout");
break;
}
};
if let Some(line_end) = buf[..read_pos].windows(1).position(|w| w == b"\n") {
let line = &buf[..line_end];
match serde_json::from_slice::<serde_json::Value>(line) {
Ok(message) => {
debug!("Received JSON-RPC message: {}", message);
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {
"code": -32603,
"message": "TCP frontend not yet fully implemented"
}
});
let response_bytes = serde_json::to_vec(&response)?;
socket.write_all(&response_bytes).await.map_err(|e| {
ProxyError::backend_connection(format!("TCP write error: {e}"))
})?;
socket.write_all(b"\n").await.map_err(|e| {
ProxyError::backend_connection(format!("TCP write error: {e}"))
})?;
read_pos = 0;
}
Err(e) => {
debug!("Failed to parse JSON-RPC: {}", e);
read_pos = 0; }
}
} else if read_pos >= max_request_size {
error!("TCP message exceeds maximum size");
break;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_frontend_config() {
let config =
TcpFrontendConfig::new("127.0.0.1:5000", Duration::from_secs(30), 10 * 1024 * 1024);
assert_eq!(config.bind, "127.0.0.1:5000");
assert_eq!(config.timeout, Duration::from_secs(30));
assert_eq!(config.max_request_size, 10 * 1024 * 1024);
}
}