use std::io;
use kevy_resp::Reply;
use crate::codec::AsyncRespCodec;
use crate::url::parse_url;
#[cfg(feature = "tokio")]
type DefaultTransport = tokio::net::TcpStream;
#[cfg(feature = "smol")]
type DefaultTransport = smol::net::TcpStream;
#[cfg(feature = "async-std")]
type DefaultTransport = async_std::net::TcpStream;
#[cfg(feature = "tokio")]
async fn connect_default(host: &str, port: u16) -> io::Result<DefaultTransport> {
crate::rt_tokio::connect(host, port).await
}
#[cfg(feature = "smol")]
async fn connect_default(host: &str, port: u16) -> io::Result<DefaultTransport> {
crate::rt_smol::connect(host, port).await
}
#[cfg(feature = "async-std")]
async fn connect_default(host: &str, port: u16) -> io::Result<DefaultTransport> {
crate::rt_async_std::connect(host, port).await
}
pub struct AsyncConnection {
codec: AsyncRespCodec<DefaultTransport>,
}
impl AsyncConnection {
pub async fn open(url: &str) -> io::Result<Self> {
let parsed = parse_url(url)?;
let transport = connect_default(&parsed.host, parsed.port).await?;
let mut codec = AsyncRespCodec::new(transport);
if let Some(db) = parsed.db {
let reply = codec
.request(&[b"SELECT".to_vec(), db.to_string().into_bytes()])
.await?;
if let Reply::Error(msg) = reply {
let text = String::from_utf8_lossy(&msg);
return Err(io::Error::other(format!("SELECT {db} rejected: {text}")));
}
}
Ok(Self { codec })
}
pub fn from_transport(transport: DefaultTransport) -> Self {
Self {
codec: AsyncRespCodec::new(transport),
}
}
pub async fn ping(&mut self) -> io::Result<()> {
let reply = self.codec.request(&[b"PING".to_vec()]).await?;
expect_pong(reply)
}
pub fn codec_mut(&mut self) -> &mut AsyncRespCodec<DefaultTransport> {
&mut self.codec
}
}
fn expect_pong(reply: Reply) -> io::Result<()> {
match reply {
Reply::Simple(s) if s == b"PONG" => Ok(()),
Reply::Bulk(s) if s == b"PONG" => Ok(()),
Reply::Error(msg) => Err(io::Error::other(format!(
"PING failed: {}",
String::from_utf8_lossy(&msg)
))),
other => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("PING returned unexpected reply: {other:?}"),
)),
}
}