use std::sync::RwLock;
use bytes::{Bytes, BytesMut};
use redust_resp::{from_data, Data};
use tracing::instrument;
use crate::{Connection, Result};
#[derive(Debug)]
pub struct Script {
contents: Bytes,
hash: RwLock<BytesMut>,
}
impl Script {
pub fn new(contents: &'static [u8]) -> Self {
Self {
contents: Bytes::from_static(contents),
hash: Default::default(),
}
}
pub fn exec<'script, 'conn>(
&'script self,
connection: &'conn mut Connection,
) -> Invocation<'script, 'conn, '_> {
Invocation {
connection,
script: self,
args: Vec::new(),
keys: Vec::new(),
}
}
pub fn is_loaded(&self) -> bool {
!self.hash.read().unwrap().is_empty()
}
fn set_hash(&self, new: &[u8]) {
let mut hash = self.hash.write().unwrap();
hash.clear();
hash.extend_from_slice(new);
}
#[instrument(level = "debug")]
pub async fn load(&self, connection: &mut Connection) -> Result<Bytes> {
let res = connection
.cmd([b"script".as_slice(), b"load", &*self.contents])
.await?;
let hash: BytesMut = from_data::<serde_bytes::ByteBuf>(res)?
.into_iter()
.collect();
self.set_hash(&*hash);
Ok(hash.freeze())
}
#[instrument(level = "trace")]
pub async fn get_hash(&self, connection: &mut Connection) -> Result<Bytes> {
let hash = self.hash.read().unwrap().clone();
if hash.is_empty() {
Ok(self.load(connection).await?)
} else {
Ok(hash.freeze())
}
}
}
#[derive(Debug)]
pub struct Invocation<'script, 'conn, 'data> {
connection: &'conn mut Connection,
script: &'script Script,
args: Vec<&'data [u8]>,
keys: Vec<&'data [u8]>,
}
impl<'data> Invocation<'_, '_, 'data> {
pub fn args<I, B>(mut self, args: I) -> Self
where
I: IntoIterator<Item = &'data B>,
B: 'data + AsRef<[u8]> + ?Sized,
{
self.args = args.into_iter().map(|b| b.as_ref()).collect();
self
}
pub fn keys<I, B>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = &'data B>,
B: 'data + AsRef<[u8]> + ?Sized,
{
self.keys = keys.into_iter().map(|b| b.as_ref()).collect();
self
}
pub fn arg<B>(mut self, arg: &'data B) -> Self
where
B: AsRef<[u8]>,
{
self.args.push(arg.as_ref());
self
}
pub fn key<B>(mut self, key: &'data B) -> Self
where
B: AsRef<[u8]>,
{
self.keys.push(key.as_ref());
self
}
#[instrument(level = "debug")]
pub async fn invoke(self) -> Result<Data<'static>> {
let hash = self.script.get_hash(self.connection).await?;
let key_len = self.keys.len().to_string().into_bytes();
let mut cmd = Vec::with_capacity(3 + self.keys.len() + self.args.len());
cmd.append(&mut vec![b"evalsha".as_slice(), &*hash, &key_len]);
cmd.extend_from_slice(&self.keys);
cmd.extend_from_slice(&self.args);
self.connection.cmd(cmd).await
}
}