use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::select;
use tokio::sync::RwLock;
use crate::codec::json_rpc;
use crate::codec::json_rpc::DapMessage;
use crate::config::Config;
use crate::error::DapzError;
use crate::interceptors::InterceptorChain;
use crate::transport::Transport;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
ClientToServer,
ServerToClient,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum State {
Created,
Initializing,
Ready,
ShuttingDown,
Exited,
}
#[allow(dead_code)]
pub struct Proxy {
config: Arc<RwLock<Config>>,
state: State,
transport: Box<dyn Transport>,
interceptor_chain: InterceptorChain,
next_seq: i64,
}
impl Proxy {
pub fn new(
config: Arc<RwLock<Config>>,
transport: Box<dyn Transport>,
interceptor_chain: InterceptorChain,
) -> Self {
Self {
config,
state: State::Created,
transport,
interceptor_chain,
next_seq: 1,
}
}
pub fn state(&self) -> State {
self.state
}
pub async fn start(&mut self) -> Result<(), DapzError> {
self.state = State::Initializing;
tracing::info!("Proxy starting (handshake)");
self.perform_handshake().await?;
self.state = State::Ready;
tracing::info!("Proxy ready, entering message loop");
self.message_loop().await
}
async fn perform_handshake(&mut self) -> Result<(), DapzError> {
let mut stdin = BufReader::new(tokio::io::stdin());
let mut stdout = tokio::io::stdout();
let init_req = read_stdin_frame(&mut stdin).await?;
self.transport.send(&init_req).await?;
tracing::debug!("Forwarded 'initialize' request to server");
let init_resp = self.transport.receive().await?;
stdout.write_all(&init_resp).await?;
stdout.flush().await?;
tracing::debug!("Forwarded 'initialize' response to client");
let launch_req = read_stdin_frame(&mut stdin).await?;
self.transport.send(&launch_req).await?;
tracing::debug!("Forwarded launch/attach request to server");
let launch_resp = self.transport.receive().await?;
stdout.write_all(&launch_resp).await?;
stdout.flush().await?;
tracing::debug!("Forwarded launch/attach response to client");
let config_done = read_stdin_frame(&mut stdin).await?;
self.transport.send(&config_done).await?;
tracing::debug!("Forwarded 'configurationDone' request");
let config_resp = self.transport.receive().await?;
stdout.write_all(&config_resp).await?;
stdout.flush().await?;
tracing::debug!("Configuration done, handshake complete");
Ok(())
}
async fn message_loop(&mut self) -> Result<(), DapzError> {
let mut stdin = BufReader::new(tokio::io::stdin());
let mut stdout = tokio::io::stdout();
loop {
select! {
client_msg = read_stdin_frame(&mut stdin) => {
let msg_bytes = match client_msg {
Ok(bytes) => bytes,
Err(DapzError::ServerExited) => {
tracing::info!("Client stdin closed, shutting down");
self.state = State::Exited;
return Ok(());
}
Err(e) => {
tracing::error!(error = %e, "Error reading from client");
return Err(e);
}
};
if is_disconnect(&msg_bytes) {
tracing::info!("Received 'disconnect' from client");
self.transport.send(&msg_bytes).await?;
self.state = State::ShuttingDown;
let resp = self.transport.receive().await?;
stdout.write_all(&resp).await?;
stdout.flush().await?;
self.state = State::Exited;
return Ok(());
}
self.transport.send(&msg_bytes).await?;
}
server_msg = self.transport.receive() => {
let msg_bytes = match server_msg {
Ok(bytes) => bytes,
Err(DapzError::ServerExited) => {
tracing::warn!("Server exited unexpectedly");
self.state = State::Exited;
return Err(DapzError::ServerExited);
}
Err(e) => {
tracing::error!(error = %e, "Error reading from server");
return Err(e);
}
};
let processed = self.process_server_message(&msg_bytes).await;
if !processed.is_empty() {
stdout.write_all(&processed).await?;
stdout.flush().await?;
}
}
}
}
}
async fn process_server_message(&mut self, raw: &[u8]) -> Vec<u8> {
let msg = match DapMessage::from_frame(raw) {
Ok(m) => m,
Err(_) => return raw.to_vec(),
};
let direction = Direction::ServerToClient;
let processed = match self.interceptor_chain.process(msg, direction).await {
Ok(Some(m)) => m,
Ok(None) => return Vec::new(),
Err(_) => return raw.to_vec(),
};
match processed.to_bytes() {
Ok(bytes) => bytes,
Err(_) => raw.to_vec(),
}
}
}
fn is_disconnect(raw: &[u8]) -> bool {
if let Ok(msg) = DapMessage::from_frame(raw) {
return msg.msg_type == "request" && msg.command.as_deref() == Some("disconnect");
}
false
}
async fn read_stdin_frame(reader: &mut BufReader<tokio::io::Stdin>) -> Result<Vec<u8>, DapzError> {
let mut header = String::new();
loop {
let mut line = String::new();
let n = reader.read_line(&mut line).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
DapzError::ServerExited
} else {
DapzError::Io(e)
}
})?;
if n == 0 {
return Err(DapzError::ServerExited);
}
header.push_str(&line);
if line == "\r\n" || line == "\n" {
break;
}
}
let content_length = json_rpc::parse_content_length(&header)?;
let mut body = vec![0u8; content_length as usize];
reader.read_exact(&mut body).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
DapzError::ServerExited
} else {
DapzError::Io(e)
}
})?;
Ok([header.as_bytes(), &body].concat())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::transport::mock::MockTransport;
#[tokio::test]
async fn test_new_proxy_state() {
let config = Arc::new(RwLock::new(Config {
backend_cmd: "test".into(),
..Default::default()
}));
let transport = Box::new(MockTransport::new());
let chain = InterceptorChain::new(vec![], config.clone());
let proxy = Proxy::new(config, transport, chain);
assert_eq!(proxy.state(), State::Created);
}
}