use super::handler::OracleProtocolHandler;
use super::tns::TnsPacket;
use super::DEFAULT_ORACLE_PORT;
use crate::{Result, Error, storage::StorageEngine};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Semaphore;
#[derive(Debug, Clone)]
pub struct OracleServerConfig {
pub listen_addr: String,
pub port: u16,
pub max_connections: usize,
}
impl Default for OracleServerConfig {
fn default() -> Self {
Self {
listen_addr: "127.0.0.1".to_string(),
port: DEFAULT_ORACLE_PORT,
max_connections: 100,
}
}
}
pub struct OracleServer {
config: OracleServerConfig,
storage: Arc<StorageEngine>,
connection_limiter: Arc<Semaphore>,
}
impl OracleServer {
pub fn new(storage: Arc<StorageEngine>, config: OracleServerConfig) -> Self {
let connection_limiter = Arc::new(Semaphore::new(config.max_connections));
Self {
config,
storage,
connection_limiter,
}
}
pub async fn start(self) -> Result<()> {
let addr = format!("{}:{}", self.config.listen_addr, self.config.port);
tracing::info!("Starting Oracle TNS server on {}", addr);
let listener = TcpListener::bind(&addr)
.await
.map_err(|e| Error::io(format!("Failed to bind to {}: {}", addr, e)))?;
tracing::info!("Oracle TNS server listening on {}", addr);
loop {
let permit = self.connection_limiter.clone().acquire_owned().await
.map_err(|e| Error::io(format!("Semaphore error: {}", e)))?;
let (socket, peer_addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr),
Err(e) => {
tracing::error!("Failed to accept connection: {}", e);
continue;
}
};
tracing::info!("New Oracle connection from {}", peer_addr);
let storage = self.storage.clone();
tokio::spawn(async move {
let result = handle_connection(socket, storage).await;
if let Err(e) = result {
tracing::error!("Connection error from {}: {}", peer_addr, e);
}
drop(permit);
tracing::info!("Connection closed from {}", peer_addr);
});
}
}
pub fn config(&self) -> &OracleServerConfig {
&self.config
}
}
async fn handle_connection(mut socket: TcpStream, storage: Arc<StorageEngine>) -> Result<()> {
let mut handler = OracleProtocolHandler::new(storage);
let mut buffer = vec![0u8; 65536];
loop {
let n = socket.read(&mut buffer).await
.map_err(|e| Error::io(format!("Failed to read from socket: {}", e)))?;
if n == 0 {
break;
}
tracing::debug!("Received {} bytes from client", n);
let recv_data = buffer.get(..n).ok_or_else(|| Error::io("Buffer read out of bounds"))?;
let packet = match TnsPacket::parse(recv_data) {
Ok(pkt) => pkt,
Err(e) => {
tracing::error!("Failed to parse TNS packet: {}", e);
continue;
}
};
tracing::debug!("Received TNS packet: type={:?}", packet.header.packet_type);
let response_packets = match handler.handle_packet(packet) {
Ok(packets) => packets,
Err(e) => {
tracing::error!("Handler error: {}", e);
let error_msg = format!("Error: {}", e);
vec![TnsPacket::refuse(1, error_msg)]
}
};
for response in response_packets {
let encoded = response.encode();
socket.write_all(&encoded).await
.map_err(|e| Error::io(format!("Failed to write to socket: {}", e)))?;
tracing::debug!("Sent {} bytes to client", encoded.len());
}
if handler.is_closed() {
break;
}
}
Ok(())
}
pub async fn start_oracle_server(
storage: Arc<StorageEngine>,
config: OracleServerConfig,
) -> Result<()> {
let server = OracleServer::new(storage, config);
server.start().await
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::Config;
#[test]
fn test_server_config_default() {
let config = OracleServerConfig::default();
assert_eq!(config.port, 1521);
assert_eq!(config.max_connections, 100);
}
#[test]
fn test_server_creation() {
let config = Config::in_memory();
let storage = StorageEngine::open_in_memory(&config).unwrap();
let server_config = OracleServerConfig::default();
let server = OracleServer::new(Arc::new(storage), server_config);
assert_eq!(server.config().port, 1521);
}
#[tokio::test]
async fn test_server_bind_error() {
let config = Config::in_memory();
let storage = StorageEngine::open_in_memory(&config).unwrap();
let server_config = OracleServerConfig {
listen_addr: "999.999.999.999".to_string(),
port: 1521,
max_connections: 10,
};
let server = OracleServer::new(Arc::new(storage), server_config);
let result = tokio::time::timeout(
std::time::Duration::from_millis(100),
server.start()
).await;
assert!(result.is_err() || result.unwrap().is_err());
}
}