Skip to main content

armour_rpc/
client.rs

1use std::ops::Bound;
2use std::path::Path;
3
4use compio::io::{
5    AsyncRead, AsyncWrite,
6    framed::{Framed, frame::LengthDelimited},
7};
8use compio::net::{OwnedReadHalf, OwnedWriteHalf, TcpStream, ToSocketAddrsAsync, UnixStream};
9use futures_util::{SinkExt, StreamExt};
10
11use crate::codec::ClientCodec;
12use crate::error::{Result, RpcError};
13use crate::protocol::*;
14
15pub struct RpcClient<R, W> {
16    hashname: u64,
17    framed: Framed<R, W, ClientCodec, LengthDelimited, Request, Response>,
18}
19
20fn hash_name(name: &str) -> u64 {
21    xxhash_rust::xxh3::xxh3_64(name.as_bytes())
22}
23
24impl RpcClient<OwnedReadHalf<TcpStream>, OwnedWriteHalf<TcpStream>> {
25    pub async fn connect_tcp(addr: impl ToSocketAddrsAsync, hashname: u64) -> Result<Self> {
26        let stream = TcpStream::connect(addr).await?;
27        let (reader, writer) = stream.into_split();
28        Ok(Self::new(reader, writer, hashname))
29    }
30
31    pub async fn connect_tcp_by_name(addr: impl ToSocketAddrsAsync, name: &str) -> Result<Self> {
32        Self::connect_tcp(addr, hash_name(name)).await
33    }
34}
35
36impl RpcClient<OwnedReadHalf<UnixStream>, OwnedWriteHalf<UnixStream>> {
37    pub async fn connect_uds(path: impl AsRef<Path>, hashname: u64) -> Result<Self> {
38        let stream = UnixStream::connect(path).await?;
39        let (reader, writer) = stream.into_split();
40        Ok(Self::new(reader, writer, hashname))
41    }
42
43    pub async fn connect_uds_by_name(path: impl AsRef<Path>, name: &str) -> Result<Self> {
44        Self::connect_uds(path, hash_name(name)).await
45    }
46}
47
48impl<R: AsyncRead + Unpin + 'static, W: AsyncWrite + Unpin + 'static> RpcClient<R, W> {
49    fn new(reader: R, writer: W, hashname: u64) -> Self {
50        let framed = Framed::new::<Request, Response>(ClientCodec, LengthDelimited::new())
51            .with_reader(reader)
52            .with_writer(writer);
53        Self { hashname, framed }
54    }
55
56    /// Send a request and return the raw Ok payload bytes, or an error.
57    async fn call(&mut self, op: OpCode, payload: RequestPayload) -> Result<Vec<u8>> {
58        let request = Request {
59            op,
60            hashname: self.hashname,
61            payload,
62        };
63        self.framed.send(request).await?;
64        let response = self.framed.next().await.ok_or(RpcError::UnexpectedEof)??;
65        match response {
66            Response::Ok(data) => Ok(data),
67            Response::Err { code, message } => Err(RpcError::Server { code, message }),
68        }
69    }
70
71    pub async fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
72        let data = self
73            .call(OpCode::Get, RequestPayload::Key(key.to_vec()))
74            .await?;
75        decode_optional_data(&data)
76    }
77
78    pub async fn contains(&mut self, key: &[u8]) -> Result<Option<u32>> {
79        let data = self
80            .call(OpCode::Contains, RequestPayload::Key(key.to_vec()))
81            .await?;
82        decode_optional_len(&data)
83    }
84
85    pub async fn first(&mut self) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
86        let data = self.call(OpCode::First, RequestPayload::Empty).await?;
87        decode_optional_kv(&data)
88    }
89
90    pub async fn last(&mut self) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
91        let data = self.call(OpCode::Last, RequestPayload::Empty).await?;
92        decode_optional_kv(&data)
93    }
94
95    pub async fn range(
96        &mut self,
97        start: Bound<Vec<u8>>,
98        end: Bound<Vec<u8>>,
99    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
100        let data = self
101            .call(OpCode::Range, RequestPayload::Range { start, end })
102            .await?;
103        decode_key_values(&data)
104    }
105
106    pub async fn range_keys(
107        &mut self,
108        start: Bound<Vec<u8>>,
109        end: Bound<Vec<u8>>,
110    ) -> Result<Vec<Vec<u8>>> {
111        let data = self
112            .call(OpCode::RangeKeys, RequestPayload::Range { start, end })
113            .await?;
114        decode_keys(&data)
115    }
116
117    pub async fn upsert(
118        &mut self,
119        key: UpsertKey,
120        flag: Option<bool>,
121        value: Vec<u8>,
122    ) -> Result<Vec<u8>> {
123        let data = self
124            .call(OpCode::Upsert, RequestPayload::Upsert { key, flag, value })
125            .await?;
126        decode_key(&data)
127    }
128
129    pub async fn remove(&mut self, key: &[u8], soft: bool) -> Result<()> {
130        self.call(
131            OpCode::Remove,
132            RequestPayload::Remove {
133                key: key.to_vec(),
134                soft,
135            },
136        )
137        .await?;
138        Ok(())
139    }
140
141    pub async fn take(&mut self, key: &[u8], soft: bool) -> Result<Option<Vec<u8>>> {
142        let data = self
143            .call(
144                OpCode::Take,
145                RequestPayload::Take {
146                    key: key.to_vec(),
147                    soft,
148                },
149            )
150            .await?;
151        decode_optional_data(&data)
152    }
153
154    pub async fn count(&mut self, exact: bool) -> Result<u64> {
155        let data = self
156            .call(OpCode::Count, RequestPayload::Count { exact })
157            .await?;
158        decode_count(&data)
159    }
160
161    pub async fn apply_batch(&mut self, items: Vec<(Vec<u8>, Option<Vec<u8>>)>) -> Result<()> {
162        self.call(OpCode::ApplyBatch, RequestPayload::Batch(items))
163            .await?;
164        Ok(())
165    }
166}
167
168// --- Response payload decoders ---
169
170fn decode_optional_data(buf: &[u8]) -> Result<Option<Vec<u8>>> {
171    let mut pos = 0;
172    let flag = read_u8(buf, &mut pos)?;
173    if flag == 0 {
174        return Ok(None);
175    }
176    Ok(Some(read_bytes(buf, &mut pos)?))
177}
178
179fn decode_optional_len(buf: &[u8]) -> Result<Option<u32>> {
180    let mut pos = 0;
181    let flag = read_u8(buf, &mut pos)?;
182    if flag == 0 {
183        return Ok(None);
184    }
185    Ok(Some(read_u32_be(buf, &mut pos)?))
186}
187
188fn decode_optional_kv(buf: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
189    let mut pos = 0;
190    let flag = read_u8(buf, &mut pos)?;
191    if flag == 0 {
192        return Ok(None);
193    }
194    let key = read_bytes(buf, &mut pos)?;
195    let val = read_bytes(buf, &mut pos)?;
196    Ok(Some((key, val)))
197}
198
199fn decode_key_values(buf: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
200    let mut pos = 0;
201    let count = read_u32_be(buf, &mut pos)? as usize;
202    let mut pairs = Vec::with_capacity(count);
203    for _ in 0..count {
204        let key = read_bytes(buf, &mut pos)?;
205        let val = read_bytes(buf, &mut pos)?;
206        pairs.push((key, val));
207    }
208    Ok(pairs)
209}
210
211fn decode_keys(buf: &[u8]) -> Result<Vec<Vec<u8>>> {
212    let mut pos = 0;
213    let count = read_u32_be(buf, &mut pos)? as usize;
214    let mut keys = Vec::with_capacity(count);
215    for _ in 0..count {
216        keys.push(read_bytes(buf, &mut pos)?);
217    }
218    Ok(keys)
219}
220
221fn decode_count(buf: &[u8]) -> Result<u64> {
222    let mut pos = 0;
223    read_u64_be(buf, &mut pos)
224}
225
226fn decode_key(buf: &[u8]) -> Result<Vec<u8>> {
227    let mut pos = 0;
228    read_bytes(buf, &mut pos)
229}