1pub mod dsl;
12pub mod query_generator;
13
14use std::marker::PhantomData;
15
16pub use dsl::*;
20
21pub use dsl::prelude;
24
25use reqwest::{Client as ReqwestClient, StatusCode};
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Debug, Clone)]
33pub struct Client {
34 client: ReqwestClient,
35 url: reqwest::Url,
36 api_key: Option<String>,
37}
38
39pub type HelixDBClient = Client;
41
42#[derive(Debug, Error)]
43pub enum HelixError {
44 #[error("Error communicating with server: {0}")]
45 ReqwestError(#[from] reqwest::Error),
46 #[error("Got Error from server: {details}")]
47 RemoteError { details: String },
48 #[error("Error serializing data: {0}")]
49 SerializationError(#[from] sonic_rs::Error),
50 #[error("Invalid URL: {0}")]
51 InvalidURL(String),
52}
53
54impl Client {
55 pub fn new(url: Option<&str>) -> Result<Self, HelixError> {
56 let url = reqwest::Url::parse(url.unwrap_or("http://localhost:6969"))
59 .map_err(|e| HelixError::InvalidURL(e.to_string()))?
60 .join("/v1/query")
61 .map_err(|e| HelixError::InvalidURL(e.to_string()))?;
62 Ok(Self {
63 client: ReqwestClient::new(),
64 url,
65 api_key: None,
66 })
67 }
68
69 pub fn with_api_key(mut self, api_key: Option<&str>) -> Self {
70 self.api_key = api_key.map(|key| key.to_string());
71 self
72 }
73
74 pub fn query<R: for<'de> Deserialize<'de>>(&self) -> QueryBuilder<'_, '_, R> {
75 QueryBuilder::new(self)
76 }
77}
78
79pub struct QueryBuilder<'hlx, 'a, R> {
80 client: &'hlx HelixDBClient,
81 query_type: QueryType,
82 headers: [Option<(&'a str, &'a str)>; 4],
83 body: Option<Vec<u8>>,
84 _phantom: PhantomData<R>,
85}
86
87#[derive(Default)]
88pub(crate) enum QueryType {
89 Stored(String),
90 Dynamic(DynamicQueryRequest),
91 #[default]
92 Empty,
93}
94
95impl<'hlx, 'a, R> QueryBuilder<'hlx, 'a, R> {
96 pub fn new(client: &'hlx HelixDBClient) -> Self {
97 let mut headers = [None; 4];
98 headers[0] = Some(("Content-Type", "application/json"));
99 Self {
100 client,
101 query_type: QueryType::default(),
102 headers,
103 body: None,
104 _phantom: PhantomData,
105 }
106 }
107
108 pub fn writer_only(mut self) -> Self {
109 self.headers[1] = Some(("x-helix-require-writer", "true"));
110 self
111 }
112
113 #[must_use]
114 pub fn warm_only(mut self) -> Self {
115 self.headers[2] = Some(("x-helix-warm", "true"));
116 self
117 }
118
119 pub fn should_await_durability(mut self, should: bool) -> Self {
120 self.headers[3] = Some((
121 "x-helix-await-durable",
122 if should { "true" } else { "false" },
123 ));
124 self
125 }
126
127 pub fn body<T: Serialize + Sync>(mut self, data: &T) -> Result<Self, HelixError> {
128 self.body = Some(sonic_rs::to_vec(data)?);
129 Ok(self)
130 }
131
132 pub fn stored(mut self, query_name: String) -> QueryRequest<'hlx, 'a, R> {
133 self.query_type = QueryType::Stored(query_name);
134 QueryRequest { request: self }
135 }
136
137 pub fn dynamic(mut self, query: DynamicQueryRequest) -> QueryRequest<'hlx, 'a, R> {
138 self.query_type = QueryType::Dynamic(query);
139 QueryRequest { request: self }
140 }
141}
142
143pub struct QueryRequest<'hlx, 'a, R> {
144 request: QueryBuilder<'hlx, 'a, R>,
145}
146
147impl<'hlx, 'a, R: for<'de> Deserialize<'de>> QueryRequest<'hlx, 'a, R> {
148 pub async fn send(self) -> Result<R, HelixError> {
149 let query_request = self.request;
150 let (url, body) = match query_request.query_type {
151 QueryType::Dynamic(query) => ("/v1/query".to_string(), Some(sonic_rs::to_vec(&query)?)),
152 QueryType::Stored(name) => (format!("/v1/query/{name}"), query_request.body),
153 QueryType::Empty => {
154 unreachable!("send() is only reachable after stored() or dynamic() sets query_type")
155 }
156 };
157 let url = query_request
158 .client
159 .url
160 .join(&url)
161 .map_err(|e| HelixError::InvalidURL(e.to_string()))?;
162
163 let mut request = query_request.client.client.post(url);
164
165 for (k, v) in query_request.headers.into_iter().flatten() {
166 request = request.header(k, v);
167 }
168 if let Some(ref api_key) = query_request.client.api_key {
169 request = request.bearer_auth(api_key);
170 }
171 if let Some(body) = body {
172 request = request.body(body);
173 }
174
175 let response = request.send().await?;
176
177 match response.status() {
178 StatusCode::OK => {
179 let bytes = response.bytes().await?;
180 sonic_rs::from_slice::<R>(&bytes).map_err(Into::into)
181 }
182 code => match response.text().await {
183 Ok(t) => Err(HelixError::RemoteError { details: t }),
184 Err(_) => match code.canonical_reason() {
185 Some(r) => Err(HelixError::RemoteError {
186 details: r.to_string(),
187 }),
188 None => Err(HelixError::RemoteError {
189 details: format!("unkown error with code: {code}"),
190 }),
191 },
192 },
193 }
194 }
195}
196
197extern crate self as helix_db;
198
199#[cfg(test)]
200mod tests {
201 use helix_db::dsl::prelude::*;
202 use std::collections::BTreeMap;
203
204 #[register]
205 fn query1(name: String) {
206 read_batch()
208 .var_as("user", g().n_where(SourcePredicate::eq("username", name)))
209 .var_as(
210 "friends",
211 g().n(NodeRef::var("user"))
212 .out(Some("FOLLOWS"))
213 .dedup()
214 .limit(100),
215 )
216 .returning(["user", "friends"])
217 }
218
219 #[test]
220 fn query1_builds_dynamic_request() {
221 let query = query1(String::from("alice"));
223
224 assert!(matches!(query.request_type, DynamicQueryRequestType::Read));
225 let params = query.parameters.expect("parameters present");
226 assert!(matches!(
227 params.get("name"),
228 Some(DynamicQueryValue::String(s)) if s == "alice"
229 ));
230 }
231
232 #[register]
235 fn q_bool(flag: bool) {
236 read_batch()
237 .var_as("v", g().n_where(SourcePredicate::eq("field", flag)))
238 .returning(["v"])
239 }
240 #[register]
241 fn q_i64(num: i64) {
242 read_batch()
243 .var_as("v", g().n_where(SourcePredicate::eq("field", num)))
244 .returning(["v"])
245 }
246 #[register]
247 fn q_f64(x: f64) {
248 read_batch()
249 .var_as("v", g().n_where(SourcePredicate::eq("field", x)))
250 .returning(["v"])
251 }
252 #[register]
253 fn q_f32(x: f32) {
254 read_batch()
255 .var_as("v", g().n_where(SourcePredicate::eq("field", x)))
256 .returning(["v"])
257 }
258 #[register]
259 fn q_datetime(ts: DateTime) {
260 read_batch()
261 .var_as("v", g().n_where(SourcePredicate::eq("field", ts)))
262 .returning(["v"])
263 }
264 #[register]
265 fn q_value(val: ParamValue) {
266 read_batch()
267 .var_as("v", g().n_where(SourcePredicate::eq("field", val)))
268 .returning(["v"])
269 }
270 #[register]
271 fn q_object(obj: ParamObject) {
272 read_batch()
273 .var_as("v", g().n_where(SourcePredicate::eq("field", obj)))
274 .returning(["v"])
275 }
276 #[register]
277 fn q_array(items: Vec<String>) {
278 read_batch()
279 .var_as("v", g().n_where(SourcePredicate::eq("field", items)))
280 .returning(["v"])
281 }
282 #[register]
283 fn q_map(map: BTreeMap<String, String>) {
284 read_batch()
285 .var_as("v", g().n_where(SourcePredicate::eq("field", map)))
286 .returning(["v"])
287 }
288 #[register]
289 #[allow(unused_variables)] fn q_bytes(blob: Vec<u8>) {
291 read_batch()
292 .var_as("v", g().n_where(SourcePredicate::eq("field", blob)))
293 .returning(["v"])
294 }
295
296 #[test]
297 fn param_types_coerce_correctly() {
298 let r = q_bool(true);
300 assert!(matches!(r.request_type, DynamicQueryRequestType::Read));
301 assert!(matches!(
302 r.parameters.as_ref().unwrap().get("flag"),
303 Some(DynamicQueryValue::Bool(true))
304 ));
305 assert!(matches!(
306 r.parameter_types.as_ref().unwrap().get("flag"),
307 Some(QueryParamType::Bool)
308 ));
309
310 let r = q_i64(7);
312 assert!(matches!(
313 r.parameters.as_ref().unwrap().get("num"),
314 Some(DynamicQueryValue::I64(7))
315 ));
316 assert!(matches!(
317 r.parameter_types.as_ref().unwrap().get("num"),
318 Some(QueryParamType::I64)
319 ));
320
321 let r = q_f64(1.5);
323 assert!(matches!(
324 r.parameters.as_ref().unwrap().get("x"),
325 Some(DynamicQueryValue::F64(v)) if *v == 1.5
326 ));
327 assert!(matches!(
328 r.parameter_types.as_ref().unwrap().get("x"),
329 Some(QueryParamType::F64)
330 ));
331
332 let r = q_f32(1.5f32);
334 assert!(matches!(
335 r.parameters.as_ref().unwrap().get("x"),
336 Some(DynamicQueryValue::F32(v)) if *v == 1.5f32
337 ));
338 assert!(matches!(
339 r.parameter_types.as_ref().unwrap().get("x"),
340 Some(QueryParamType::F32)
341 ));
342
343 let r = q_datetime(DateTime::from_millis(0));
345 let expected = DateTime::from_millis(0).to_rfc3339().unwrap();
346 assert!(matches!(
347 r.parameters.as_ref().unwrap().get("ts"),
348 Some(DynamicQueryValue::String(s)) if *s == expected
349 ));
350 assert!(matches!(
351 r.parameter_types.as_ref().unwrap().get("ts"),
352 Some(QueryParamType::DateTime)
353 ));
354
355 let r = q_value(PropertyValue::I64(5));
357 assert!(matches!(
358 r.parameters.as_ref().unwrap().get("val"),
359 Some(DynamicQueryValue::I64(5))
360 ));
361 assert!(matches!(
362 r.parameter_types.as_ref().unwrap().get("val"),
363 Some(QueryParamType::Value)
364 ));
365
366 let mut obj = BTreeMap::new();
368 obj.insert("k".to_string(), PropertyValue::String("x".to_string()));
369 let r = q_object(obj);
370 assert!(matches!(
371 r.parameters.as_ref().unwrap().get("obj"),
372 Some(DynamicQueryValue::Object(_))
373 ));
374 assert!(matches!(
375 r.parameter_types.as_ref().unwrap().get("obj"),
376 Some(QueryParamType::Object)
377 ));
378
379 let r = q_array(vec!["a".to_string(), "b".to_string()]);
381 match r.parameters.as_ref().unwrap().get("items") {
382 Some(DynamicQueryValue::Array(items)) => {
383 assert_eq!(items.len(), 2);
384 assert!(matches!(&items[0], DynamicQueryValue::String(s) if s == "a"));
385 assert!(matches!(&items[1], DynamicQueryValue::String(s) if s == "b"));
386 }
387 other => panic!("expected array, got {other:?}"),
388 }
389 assert!(matches!(
390 r.parameter_types.as_ref().unwrap().get("items"),
391 Some(QueryParamType::Array(inner)) if matches!(**inner, QueryParamType::String)
392 ));
393
394 let mut map = BTreeMap::new();
396 map.insert("k".to_string(), "v".to_string());
397 let r = q_map(map);
398 assert!(matches!(
399 r.parameters.as_ref().unwrap().get("map"),
400 Some(DynamicQueryValue::Object(_))
401 ));
402 assert!(matches!(
403 r.parameter_types.as_ref().unwrap().get("map"),
404 Some(QueryParamType::Object)
405 ));
406 }
407
408 #[test]
409 #[should_panic(expected = "failed to coerce parameter")]
410 fn bytes_param_panics_on_dynamic_call() {
411 let _ = q_bytes(vec![1, 2, 3]);
414 }
415
416 #[test]
419 fn predicate_literal_json_is_unchanged() {
420 assert_eq!(
421 sonic_rs::to_string(&Predicate::eq("username", "alice")).unwrap(),
422 r#"{"Eq":["username",{"String":"alice"}]}"#
423 );
424 assert_eq!(
425 sonic_rs::to_string(&Predicate::gt("score", 10i64)).unwrap(),
426 r#"{"Gt":["score",{"I64":10}]}"#
427 );
428 assert_eq!(
429 sonic_rs::to_string(&Predicate::between("age", 18i64, 65i64)).unwrap(),
430 r#"{"Between":["age",{"I64":18},{"I64":65}]}"#
431 );
432 }
433
434 #[test]
435 fn predicate_param_json_uses_expr_variants() {
436 assert_eq!(
437 sonic_rs::to_string(&Predicate::eq("username", Expr::param("name"))).unwrap(),
438 r#"{"EqExpr":["username",{"Param":"name"}]}"#
439 );
440 assert_eq!(
441 sonic_rs::to_string(&Predicate::lte("score", Expr::param("max"))).unwrap(),
442 r#"{"LteExpr":["score",{"Param":"max"}]}"#
443 );
444 assert_eq!(
445 sonic_rs::to_string(&Predicate::between("age", Expr::param("lo"), 65i64)).unwrap(),
446 r#"{"BetweenExpr":["age",{"Param":"lo"},{"Constant":{"I64":65}}]}"#
447 );
448 }
449
450 #[test]
451 fn predicate_json_round_trips() {
452 for predicate in [
453 Predicate::eq("username", "alice"),
454 Predicate::eq("username", Expr::param("name")),
455 Predicate::between("age", Expr::param("lo"), 65i64),
456 ] {
457 let json = sonic_rs::to_string(&predicate).unwrap();
458 let back: Predicate = sonic_rs::from_str(&json).unwrap();
459 assert_eq!(predicate, back);
460 }
461 }
462
463 #[test]
466 fn source_predicate_literal_json_is_unchanged() {
467 assert_eq!(
468 sonic_rs::to_string(&SourcePredicate::eq("username", "alice")).unwrap(),
469 r#"{"Eq":["username",{"String":"alice"}]}"#
470 );
471 assert_eq!(
472 sonic_rs::to_string(&SourcePredicate::gt("score", 10i64)).unwrap(),
473 r#"{"Gt":["score",{"I64":10}]}"#
474 );
475 assert_eq!(
476 sonic_rs::to_string(&SourcePredicate::between("age", 18i64, 65i64)).unwrap(),
477 r#"{"Between":["age",{"I64":18},{"I64":65}]}"#
478 );
479 }
480
481 #[test]
482 fn source_predicate_param_json_uses_expr_variants() {
483 assert_eq!(
484 sonic_rs::to_string(&SourcePredicate::eq("username", Expr::param("name"))).unwrap(),
485 r#"{"EqExpr":["username",{"Param":"name"}]}"#
486 );
487 assert_eq!(
488 sonic_rs::to_string(&SourcePredicate::lte("score", Expr::param("max"))).unwrap(),
489 r#"{"LteExpr":["score",{"Param":"max"}]}"#
490 );
491 assert_eq!(
492 sonic_rs::to_string(&SourcePredicate::between("age", Expr::param("lo"), 65i64))
493 .unwrap(),
494 r#"{"BetweenExpr":["age",{"Param":"lo"},{"Constant":{"I64":65}}]}"#
495 );
496 }
497
498 #[test]
499 fn source_predicate_json_round_trips() {
500 for sp in [
501 SourcePredicate::eq("username", "alice"),
502 SourcePredicate::eq("username", Expr::param("name")),
503 SourcePredicate::between("age", Expr::param("lo"), 65i64),
504 ] {
505 let json = sonic_rs::to_string(&sp).unwrap();
506 let back: SourcePredicate = sonic_rs::from_str(&json).unwrap();
507 assert_eq!(sp, back);
508 }
509 }
510
511 #[test]
514 fn query_ast_literal_vs_param_json() {
515 let literal = read_batch()
516 .var_as(
517 "user",
518 g().n_where(SourcePredicate::eq("username", "alice")),
519 )
520 .returning(["user"]);
521 let literal_json = sonic_rs::to_string(&literal).unwrap();
522 assert!(
523 literal_json.contains(r#"{"NWhere":{"Eq":["username",{"String":"alice"}]}}"#),
524 "literal NWhere step changed shape: {literal_json}"
525 );
526 assert!(!literal_json.contains("EqExpr"));
527
528 let param = read_batch()
529 .var_as(
530 "user",
531 g().n_where(SourcePredicate::eq("username", Expr::param("name"))),
532 )
533 .returning(["user"]);
534 let param_json = sonic_rs::to_string(¶m).unwrap();
535 assert!(
536 param_json.contains(r#"{"NWhere":{"EqExpr":["username",{"Param":"name"}]}}"#),
537 "param NWhere step missing EqExpr/Param: {param_json}"
538 );
539 }
540}
541
542#[cfg(test)]
543mod client_tests {
544 use super::*;
549 use serde::Deserialize;
550
551 #[derive(Deserialize)]
552 struct Resp;
553
554 fn sample_request() -> DynamicQueryRequest {
555 DynamicQueryRequest::read(
556 read_batch()
557 .var_as(
558 "user",
559 g().n_where(SourcePredicate::eq("username", "alice")),
560 )
561 .returning(["user"]),
562 )
563 }
564
565 #[test]
568 fn new_defaults_to_localhost() {
569 let client = Client::new(None).unwrap();
570 assert_eq!(client.url.as_str(), "http://localhost:6969/v1/query");
571 assert!(client.api_key.is_none());
572 }
573
574 #[test]
575 fn new_parses_custom_url() {
576 let client = Client::new(Some("https://cluster.helix-db.com")).unwrap();
577 assert_eq!(client.url.as_str(), "https://cluster.helix-db.com/v1/query");
578 }
579
580 #[test]
581 fn new_rejects_invalid_url() {
582 let err = Client::new(Some("not a url")).unwrap_err();
583 assert!(matches!(err, HelixError::InvalidURL(_)));
584 }
585
586 #[test]
587 fn with_api_key_sets_and_clears() {
588 let client = Client::new(None).unwrap().with_api_key(Some("hx_secret"));
589 assert_eq!(client.api_key.as_deref(), Some("hx_secret"));
590
591 let cleared = client.with_api_key(None);
592 assert!(cleared.api_key.is_none());
593 }
594
595 #[test]
598 fn query_builder_starts_with_only_content_type() {
599 let client = Client::new(None).unwrap();
600 let builder = client.query::<Resp>();
601 assert_eq!(
602 builder.headers[0],
603 Some(("Content-Type", "application/json"))
604 );
605 assert!(builder.headers[1..].iter().all(Option::is_none));
606 }
607
608 #[test]
609 fn header_toggles_populate_slots() {
610 let client = Client::new(None).unwrap();
611 let builder = client
612 .query::<Resp>()
613 .writer_only()
614 .warm_only()
615 .should_await_durability(true);
616 assert_eq!(builder.headers[1], Some(("x-helix-require-writer", "true")));
617 assert_eq!(builder.headers[2], Some(("x-helix-warm", "true")));
618 assert_eq!(builder.headers[3], Some(("x-helix-await-durable", "true")));
619 }
620
621 #[test]
622 fn should_await_durability_false_sends_false() {
623 let client = Client::new(None).unwrap();
624 let builder = client.query::<Resp>().should_await_durability(false);
625 assert_eq!(builder.headers[3], Some(("x-helix-await-durable", "false")));
626 }
627
628 #[test]
631 fn dynamic_query_sets_query_type() {
632 let client = Client::new(None).unwrap();
633 let request = client.query::<Resp>().dynamic(sample_request());
634 assert!(matches!(request.request.query_type, QueryType::Dynamic(_)));
635 }
636
637 #[test]
638 fn stored_query_sets_query_type() {
639 let client = Client::new(None).unwrap();
640 let request = client.query::<Resp>().stored("add_user".to_string());
641 assert!(
642 matches!(&request.request.query_type, QueryType::Stored(name) if name == "add_user")
643 );
644 }
645
646 #[derive(serde::Serialize)]
647 struct Payload {
648 name: String,
649 }
650
651 #[test]
652 fn body_serializes_payload() {
653 let client = Client::new(None).unwrap();
654 let payload = Payload {
655 name: "alice".to_string(),
656 };
657 let builder = client.query::<Resp>().body(&payload).unwrap();
658 assert_eq!(builder.body, Some(sonic_rs::to_vec(&payload).unwrap()));
659 }
660
661 #[derive(serde::Deserialize)]
664 struct EmptyResp {}
665
666 async fn spawn_capture_server() -> (String, tokio::task::JoinHandle<String>) {
669 use tokio::io::{AsyncReadExt, AsyncWriteExt};
670 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
671 let base = format!("http://{}", listener.local_addr().unwrap());
672 let handle = tokio::spawn(async move {
673 let (mut socket, _) = listener.accept().await.unwrap();
674 let mut buf = [0u8; 4096];
675 let n = socket.read(&mut buf).await.unwrap();
676 let request_line = String::from_utf8_lossy(&buf[..n])
677 .lines()
678 .next()
679 .unwrap()
680 .to_string();
681 let target = request_line.split_whitespace().nth(1).unwrap().to_string();
683 let resp = "HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}";
684 socket.write_all(resp.as_bytes()).await.unwrap();
685 target
686 });
687 (base, handle)
688 }
689
690 #[tokio::test]
691 async fn dynamic_query_posts_to_v1_query() {
692 let (base, handle) = spawn_capture_server().await;
693 let client = Client::new(Some(&base)).unwrap();
694 let _: EmptyResp = client
695 .query()
696 .dynamic(sample_request())
697 .send()
698 .await
699 .unwrap();
700 assert_eq!(handle.await.unwrap(), "/v1/query");
701 }
702
703 #[tokio::test]
704 async fn stored_query_posts_to_named_route() {
705 let (base, handle) = spawn_capture_server().await;
706 let client = Client::new(Some(&base)).unwrap();
707 let _: EmptyResp = client
708 .query()
709 .stored("add_user".to_string())
710 .send()
711 .await
712 .unwrap();
713 assert_eq!(handle.await.unwrap(), "/v1/query/add_user");
714 }
715}