Skip to main content

aurora_db/network/
server.rs

1use crate::db::Aurora;
2use crate::error::{AqlError, Result};
3use crate::network::protocol::Request;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7
8pub struct BincodeServer {
9    db: Arc<Aurora>,
10    addr: String,
11}
12
13impl BincodeServer {
14    pub fn new(db: Arc<Aurora>, addr: &str) -> Self {
15        Self {
16            db,
17            addr: addr.to_string(),
18        }
19    }
20
21    pub async fn run(&self) -> Result<()> {
22        let listener = TcpListener::bind(&self.addr).await?;
23        println!("Bincode server listening on {}", self.addr);
24
25        loop {
26            let (stream, _) = listener.accept().await?;
27            let db_clone = self.db.clone();
28            tokio::spawn(async move {
29                if let Err(e) = Self::handle_bincode_connection(stream, db_clone).await {
30                    eprintln!("Error handling bincode connection: {}", e);
31                }
32            });
33        }
34    }
35
36    async fn handle_bincode_connection(mut stream: TcpStream, db: Arc<Aurora>) -> Result<()> {
37        loop {
38            let mut len_bytes = [0u8; 4];
39            match stream.read_exact(&mut len_bytes).await {
40                Ok(_) => (),
41                Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
42                    // Client disconnected
43                    break;
44                }
45                Err(e) => return Err(e.into()),
46            }
47
48            let len = u32::from_le_bytes(len_bytes) as usize;
49            const MAX_FRAME_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
50            if len > MAX_FRAME_SIZE {
51                return Err(AqlError::new(
52                    crate::error::ErrorCode::ProtocolError,
53                    format!("Frame too large: {} bytes", len),
54                ));
55            }
56            let mut buffer = vec![0u8; len];
57            stream.read_exact(&mut buffer).await?;
58
59            let request: Request =
60                bincode::deserialize(&buffer).map_err(AqlError::from)?;
61
62            let response = db.process_network_request(request).await;
63
64            let response_bytes = bincode::serialize(&response).map_err(AqlError::from)?;
65            let len_bytes = (response_bytes.len() as u32).to_le_bytes();
66
67            stream.write_all(&len_bytes).await?;
68            stream.write_all(&response_bytes).await?;
69        }
70        Ok(())
71    }
72}