Skip to main content

aurora_db/
client.rs

1use crate::error::{AqlError, ErrorCode, Result};
2use crate::network::protocol::{Request, Response};
3use crate::query::SimpleQueryBuilder;
4use crate::types::{Document, FieldType, Value};
5use std::collections::HashMap;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8
9pub struct Client {
10    stream: TcpStream,
11    current_transaction: Option<u64>,
12}
13
14impl Client {
15    /// Connect to an Aurora server at the given address.
16    pub async fn connect(addr: &str) -> Result<Self> {
17        let stream = TcpStream::connect(addr).await?;
18        Ok(Self {
19            stream,
20            current_transaction: None,
21        })
22    }
23
24    /// Sends a request to the server and awaits a response.
25    async fn send_request(&mut self, request: Request) -> Result<Response> {
26        let request_bytes = bincode::serialize(&request).map_err(AqlError::from)?;
27        let len_bytes = (request_bytes.len() as u32).to_le_bytes();
28
29        self.stream.write_all(&len_bytes).await?;
30        self.stream.write_all(&request_bytes).await?;
31
32        let mut len_bytes = [0u8; 4];
33        self.stream.read_exact(&mut len_bytes).await?;
34        const MAX_FRAME_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
35        let len = u32::from_le_bytes(len_bytes) as usize;
36        if len > MAX_FRAME_SIZE {
37            return Err(AqlError::new(
38                ErrorCode::ProtocolError,
39                format!("Response too large: {} bytes", len),
40            ));
41        }
42
43        let mut buffer = vec![0u8; len];
44        self.stream.read_exact(&mut buffer).await?;
45
46        let response: Response = bincode::deserialize(&buffer).map_err(AqlError::from)?;
47
48        Ok(response)
49    }
50
51    pub async fn new_collection(
52        &mut self,
53        name: &str,
54        fields: Vec<(String, FieldType, bool)>,
55    ) -> Result<()> {
56        let req = Request::NewCollection {
57            name: name.to_string(),
58            fields,
59        };
60        match self.send_request(req).await? {
61            Response::Done => Ok(()),
62            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
63            _ => Err(AqlError::new(
64                ErrorCode::ProtocolError,
65                "Unexpected response".to_string(),
66            )),
67        }
68    }
69
70    pub async fn insert(
71        &mut self,
72        collection: &str,
73        data: HashMap<String, Value>,
74    ) -> Result<String> {
75        let req = Request::Insert {
76            collection: collection.to_string(),
77            data,
78        };
79        match self.send_request(req).await? {
80            Response::Message(id) => Ok(id),
81            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
82            _ => Err(AqlError::new(
83                ErrorCode::ProtocolError,
84                "Unexpected response".to_string(),
85            )),
86        }
87    }
88
89    pub async fn get_document(&mut self, collection: &str, id: &str) -> Result<Option<Document>> {
90        let req = Request::GetDocument {
91            collection: collection.to_string(),
92            id: id.to_string(),
93        };
94        match self.send_request(req).await? {
95            Response::Document(doc) => Ok(doc),
96            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
97            _ => Err(AqlError::new(
98                ErrorCode::ProtocolError,
99                "Unexpected response".to_string(),
100            )),
101        }
102    }
103
104    pub async fn query(&mut self, builder: SimpleQueryBuilder) -> Result<Vec<Document>> {
105        let req = Request::Query(builder);
106        match self.send_request(req).await? {
107            Response::Documents(docs) => Ok(docs),
108            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
109            _ => Err(AqlError::new(
110                ErrorCode::ProtocolError,
111                "Unexpected response".to_string(),
112            )),
113        }
114    }
115
116    pub async fn begin_transaction(&mut self) -> Result<u64> {
117        match self.send_request(Request::BeginTransaction).await? {
118            Response::TransactionId(tx_id) => {
119                self.current_transaction = Some(tx_id);
120                Ok(tx_id)
121            }
122            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
123            _ => Err(AqlError::new(
124                ErrorCode::ProtocolError,
125                "Unexpected response".to_string(),
126            )),
127        }
128    }
129
130    pub async fn commit_transaction(&mut self) -> Result<()> {
131        let tx_id = self
132            .current_transaction
133            .ok_or_else(|| AqlError::invalid_operation("No active transaction".to_string()))?;
134
135        match self.send_request(Request::CommitTransaction(tx_id)).await? {
136            Response::Done => {
137                self.current_transaction = None;
138                Ok(())
139            }
140            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
141            _ => Err(AqlError::new(
142                ErrorCode::ProtocolError,
143                "Unexpected response".to_string(),
144            )),
145        }
146    }
147
148    pub async fn rollback_transaction(&mut self) -> Result<()> {
149        let tx_id = self
150            .current_transaction
151            .ok_or_else(|| AqlError::invalid_operation("No active transaction".to_string()))?;
152
153        match self
154            .send_request(Request::RollbackTransaction(tx_id))
155            .await?
156        {
157            Response::Done => {
158                self.current_transaction = None;
159                Ok(())
160            }
161            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
162            _ => Err(AqlError::new(
163                ErrorCode::ProtocolError,
164                "Unexpected response".to_string(),
165            )),
166        }
167    }
168
169    pub async fn delete(&mut self, key: &str) -> Result<()> {
170        match self.send_request(Request::Delete(key.to_string())).await? {
171            Response::Done => Ok(()),
172            Response::Error(e) => Err(AqlError::new(ErrorCode::ProtocolError, e)),
173            _ => Err(AqlError::new(
174                ErrorCode::ProtocolError,
175                "Unexpected response".to_string(),
176            )),
177        }
178    }
179}