1#![warn(missing_docs)]
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use uuid::Uuid;
5use yykv_types::layout::{DsValueDecoder, DsValueEncoder};
6pub use yykv_types::{DsError, DsValue, Redundancy};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DatabaseBackend {
11 Limbo,
13 Postgres,
15 MySql,
17}
18
19pub type DatabaseResult<T> = Result<T, DsError>;
21
22#[async_trait::async_trait]
24pub trait DatabaseConnection {
25 fn backend(&self) -> DatabaseBackend;
27
28 async fn query(&self, sql: &str) -> DatabaseResult<Box<dyn RowIterator>>;
30}
31
32#[async_trait::async_trait]
34pub trait RowIterator: Send {
35 async fn next(&mut self) -> DatabaseResult<Option<Box<dyn Row>>>;
37}
38
39pub trait Row: Send {
41 fn get_string(&self, index: usize) -> DatabaseResult<String>;
43
44 fn get_i64(&self, index: usize) -> DatabaseResult<i64>;
46
47 fn get_bool(&self, index: usize) -> DatabaseResult<bool>;
49
50 fn get_option_string(&self, index: usize) -> DatabaseResult<Option<String>>;
52}
53
54pub mod schema;
56
57use crc32fast::Hasher;
58use futures::{SinkExt, StreamExt};
59use sha2::{Digest, Sha256};
60use std::net::SocketAddr;
61use std::str::FromStr;
62use tokio::net::TcpStream;
63use tokio_util::codec::Framed;
64
65#[derive(Debug, Clone)]
67pub struct ConnectionOptions {
68 pub addr: SocketAddr,
69 pub tenant_id: Uuid,
70 pub secret_key: Vec<u8>,
71}
72
73impl FromStr for ConnectionOptions {
74 type Err = DsError;
75
76 fn from_str(s: &str) -> Result<Self, Self::Err> {
77 let mut options = ConnectionOptions {
78 addr: "127.0.0.1:8889".parse().unwrap(),
79 tenant_id: Uuid::nil(),
80 secret_key: b"yykv-secret-key-2026".to_vec(),
81 };
82
83 for part in s.split(';') {
84 let kv: Vec<&str> = part.split('=').collect();
85 if kv.len() == 2 {
86 match kv[0].to_lowercase().as_str() {
87 "server" | "host" => {
88 let host = kv[1];
89 options.addr = format!("{}:8889", host)
90 .parse()
91 .map_err(|e| DsError::internal(format!("Invalid host: {}", e)))?;
92 }
93 "port" => {
94 let port: u16 = kv[1]
95 .parse()
96 .map_err(|e| DsError::internal(format!("Invalid port: {}", e)))?;
97 let mut addr = options.addr;
98 addr.set_port(port);
99 options.addr = addr;
100 }
101 "tenantid" => {
102 options.tenant_id = Uuid::parse_str(kv[1])
103 .map_err(|e| DsError::internal(format!("Invalid TenantID: {}", e)))?;
104 }
105 "secretkey" => {
106 options.secret_key = kv[1].as_bytes().to_vec();
107 }
108 _ => {}
109 }
110 }
111 }
112
113 Ok(options)
114 }
115}
116
117pub const MAGIC: [u8; 2] = *b"YY";
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum MessageType {
122 Put = 1,
123 Get = 2,
124 Delete = 3,
125 Query = 4, Rbq = 5,
127 Response = 6,
128 Auth = 7,
129 Value = 8,
130 Push = 9,
131 Pull = 10,
132 Heartbeat = 11,
133 Kql = 12,
134
135 PutResp = 101,
136 GetResp = 102,
137 DeleteResp = 103,
138 QueryResp = 104,
139
140 Error = 255,
141}
142
143impl From<u8> for MessageType {
144 fn from(v: u8) -> Self {
145 match v {
146 1 => MessageType::Put,
147 2 => MessageType::Get,
148 3 => MessageType::Delete,
149 4 => MessageType::Query,
150 5 => MessageType::Rbq,
151 6 => MessageType::Response,
152 7 => MessageType::Auth,
153 8 => MessageType::Value,
154 9 => MessageType::Push,
155 10 => MessageType::Pull,
156 11 => MessageType::Heartbeat,
157 12 => MessageType::Kql,
158 101 => MessageType::PutResp,
159 102 => MessageType::GetResp,
160 103 => MessageType::DeleteResp,
161 104 => MessageType::QueryResp,
162 _ => MessageType::Error,
163 }
164 }
165}
166
167#[derive(Debug, Clone, PartialEq)]
173pub struct TrustHeader {
174 pub version: u8,
175 pub msg_type: u8,
176 pub flags: u32,
177 pub length: u32,
178 pub checksum: u32,
179 pub request_id: Uuid,
180 pub tenant_id: Uuid,
181 pub signature: [u8; 32],
182}
183
184impl TrustHeader {
185 pub const SIZE: usize = 80;
186
187 pub fn sdr_level(&self) -> Redundancy {
188 Redundancy::from_u8((self.flags & 0xFF) as u8)
189 }
190
191 pub fn set_sdr_level(&mut self, level: Redundancy) {
192 self.flags = (self.flags & !0xFF) | (level.0 as u32 & 0xFF);
193 }
194
195 pub fn sign(&mut self, secret: &[u8]) {
196 let mut hasher = Sha256::new();
197 hasher.update(secret);
198 hasher.update(self.request_id.as_bytes());
199 hasher.update(self.tenant_id.as_bytes());
200 hasher.update(self.checksum.to_be_bytes());
201 hasher.update(self.flags.to_be_bytes());
202 let hash = hasher.finalize();
203 self.signature.copy_from_slice(&hash);
204 }
205
206 pub fn verify(&self, secret: &[u8]) -> bool {
207 let mut hasher = Sha256::new();
208 hasher.update(secret);
209 hasher.update(self.request_id.as_bytes());
210 hasher.update(self.tenant_id.as_bytes());
211 hasher.update(self.checksum.to_be_bytes());
212 hasher.update(self.flags.to_be_bytes());
213 let hash = hasher.finalize();
214 self.signature == hash.as_slice()
215 }
216
217 pub fn encode<B: BufMut>(&self, mut dst: B) {
218 dst.put_slice(&MAGIC);
219 dst.put_u8(self.version);
220 dst.put_u8(self.msg_type);
221 dst.put_u32(self.flags);
222 dst.put_u32(self.length);
223 dst.put_u32(self.checksum);
224 dst.put_slice(self.request_id.as_bytes());
225 dst.put_slice(self.tenant_id.as_bytes());
226 dst.put_slice(&self.signature);
227 }
228
229 pub fn decode(src: &mut BytesMut) -> Result<Self, DsError> {
230 if src.len() < Self::SIZE {
231 return Err(DsError::protocol("Insufficient data for header"));
232 }
233
234 let magic = [src[0], src[1]];
235 if magic != MAGIC {
236 return Err(DsError::protocol(format!("Invalid magic: {:?}", magic)));
237 }
238
239 let version = src[2];
240 let msg_type = src[3];
241 let flags = u32::from_be_bytes([src[4], src[5], src[6], src[7]]);
242 let length = u32::from_be_bytes([src[8], src[9], src[10], src[11]]);
243 let checksum = u32::from_be_bytes([src[12], src[13], src[14], src[15]]);
244
245 let request_id = Uuid::from_slice(&src[16..32])
246 .map_err(|e| DsError::protocol(format!("Invalid request ID: {}", e)))?;
247 let tenant_id = Uuid::from_slice(&src[32..48])
248 .map_err(|e| DsError::protocol(format!("Invalid tenant ID: {}", e)))?;
249
250 let mut signature = [0u8; 32];
251 signature.copy_from_slice(&src[48..80]);
252
253 src.advance(Self::SIZE);
254
255 Ok(Self {
256 version,
257 msg_type,
258 flags,
259 length,
260 checksum,
261 request_id,
262 tenant_id,
263 signature,
264 })
265 }
266}
267
268#[derive(Debug)]
270pub struct TrustMessage {
271 pub header: TrustHeader,
272 pub payload: Bytes,
273}
274
275impl TrustMessage {
276 pub fn new(msg_type: MessageType, tenant_id: Uuid, payload: Bytes) -> Self {
277 let mut hasher = Hasher::new();
278 hasher.update(&payload);
279 let checksum = hasher.finalize();
280
281 TrustMessage {
282 header: TrustHeader {
283 version: 1,
284 msg_type: msg_type as u8,
285 flags: 0,
286 length: payload.len() as u32,
287 checksum,
288 request_id: Uuid::new_v4(),
289 tenant_id,
290 signature: [0u8; 32],
291 },
292 payload,
293 }
294 }
295
296 pub fn encode<B: BufMut>(&self, mut dst: B) {
297 self.header.encode(&mut dst);
298 dst.put(self.payload.clone());
299 }
300}
301
302pub struct TrustCodec;
303
304impl tokio_util::codec::Decoder for TrustCodec {
305 type Item = TrustMessage;
306 type Error = DsError;
307
308 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
309 if src.len() < TrustHeader::SIZE {
310 return Ok(None);
311 }
312
313 let mut length_bytes = [0u8; 4];
315 length_bytes.copy_from_slice(&src[8..12]);
316 let payload_len = u32::from_be_bytes(length_bytes) as usize;
317 let total_length = TrustHeader::SIZE + payload_len;
318
319 if src.len() < total_length {
320 src.reserve(total_length - src.len());
321 return Ok(None);
322 }
323
324 let header = TrustHeader::decode(src)?;
325 let payload = src.split_to(payload_len).freeze();
326
327 let mut hasher = Hasher::new();
329 hasher.update(&payload);
330 if hasher.finalize() != header.checksum {
331 return Err(DsError::protocol("Payload checksum mismatch"));
332 }
333
334 Ok(Some(TrustMessage { header, payload }))
335 }
336}
337
338impl tokio_util::codec::Encoder<TrustMessage> for TrustCodec {
339 type Error = DsError;
340
341 fn encode(&mut self, item: TrustMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
342 item.encode(dst);
343 Ok(())
344 }
345}
346
347pub struct DsValueCodec;
349
350impl DsValueCodec {
351 pub fn encode(
352 value: &DsValue,
353 tenant_id: Uuid,
354 sdr_level: Redundancy,
355 ) -> Result<Bytes, DsError> {
356 let mut result = BytesMut::new();
357 result.resize(TrustHeader::SIZE, 0);
359
360 DsValueEncoder::encode_to_buf(value, &mut result)?;
362
363 let total_len = result.len();
364
365 let mut hasher = Hasher::new();
367 hasher.update(&result[TrustHeader::SIZE..]);
368 let checksum = hasher.finalize();
369
370 let mut header = TrustHeader {
371 version: 1,
372 msg_type: MessageType::Value as u8,
373 flags: 0,
374 length: total_len as u32,
375 checksum,
376 request_id: Uuid::new_v4(),
377 tenant_id,
378 signature: [0u8; 32],
379 };
380 header.set_sdr_level(sdr_level);
381
382 let mut header_part = &mut result[..TrustHeader::SIZE];
384 header.encode(&mut header_part);
385
386 Ok(result.freeze())
387 }
388
389 pub fn decode(mut data: BytesMut) -> Result<(DsValue, TrustHeader), DsError> {
390 let header = TrustHeader::decode(&mut data)?;
391 let mut payload = data.freeze();
392 let value = DsValueDecoder::decode(&mut payload)?;
393 Ok((value, header))
394 }
395}
396
397pub struct WeTrustClient {
399 framed: Framed<TcpStream, TrustCodec>,
400 tenant_id: Uuid,
401 secret_key: Vec<u8>,
402}
403
404impl WeTrustClient {
405 pub async fn connect(
406 addr: SocketAddr,
407 tenant_id: Uuid,
408 secret_key: Vec<u8>,
409 ) -> Result<Self, DsError> {
410 let stream = TcpStream::connect(addr)
411 .await
412 .map_err(|e| DsError::io_raw(e, Some(addr.to_string().into())))?;
413 let mut framed = Framed::new(stream, TrustCodec);
414
415 let mut auth_msg = TrustMessage::new(MessageType::Auth, tenant_id, Bytes::from("auth-v1"));
417 auth_msg.header.sign(&secret_key);
418
419 framed.send(auth_msg).await?;
420
421 if let Some(resp) = framed.next().await {
423 let resp = resp?;
424 if resp.header.msg_type != MessageType::Response as u8 {
425 return Err(DsError::protocol(
426 "Unexpected message type during handshake",
427 ));
428 }
429 if !resp.header.verify(&secret_key) {
430 return Err(DsError::protocol("Handshake signature verification failed"));
431 }
432 } else {
433 return Err(DsError::protocol("Connection closed during handshake"));
434 }
435
436 Ok(Self {
437 framed,
438 tenant_id,
439 secret_key,
440 })
441 }
442
443 pub async fn send_request(
444 &mut self,
445 msg_type: MessageType,
446 payload: Bytes,
447 ) -> Result<TrustMessage, DsError> {
448 let mut msg = TrustMessage::new(msg_type, self.tenant_id, payload);
449 msg.header.sign(&self.secret_key);
450
451 self.framed.send(msg).await?;
452
453 if let Some(resp) = self.framed.next().await {
454 let resp = resp?;
455 if !resp.header.verify(&self.secret_key) {
456 return Err(DsError::protocol("Message signature verification failed"));
457 }
458 Ok(resp)
459 } else {
460 Err(DsError::protocol("Connection closed by server"))
461 }
462 }
463
464 pub async fn send_query(&mut self, sql: &str) -> Result<Vec<Vec<DsValue>>, DsError> {
465 let _resp = self
466 .send_request(MessageType::Kql, Bytes::copy_from_slice(sql.as_bytes()))
467 .await?;
468
469 Ok(vec![vec![DsValue::Text(format!("Executed: {}", sql))]])
472 }
473
474 pub async fn put(&mut self, key: &str, value: DsValue) -> Result<(), DsError> {
475 let value_data = DsValueEncoder::encode(&value)?;
476 let mut payload = BytesMut::with_capacity(4 + key.len() + value_data.len());
477 payload.put_u32(key.len() as u32);
478 payload.put_slice(key.as_bytes());
479 payload.put(value_data);
480
481 self.send_request(MessageType::Put, payload.freeze())
482 .await?;
483 Ok(())
484 }
485
486 pub async fn get(&mut self, key: &str) -> Result<Option<DsValue>, DsError> {
487 let mut payload = BytesMut::with_capacity(4 + key.len());
488 payload.put_u32(key.len() as u32);
489 payload.put_slice(key.as_bytes());
490
491 let resp = self
492 .send_request(MessageType::Get, payload.freeze())
493 .await?;
494 if resp.header.msg_type == MessageType::Error as u8 {
495 return Ok(None);
496 }
497
498 let mut data = resp.payload;
499 if data.is_empty() {
500 return Ok(None);
501 }
502
503 let val = DsValueDecoder::decode(&mut data)?;
505 Ok(Some(val))
506 }
507
508 pub async fn delete(&mut self, key: &str) -> Result<(), DsError> {
509 let mut payload = BytesMut::with_capacity(4 + key.len());
510 payload.put_u32(key.len() as u32);
511 payload.put_slice(key.as_bytes());
512 self.send_request(MessageType::Delete, payload.freeze())
513 .await?;
514 Ok(())
515 }
516
517 pub async fn kql(&mut self, query: &str) -> Result<DsValue, DsError> {
518 let resp = self
519 .send_request(MessageType::Kql, Bytes::copy_from_slice(query.as_bytes()))
520 .await?;
521 if resp.header.msg_type == MessageType::Error as u8 {
522 return Err(DsError::query_with_sql(
523 query,
524 String::from_utf8_lossy(&resp.payload).to_string(),
525 ));
526 }
527 let mut data = resp.payload;
528 let value = DsValueDecoder::decode(&mut data)?;
529 Ok(value)
530 }
531
532 pub async fn heartbeat(&mut self) -> Result<(), DsError> {
533 self.send_request(MessageType::Heartbeat, Bytes::new())
534 .await?;
535 Ok(())
536 }
537}