use crate::{error, request, response};
use http_types::headers;
use std::{cell, collections, rc, time};
#[cfg(not(feature = "picodata_tarantool"))]
use tarantool::{
fiber::{
self,
r#async::timeout::{self, IntoTimeout},
},
network::client::tcp,
time as ttime,
};
#[cfg(feature = "picodata_tarantool")]
use picodata_tarantool::system::tarantool::{
fiber::{
self,
r#async::timeout::{self, IntoTimeout},
},
network::client::tcp,
time as ttime,
};
type IdleConns = fiber::Mutex<Vec<Inner>>;
type MappedPool =
fiber::Mutex<collections::HashMap<String, collections::HashMap<u16, rc::Rc<Container>>>>;
type WaitQueue = fiber::Mutex<collections::LinkedList<fiber::Channel<Inner>>>;
#[derive(Debug)]
struct Inner {
host: String,
port: u16,
ttl: time::Duration,
stream: tcp::TcpStream,
created: ttime::Instant,
connect_timeout: time::Duration,
}
impl Inner {
fn try_new(
host: String,
port: u16,
ttl: time::Duration,
connect_timeout: time::Duration,
) -> Result<Self, Box<error::Error>> {
let created = ttime::Instant::now_fiber();
let stream = tcp::TcpStream::connect_timeout(&host, port, connect_timeout)
.map_err(|e| Box::new(error::Error::TCP(e)))?;
Ok(Self {
host,
port,
ttl,
stream,
created,
connect_timeout,
})
}
fn recreate(self) -> Result<Self, Box<error::Error>> {
drop(self.stream);
Self::try_new(self.host, self.port, self.ttl, self.connect_timeout)
}
fn is_expired(&self) -> bool {
self.created.elapsed() > self.ttl
}
async fn execute(
&self,
request: request::Request,
) -> Result<response::Response, Box<error::Error>> {
let (mut request, headers, body, tls_timeout, request_timeout, response_timeout) =
request.pieces();
for (key, value) in headers {
request.insert_header(key, &value);
}
if let Some(b) = body {
request.set_body(b);
}
let url = request.url().to_owned();
let stream = self.stream.clone();
let response = if url.scheme() == "https" {
let stream = async_native_tls::connect(&self.host, stream)
.timeout(tls_timeout)
.await
.map_err(|x| match x {
timeout::Error::Failed(e) => Box::new(error::Error::TLS(e)),
timeout::Error::Expired => Box::new(error::Error::Timeout),
})?;
async_h1::connect(stream, request)
.timeout(request_timeout)
.await
} else {
async_h1::connect(stream, request)
.timeout(request_timeout)
.await
}
.map_err(|x| match x {
timeout::Error::Failed(e) => Box::new(error::Error::HTTP(e)),
timeout::Error::Expired => Box::new(error::Error::Timeout),
})?;
Ok(response::Response::new(url, response, response_timeout))
}
}
#[derive(Debug)]
struct Container {
host: String,
port: u16,
max_size: usize,
current_size: cell::Cell<usize>,
conn_ttl: time::Duration,
connect_timeout: time::Duration,
acquire_timeout: time::Duration,
idle_conns: IdleConns,
wait_queue: WaitQueue,
}
impl Container {
fn new(
host: String,
port: u16,
max_size: usize,
conn_ttl: time::Duration,
connect_timeout: time::Duration,
acquire_timeout: time::Duration,
) -> Self {
Self {
host,
port,
max_size,
conn_ttl,
current_size: cell::Cell::new(0),
connect_timeout,
acquire_timeout,
idle_conns: fiber::Mutex::new(Vec::with_capacity(max_size)),
wait_queue: fiber::Mutex::new(collections::LinkedList::new()),
}
}
fn acquire(&self) -> Result<Inner, Box<error::Error>> {
let mut guard = self.idle_conns.lock();
if let Some(mut v) = guard.pop() {
if v.is_expired() {
v = v.recreate()?;
}
return Ok(v);
}
if self.current_size.get() < self.max_size {
self.current_size.set(self.current_size.get() + 1);
let v = match Inner::try_new(
self.host.clone(),
self.port,
self.conn_ttl,
self.connect_timeout,
) {
Ok(v) => v,
Err(e) => {
self.current_size.set(self.current_size.get() - 1);
return Err(e);
}
};
return Ok(v);
}
drop(guard);
let mut guard = self.wait_queue.lock();
let chan = fiber::Channel::new(1);
guard.push_back(fiber::Channel::clone(&chan));
drop(guard);
let Ok(mut inner) = chan.recv_timeout(self.acquire_timeout) else {
return Err(Box::new(error::Error::Timeout));
};
if inner.is_expired() {
inner = inner.recreate()?;
}
Ok(inner)
}
fn close(&self, inner: Inner) {
drop(inner);
self.current_size.set(self.current_size.get() - 1);
}
fn release(&self, mut inner: Inner) {
let mut lock = self.wait_queue.lock();
while let Some(chan) = lock.pop_front() {
inner = match chan.send(inner) {
Ok(()) => return,
Err(v) => v,
}
}
self.idle_conns.lock().push(inner);
}
}
#[derive(Debug)]
pub(crate) struct Connection {
is_reusable: bool,
inner: Option<Inner>,
container: rc::Rc<Container>,
}
impl Connection {
fn new(inner: Inner, container: rc::Rc<Container>) -> Self {
Self {
container,
inner: Some(inner),
is_reusable: false,
}
}
pub(crate) fn execute(
&mut self,
request: request::Request,
) -> Result<response::Response, Box<error::Error>> {
let result = fiber::block_on(self.inner.as_ref().unwrap().execute(request));
let response = match result {
Ok(v) => v,
Err(e) => {
if matches!(*e, error::Error::SocketClosed) {
self.is_reusable = false;
}
return Err(e);
}
};
if let Some(header) = response.headers().get(&headers::CONNECTION) {
let value = header.as_str().to_lowercase();
self.is_reusable = response.version() == Some(http_types::Version::Http1_1)
|| value.as_str() == "keep-alive";
}
Ok(response)
}
}
impl Drop for Connection {
fn drop(&mut self) {
let inner = self.inner.take().unwrap();
if self.is_reusable {
self.container.release(inner);
} else {
self.container.close(inner);
}
}
}
#[derive(Debug)]
pub(crate) struct Pool {
max_conns: usize,
conn_ttl: time::Duration,
connect_timeout: time::Duration,
acquire_timeout: time::Duration,
inner: MappedPool,
}
impl Pool {
pub(crate) fn new(
max_conns: usize,
conn_ttl: time::Duration,
connect_timeout: time::Duration,
acquire_timeout: time::Duration,
) -> Self {
Self {
max_conns,
conn_ttl,
connect_timeout,
acquire_timeout,
inner: fiber::Mutex::new(collections::HashMap::with_capacity(16)),
}
}
pub(crate) fn get(&self, host: &str, port: u16) -> Result<Connection, Box<error::Error>> {
let mut guard = self.inner.lock();
if let Some(map) = guard.get_mut(host) {
if map.get(&port).is_none() {
map.insert(
port,
rc::Rc::new(Container::new(
host.to_string(),
port,
self.max_conns,
self.conn_ttl,
self.connect_timeout,
self.acquire_timeout,
)),
);
}
} else {
let mut map = collections::HashMap::with_capacity(16);
map.insert(
port,
rc::Rc::new(Container::new(
host.to_string(),
port,
self.max_conns,
self.conn_ttl,
self.connect_timeout,
self.acquire_timeout,
)),
);
guard.insert(host.to_owned(), map);
}
let container = rc::Rc::clone(guard.get(host).unwrap().get(&port).unwrap());
let inner = container.acquire()?;
drop(guard);
Ok(Connection::new(inner, container))
}
}