Skip to main content

aurora_db/
client.rs

1use crate::error::{AqlError, Result, ErrorCode};
2use crate::network::protocol::{Request, Response};
3use crate::query::SimpleQueryBuilder;
4use crate::types::{Document, FieldType, Value};
5use std::collections::HashMap;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8
9pub struct Client {
10    stream: TcpStream,
11    current_transaction: Option<u64>,
12}
13
14impl Client {
15    /// Connect to an Aurora server at the given address.
16    pub async fn connect(addr: &str) -> Result<Self> {
17        let stream = TcpStream::connect(addr).await?;
18        Ok(Self {
19            stream,
20            current_transaction: None,
21        })
22    }
23
24    /// Sends a request to the server and awaits a response.
25    async fn send_request(&mut self, request: Request) -> Result<Response> {
26        let request_bytes = bincode::serialize(&request).map_err(AqlError::from)?;
27        let len_bytes = (request_bytes.len() as u32).to_le_bytes();
28
29        self.stream.write_all(&len_bytes).await?;
30        self.stream.write_all(&request_bytes).await?;
31
32        let mut len_bytes = [0u8; 4];
33        self.stream.read_exact(&mut len_bytes).await?;
34        const MAX_FRAME_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
35        let len = u32::from_le_bytes(len_bytes) as usize;
36        if len > MAX_FRAME_SIZE {
37            return Err(AqlError::new(
38                ErrorCode::ProtocolError,
39                format!("Response too large: {} bytes", len),
40            ));
41        }
42
43        let mut buffer = vec![0u8; len];
44        self.stream.read_exact(&mut buffer).await?;
45
46        let response: Response = bincode::deserialize(&buffer).map_err(AqlError::from)?;
47
48        Ok(response)
49    }
50
51    pub async fn new_collection(
52        &mut self,
53        name: &str,
54        fields: Vec<(String, FieldType, bool)>,
55    ) -> Result<()> {
56        let req = Request::NewCollection {
57            name: name.to_string(),
58            fields,
59        };
60        match self.send_request(req).await? {
61            Response::Done => Ok(()),
62            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
63            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
64        }
65    }
66
67    pub async fn insert(
68        &mut self,
69        collection: &str,
70        data: HashMap<String, Value>,
71    ) -> Result<String> {
72        let req = Request::Insert {
73            collection: collection.to_string(),
74            data,
75        };
76        match self.send_request(req).await? {
77            Response::Message(id) => Ok(id),
78            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
79            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
80        }
81    }
82
83    pub async fn get_document(&mut self, collection: &str, id: &str) -> Result<Option<Document>> {
84        let req = Request::GetDocument {
85            collection: collection.to_string(),
86            id: id.to_string(),
87        };
88        match self.send_request(req).await? {
89            Response::Document(doc) => Ok(doc),
90            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
91            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
92        }
93    }
94
95    pub async fn query(&mut self, builder: SimpleQueryBuilder) -> Result<Vec<Document>> {
96        let req = Request::Query(builder);
97        match self.send_request(req).await? {
98            Response::Documents(docs) => Ok(docs),
99            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
100            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
101        }
102    }
103
104    pub async fn begin_transaction(&mut self) -> Result<u64> {
105        match self.send_request(Request::BeginTransaction).await? {
106            Response::TransactionId(tx_id) => {
107                self.current_transaction = Some(tx_id);
108                Ok(tx_id)
109            }
110            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
111            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
112        }
113    }
114
115    pub async fn commit_transaction(&mut self) -> Result<()> {
116        let tx_id = self
117            .current_transaction
118            .ok_or_else(|| AqlError::invalid_operation("No active transaction".to_string()))?;
119
120        match self.send_request(Request::CommitTransaction(tx_id)).await? {
121            Response::Done => {
122                self.current_transaction = None;
123                Ok(())
124            }
125            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
126            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
127        }
128    }
129
130    pub async fn rollback_transaction(&mut self) -> Result<()> {
131        let tx_id = self
132            .current_transaction
133            .ok_or_else(|| AqlError::invalid_operation("No active transaction".to_string()))?;
134
135        match self
136            .send_request(Request::RollbackTransaction(tx_id))
137            .await?
138        {
139            Response::Done => {
140                self.current_transaction = None;
141                Ok(())
142            }
143            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
144            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
145        }
146    }
147
148    pub async fn delete(&mut self, key: &str) -> Result<()> {
149        match self.send_request(Request::Delete(key.to_string())).await? {
150            Response::Done => Ok(()),
151            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
152            _ => Err(AqlError::new(ErrorCode::ProtocolError, "Unexpected response".to_string())),
153        }
154    }
155}