1use anyhow::Result;
2use futures::SinkExt;
3use spatio_rpc::{Command, ResponsePayload, ResponseStatus, RpcClientCodec};
4use spatio_types::config::SetOptions;
5use spatio_types::point::{Point3d, TemporalPoint};
6use spatio_types::stats::DbStats;
7use std::time::{Duration, SystemTime};
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio_stream::StreamExt;
11use tokio_util::codec::Framed;
12
13pub struct SpatioClient {
14 host: String,
15 port: u16,
16 inner: Mutex<Option<Framed<TcpStream, RpcClientCodec>>>,
17 timeout: Duration,
18}
19
20impl SpatioClient {
21 pub fn new(host: String, port: u16) -> Self {
22 Self {
23 host,
24 port,
25 inner: Mutex::new(None),
26 timeout: Duration::from_secs(10),
27 }
28 }
29
30 pub fn with_timeout(mut self, timeout: Duration) -> Self {
31 self.timeout = timeout;
32 self
33 }
34
35 async fn get_connection(&self) -> Result<Framed<TcpStream, RpcClientCodec>> {
36 let mut inner = self.inner.lock().await;
37
38 if let Some(framed) = inner.take() {
39 return Ok(framed);
40 }
41
42 let addr = format!("{}:{}", self.host, self.port);
43 let stream = tokio::time::timeout(self.timeout, TcpStream::connect(&addr)).await??;
44 Ok(Framed::new(stream, RpcClientCodec))
45 }
46
47 async fn call(&self, cmd: Command) -> Result<ResponsePayload> {
48 let mut framed = self.get_connection().await?;
49
50 let res: Result<ResponsePayload> = async {
51 framed.send(cmd).await?;
52 let (status, payload) = framed
53 .next()
54 .await
55 .ok_or_else(|| anyhow::anyhow!("Connection closed"))??;
56
57 match status {
58 ResponseStatus::Ok => Ok(payload),
59 ResponseStatus::Error => {
60 if let ResponsePayload::Error(e) = payload {
61 Err(anyhow::anyhow!(e))
62 } else {
63 Err(anyhow::anyhow!("Unknown error"))
64 }
65 }
66 }
67 }
68 .await;
69
70 if let Ok(ref payload) = res {
71 *self.inner.lock().await = Some(framed);
72 Ok(payload.clone())
73 } else {
74 res
75 }
76 }
77
78 pub async fn upsert(
79 &self,
80 namespace: &str,
81 object_id: &str,
82 point: Point3d,
83 metadata: serde_json::Value,
84 opts: Option<SetOptions>,
85 ) -> Result<()> {
86 let metadata_bytes = serde_json::to_vec(&metadata)?;
87 let cmd = Command::Upsert {
88 namespace: namespace.to_string(),
89 id: object_id.to_string(),
90 point,
91 metadata: metadata_bytes,
92 opts,
93 };
94
95 match self.call(cmd).await? {
96 ResponsePayload::Ok => Ok(()),
97 _ => Err(anyhow::anyhow!("Unexpected response")),
98 }
99 }
100
101 pub async fn get(
102 &self,
103 namespace: &str,
104 object_id: &str,
105 ) -> Result<Option<(Point3d, serde_json::Value)>> {
106 let cmd = Command::Get {
107 namespace: namespace.to_string(),
108 id: object_id.to_string(),
109 };
110
111 match self.call(cmd).await? {
112 ResponsePayload::Object {
113 point, metadata, ..
114 } => {
115 let metadata_json = serde_json::from_slice(&metadata)?;
116 Ok(Some((point, metadata_json)))
117 }
118 ResponsePayload::Error(e) if e == "Not found" => Ok(None),
119 _ => Err(anyhow::anyhow!("Unexpected response")),
120 }
121 }
122
123 pub async fn query_radius(
124 &self,
125 namespace: &str,
126 center: &Point3d,
127 radius: f64,
128 limit: usize,
129 ) -> Result<Vec<(String, Point3d, serde_json::Value, f64)>> {
130 let cmd = Command::QueryRadius {
131 namespace: namespace.to_string(),
132 center: center.clone(),
133 radius,
134 limit,
135 };
136
137 match self.call(cmd).await? {
138 ResponsePayload::Objects(results) => {
139 let mut formatted = Vec::with_capacity(results.len());
140 for (id, point, metadata, dist) in results {
141 formatted.push((id, point, serde_json::from_slice(&metadata)?, dist));
142 }
143 Ok(formatted)
144 }
145 _ => Err(anyhow::anyhow!("Unexpected response")),
146 }
147 }
148
149 pub async fn knn(
150 &self,
151 namespace: &str,
152 center: &Point3d,
153 k: usize,
154 ) -> Result<Vec<(String, Point3d, serde_json::Value, f64)>> {
155 let cmd = Command::Knn {
156 namespace: namespace.to_string(),
157 center: center.clone(),
158 k,
159 };
160
161 match self.call(cmd).await? {
162 ResponsePayload::Objects(results) => {
163 let mut formatted = Vec::with_capacity(results.len());
164 for (id, point, metadata, dist) in results {
165 formatted.push((id, point, serde_json::from_slice(&metadata)?, dist));
166 }
167 Ok(formatted)
168 }
169 _ => Err(anyhow::anyhow!("Unexpected response")),
170 }
171 }
172
173 pub async fn stats(&self) -> Result<DbStats> {
174 let cmd = Command::Stats;
175
176 match self.call(cmd).await? {
177 ResponsePayload::Stats(stats) => Ok(stats),
178 _ => Err(anyhow::anyhow!("Unexpected response")),
179 }
180 }
181
182 pub async fn delete(&self, namespace: &str, object_id: &str) -> Result<()> {
183 let cmd = Command::Delete {
184 namespace: namespace.to_string(),
185 id: object_id.to_string(),
186 };
187
188 match self.call(cmd).await? {
189 ResponsePayload::Ok => Ok(()),
190 _ => Err(anyhow::anyhow!("Unexpected response")),
191 }
192 }
193
194 pub async fn query_bbox(
195 &self,
196 namespace: &str,
197 min_x: f64,
198 min_y: f64,
199 max_x: f64,
200 max_y: f64,
201 limit: usize,
202 ) -> Result<Vec<(String, Point3d, serde_json::Value)>> {
203 let cmd = Command::QueryBbox {
204 namespace: namespace.to_string(),
205 min_x,
206 min_y,
207 max_x,
208 max_y,
209 limit,
210 };
211
212 match self.call(cmd).await? {
213 ResponsePayload::ObjectList(results) => {
214 let mut formatted = Vec::with_capacity(results.len());
215 for (id, point, metadata) in results {
216 formatted.push((id, point, serde_json::from_slice(&metadata)?));
217 }
218 Ok(formatted)
219 }
220 _ => Err(anyhow::anyhow!("Unexpected response")),
221 }
222 }
223
224 pub async fn insert_trajectory(
225 &self,
226 namespace: &str,
227 object_id: &str,
228 trajectory: Vec<TemporalPoint>,
229 ) -> Result<()> {
230 let cmd = Command::InsertTrajectory {
231 namespace: namespace.to_string(),
232 id: object_id.to_string(),
233 trajectory,
234 };
235
236 match self.call(cmd).await? {
237 ResponsePayload::Ok => Ok(()),
238 _ => Err(anyhow::anyhow!("Unexpected response")),
239 }
240 }
241
242 pub async fn query_trajectory(
243 &self,
244 namespace: &str,
245 object_id: &str,
246 start_time: SystemTime,
247 end_time: SystemTime,
248 limit: usize,
249 ) -> Result<Vec<(Point3d, serde_json::Value, SystemTime)>> {
250 let cmd = Command::QueryTrajectory {
251 namespace: namespace.to_string(),
252 id: object_id.to_string(),
253 start_time,
254 end_time,
255 limit,
256 };
257
258 match self.call(cmd).await? {
259 ResponsePayload::Trajectory(results) => {
260 let mut formatted = Vec::with_capacity(results.len());
261 for upd in results {
262 formatted.push((
263 upd.position,
264 serde_json::from_slice(&upd.metadata)?,
265 upd.timestamp,
266 ));
267 }
268 Ok(formatted)
269 }
270 _ => Err(anyhow::anyhow!("Unexpected response")),
271 }
272 }
273}