shardcache-client-rs 0.3.2

Blocking Rust client for shardcache's native SCNP protocol
Documentation
use std::io::{BufReader, BufWriter, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};

use crate::commands::ScnpCommand;
#[cfg(feature = "redis")]
use crate::commands::redis::RedisResponse;
use crate::error::{Result, ShardCacheClientError};
use crate::protocol::{
    FAST_PROTOCOL_VERSION, FAST_REQUEST_MAGIC, FAST_RESPONSE_MAGIC, STATUS_ERROR, STATUS_INTEGER,
    STATUS_NULL, STATUS_OK, STATUS_VALUE,
};
#[cfg(feature = "redis")]
use crate::protocol::{STATUS_ARRAY, STATUS_FLOAT};

#[derive(Debug)]
pub(crate) struct ScnpConnection {
    r: BufReader<TcpStream>,
    pub(crate) w: BufWriter<TcpStream>,
    scratch: Vec<u8>,
}

impl ScnpConnection {
    pub(crate) fn connect(addr: impl ToSocketAddrs) -> Result<Self> {
        let s = TcpStream::connect(addr)?;
        s.set_nodelay(true)?;
        tune_tcp_stream_buffers(&s);
        let s2 = s.try_clone()?;
        Ok(Self {
            r: BufReader::with_capacity(64 * 1024, s),
            w: BufWriter::with_capacity(64 * 1024, s2),
            scratch: Vec::with_capacity(64),
        })
    }

    pub(crate) fn execute<C: ScnpCommand>(&mut self, command: C) -> Result<C::Output> {
        self.write_header(command.opcode(), command.flags(), command.body_len() as u32)?;
        command.write_body(&mut self.w)?;
        self.w.flush()?;
        command.read_response(self)
    }

    pub(crate) fn flush(&mut self) -> Result<()> {
        self.w.flush()?;
        Ok(())
    }

    pub(crate) fn expect_ok(&mut self, op: &str) -> Result<()> {
        let (status, body_len) = self.read_response_header()?;
        match status {
            STATUS_OK => {
                self.discard(body_len)?;
                Ok(())
            }
            STATUS_ERROR => Err(ShardCacheClientError::Protocol(self.read_error(body_len)?)),
            other => Err(ShardCacheClientError::Protocol(format!(
                "{op} unexpected response status: {other}"
            ))),
        }
    }

    pub(crate) fn read_value(&mut self, op: &str, out: &mut Vec<u8>) -> Result<bool> {
        out.clear();
        let (status, body_len) = self.read_response_header()?;
        match status {
            STATUS_VALUE => {
                out.resize(body_len, 0);
                self.r.read_exact(out.as_mut_slice())?;
                Ok(true)
            }
            STATUS_NULL => {
                self.discard(body_len)?;
                Ok(false)
            }
            STATUS_ERROR => Err(ShardCacheClientError::Protocol(self.read_error(body_len)?)),
            other => Err(ShardCacheClientError::Protocol(format!(
                "{op} unexpected response status: {other}"
            ))),
        }
    }

    pub(crate) fn read_integer(&mut self, op: &str) -> Result<i64> {
        let (status, body_len) = self.read_response_header()?;
        match status {
            STATUS_INTEGER => {
                if body_len != 8 {
                    self.discard(body_len)?;
                    return Err(ShardCacheClientError::Protocol(format!(
                        "{op} integer response body length was {body_len}, expected 8"
                    )));
                }
                let mut value = [0u8; 8];
                self.r.read_exact(&mut value)?;
                Ok(i64::from_le_bytes(value))
            }
            STATUS_ERROR => Err(ShardCacheClientError::Protocol(self.read_error(body_len)?)),
            other => Err(ShardCacheClientError::Protocol(format!(
                "{op} unexpected response status: {other}"
            ))),
        }
    }

    #[cfg(feature = "redis")]
    pub(crate) fn read_native_redis_response(&mut self, op: &str) -> Result<RedisResponse> {
        let (status, body_len) = self.read_response_header()?;
        match status {
            STATUS_OK => {
                self.discard(body_len)?;
                Ok(RedisResponse::Ok)
            }
            STATUS_NULL => {
                self.discard(body_len)?;
                Ok(RedisResponse::Null)
            }
            STATUS_ERROR => Err(ShardCacheClientError::Protocol(self.read_error(body_len)?)),
            STATUS_INTEGER => {
                if body_len != 8 {
                    self.discard(body_len)?;
                    return Err(ShardCacheClientError::Protocol(format!(
                        "{op} integer response body length was {body_len}, expected 8"
                    )));
                }
                let mut value = [0u8; 8];
                self.r.read_exact(&mut value)?;
                Ok(RedisResponse::Integer(i64::from_le_bytes(value)))
            }
            STATUS_VALUE => {
                let mut value = vec![0; body_len];
                self.r.read_exact(value.as_mut_slice())?;
                Ok(RedisResponse::Value(value))
            }
            STATUS_ARRAY => self.read_array_response(op, body_len),
            STATUS_FLOAT => {
                if body_len != 8 {
                    self.discard(body_len)?;
                    return Err(ShardCacheClientError::Protocol(format!(
                        "{op} float response body length was {body_len}, expected 8"
                    )));
                }
                let mut value = [0u8; 8];
                self.r.read_exact(&mut value)?;
                Ok(RedisResponse::Float(f64::from_le_bytes(value)))
            }
            other => Err(ShardCacheClientError::Protocol(format!(
                "{op} unexpected response status: {other}"
            ))),
        }
    }

    #[cfg(feature = "redis")]
    pub(crate) fn read_resp_redis_response(&mut self, op: &str) -> Result<RedisResponse> {
        let (status, body_len) = self.read_response_header()?;
        match status {
            STATUS_OK => {
                self.discard(body_len)?;
                Ok(RedisResponse::Ok)
            }
            STATUS_NULL => {
                self.discard(body_len)?;
                Ok(RedisResponse::Null)
            }
            STATUS_ERROR => Err(ShardCacheClientError::Protocol(self.read_error(body_len)?)),
            STATUS_INTEGER => {
                if body_len != 8 {
                    self.discard(body_len)?;
                    return Err(ShardCacheClientError::Protocol(format!(
                        "{op} integer response body length was {body_len}, expected 8"
                    )));
                }
                let mut value = [0u8; 8];
                self.r.read_exact(&mut value)?;
                Ok(RedisResponse::Integer(i64::from_le_bytes(value)))
            }
            STATUS_VALUE => {
                self.scratch.resize(body_len, 0);
                self.r.read_exact(&mut self.scratch[..body_len])?;
                RedisResponse::from_resp_bytes(&self.scratch[..body_len])
            }
            STATUS_ARRAY => self.read_array_response(op, body_len),
            STATUS_FLOAT => {
                if body_len != 8 {
                    self.discard(body_len)?;
                    return Err(ShardCacheClientError::Protocol(format!(
                        "{op} float response body length was {body_len}, expected 8"
                    )));
                }
                let mut value = [0u8; 8];
                self.r.read_exact(&mut value)?;
                Ok(RedisResponse::Float(f64::from_le_bytes(value)))
            }
            other => Err(ShardCacheClientError::Protocol(format!(
                "{op} unexpected response status: {other}"
            ))),
        }
    }

    pub(crate) fn write_header(&mut self, cmd: u8, flags: u8, body_len: u32) -> Result<()> {
        let header = [
            FAST_REQUEST_MAGIC,
            FAST_PROTOCOL_VERSION,
            cmd,
            flags,
            body_len as u8,
            (body_len >> 8) as u8,
            (body_len >> 16) as u8,
            (body_len >> 24) as u8,
        ];
        self.w.write_all(&header)?;
        Ok(())
    }

    fn read_response_header(&mut self) -> Result<(u8, usize)> {
        let mut header = [0u8; 8];
        self.r.read_exact(&mut header)?;
        if header[0] != FAST_RESPONSE_MAGIC {
            return Err(ShardCacheClientError::Protocol(format!(
                "bad response magic: 0x{:02X}",
                header[0]
            )));
        }
        if header[1] != FAST_PROTOCOL_VERSION {
            return Err(ShardCacheClientError::Protocol(format!(
                "bad response version: {}",
                header[1]
            )));
        }
        let status = header[2];
        let body_len = u32::from_le_bytes([header[4], header[5], header[6], header[7]]) as usize;
        Ok((status, body_len))
    }

    #[cfg(feature = "redis")]
    fn read_array_response(&mut self, op: &str, body_len: usize) -> Result<RedisResponse> {
        self.scratch.resize(body_len, 0);
        self.r.read_exact(&mut self.scratch[..body_len])?;
        if body_len < 4 {
            return Err(ShardCacheClientError::Protocol(format!(
                "{op} array response body length was {body_len}, expected at least 4"
            )));
        }

        let mut cursor = 0usize;
        let count = read_u32(&self.scratch, &mut cursor, op)? as usize;
        let mut values = Vec::with_capacity(count);
        for _ in 0..count {
            let len = read_u32(&self.scratch, &mut cursor, op)?;
            if len == u32::MAX {
                values.push(RedisResponse::Null);
                continue;
            }
            let len = len as usize;
            let end = cursor.checked_add(len).ok_or_else(|| {
                ShardCacheClientError::Protocol(format!("{op} array item length overflow"))
            })?;
            if end > self.scratch.len() {
                return Err(ShardCacheClientError::Protocol(format!(
                    "{op} array item exceeds response body"
                )));
            }
            values.push(RedisResponse::Value(self.scratch[cursor..end].to_vec()));
            cursor = end;
        }
        if cursor != self.scratch.len() {
            return Err(ShardCacheClientError::Protocol(format!(
                "{op} array response has trailing bytes"
            )));
        }
        Ok(RedisResponse::Array(values))
    }

    fn discard(&mut self, n: usize) -> Result<()> {
        if n == 0 {
            return Ok(());
        }
        self.scratch.resize(n, 0);
        self.r.read_exact(&mut self.scratch[..n])?;
        Ok(())
    }

    fn read_error(&mut self, body_len: usize) -> Result<String> {
        self.scratch.resize(body_len, 0);
        self.r.read_exact(&mut self.scratch[..body_len])?;
        Ok(String::from_utf8_lossy(&self.scratch[..body_len]).into_owned())
    }
}

#[cfg(unix)]
fn configured_tcp_buffer_bytes() -> Option<usize> {
    static VALUE: std::sync::OnceLock<Option<usize>> = std::sync::OnceLock::new();
    *VALUE.get_or_init(|| {
        std::env::var("SCNP_CLIENT_TCP_BUFFER_BYTES")
            .or_else(|_| std::env::var("SHARDCACHE_TCP_BUFFER_BYTES"))
            .ok()
            .and_then(|value| value.parse::<usize>().ok())
            .filter(|value| *value > 0)
    })
}

#[cfg(unix)]
fn tune_tcp_stream_buffers(stream: &TcpStream) {
    let Some(buffer_bytes) = configured_tcp_buffer_bytes() else {
        return;
    };
    let Ok(value) = libc::c_int::try_from(buffer_bytes) else {
        return;
    };

    use std::os::fd::AsRawFd;
    let fd = stream.as_raw_fd();
    let value_ptr = (&value as *const libc::c_int).cast::<libc::c_void>();
    let value_len = std::mem::size_of_val(&value) as libc::socklen_t;

    unsafe {
        let _ = libc::setsockopt(fd, libc::SOL_SOCKET, libc::SO_SNDBUF, value_ptr, value_len);
        let _ = libc::setsockopt(fd, libc::SOL_SOCKET, libc::SO_RCVBUF, value_ptr, value_len);
    }
}

#[cfg(not(unix))]
fn tune_tcp_stream_buffers(_stream: &TcpStream) {}

#[cfg(feature = "redis")]
fn read_u32(buf: &[u8], cursor: &mut usize, op: &str) -> Result<u32> {
    let end = cursor
        .checked_add(4)
        .ok_or_else(|| ShardCacheClientError::Protocol(format!("{op} array cursor overflow")))?;
    if end > buf.len() {
        return Err(ShardCacheClientError::Protocol(format!(
            "{op} truncated array response"
        )));
    }
    let value = u32::from_le_bytes(buf[*cursor..end].try_into().unwrap());
    *cursor = end;
    Ok(value)
}