use bytes::BytesMut;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, Notify};
use tracing::{debug, error};
use crate::error::{RbkError, RbkResult};
use crate::protocol::{RbkDecoder, encode_request};
pub(crate) struct RbkPortClient {
host: String,
port: u16,
state: Arc<Mutex<ClientState>>,
}
struct ClientState {
connection: Option<Connection>,
flow_no_counter: u16,
response_map: HashMap<u16, String>,
notify: Arc<Notify>,
disposed: bool,
}
struct Connection {
stream: TcpStream,
read_task: tokio::task::JoinHandle<()>,
}
impl RbkPortClient {
pub fn new(host: String, port: u16) -> Self {
Self {
host,
port,
state: Arc::new(Mutex::new(ClientState {
connection: None,
flow_no_counter: 0,
response_map: HashMap::new(),
notify: Arc::new(Notify::new()),
disposed: false,
})),
}
}
pub async fn request(
&self,
api_no: u16,
req_str: &str,
timeout: Duration,
) -> RbkResult<String> {
let result = self.do_request(api_no, req_str, timeout).await;
if let Err(ref e) = result {
debug!(
"Request failed (API {}), resetting client: {:?}",
api_no, e
);
self.reset().await;
}
result
}
async fn do_request(
&self,
api_no: u16,
req_str: &str,
timeout: Duration,
) -> RbkResult<String> {
let mut state = self.state.lock().await;
if state.disposed {
return Err(RbkError::Disposed);
}
if state.connection.is_none() {
drop(state);
self.connect().await?;
state = self.state.lock().await;
}
let flow_no = state.next_flow_no();
let notify = state.notify.clone();
let request_bytes = encode_request(api_no, req_str, flow_no);
if let Some(ref mut conn) = state.connection {
conn.stream.write_all(&request_bytes).await.map_err(|e| {
error!("Write error for API {}: {}", api_no, e.kind());
RbkError::WriteError(e.to_string())
})?;
}
drop(state);
tokio::time::timeout(timeout, async {
loop {
notify.notified().await;
let mut state = self.state.lock().await;
if state.disposed {
return Err(RbkError::Disposed);
}
if let Some(res_str) = state.response_map.remove(&flow_no) {
return Ok(res_str);
}
}
})
.await
.map_err(|_| RbkError::Timeout)?
}
async fn connect(&self) -> RbkResult<()> {
let addr = format!("{}:{}", self.host, self.port);
let stream = tokio::time::timeout(
Duration::from_secs(10),
TcpStream::connect(&addr),
)
.await
.map_err(|_| RbkError::Timeout)?
.map_err(|e| RbkError::ConnectionFailed(e.to_string()))?;
let state_clone = self.state.clone();
let read_task = tokio::spawn(async move {
read_loop(state_clone).await;
});
let mut state = self.state.lock().await;
state.connection = Some(Connection { stream, read_task });
state.disposed = false;
Ok(())
}
async fn reset(&self) {
let mut state = self.state.lock().await;
state.response_map.clear();
state.disposed = true;
if let Some(mut conn) = state.connection.take() {
conn.read_task.abort();
let _ = conn.stream.shutdown().await;
}
state.notify.notify_waiters();
}
}
impl Drop for RbkPortClient {
fn drop(&mut self) {
}
}
impl ClientState {
fn next_flow_no(&mut self) -> u16 {
self.flow_no_counter = (self.flow_no_counter + 1) % 512;
self.flow_no_counter
}
}
async fn read_loop(state: Arc<Mutex<ClientState>>) {
let mut decoder = RbkDecoder::new();
let mut buf = BytesMut::with_capacity(4096);
let mut read_buf = vec![0u8; 4096];
loop {
let mut state_lock = state.lock().await;
let has_connection = state_lock.connection.is_some();
if !has_connection {
break;
}
let mut conn = match state_lock.connection.take() {
Some(c) => c,
None => break,
};
drop(state_lock);
match conn.stream.read(&mut read_buf).await {
Ok(0) => {
break;
}
Ok(n) => {
buf.extend_from_slice(&read_buf[..n]);
while let Some(frame) = decoder.decode(&mut buf) {
let mut state = state.lock().await;
state.response_map.insert(frame.flow_no, frame.body);
state.notify.notify_waiters();
}
let mut state = state.lock().await;
state.connection = Some(conn);
}
Err(e) => {
error!("Read error: {}", e);
break;
}
}
}
}