1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8use std::future::Future;
9
10use parking_lot::RwLock;
11
12use super::bolt::PackStreamValue;
13use super::driver::DriverConfig;
14use super::error::{DriverError, DriverResult};
15use super::pool::{ConnectionPool, PooledConnection};
16use super::record::{Record, RecordStream};
17use super::transaction::{Transaction, TransactionConfig};
18use super::types::Value;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum AccessMode {
27 #[default]
29 Read,
30 Write,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct Bookmark {
41 value: String,
43}
44
45impl Bookmark {
46 pub fn new(value: impl Into<String>) -> Self {
48 Self {
49 value: value.into(),
50 }
51 }
52
53 pub fn value(&self) -> &str {
55 &self.value
56 }
57
58 pub fn is_empty(&self) -> bool {
60 self.value.is_empty()
61 }
62
63 pub fn from_bookmarks(bookmarks: &[Bookmark]) -> Self {
65 if bookmarks.is_empty() {
66 Self::new("")
67 } else {
68 bookmarks.last().cloned().unwrap_or_else(|| Self::new(""))
70 }
71 }
72}
73
74impl std::fmt::Display for Bookmark {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(f, "{}", self.value)
77 }
78}
79
80impl From<String> for Bookmark {
81 fn from(s: String) -> Self {
82 Self::new(s)
83 }
84}
85
86impl From<&str> for Bookmark {
87 fn from(s: &str) -> Self {
88 Self::new(s)
89 }
90}
91
92#[derive(Debug, Clone)]
98pub struct SessionConfig {
99 pub database: Option<String>,
101 pub fetch_size: usize,
103 pub default_access_mode: AccessMode,
105 pub bookmarks: Vec<Bookmark>,
107 pub impersonated_user: Option<String>,
109}
110
111impl SessionConfig {
112 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn builder() -> SessionConfigBuilder {
119 SessionConfigBuilder::new()
120 }
121
122 pub fn with_database(mut self, database: impl Into<String>) -> Self {
124 self.database = Some(database.into());
125 self
126 }
127
128 pub fn with_fetch_size(mut self, size: usize) -> Self {
130 self.fetch_size = size;
131 self
132 }
133
134 pub fn with_access_mode(mut self, mode: AccessMode) -> Self {
136 self.default_access_mode = mode;
137 self
138 }
139
140 pub fn with_bookmarks(mut self, bookmarks: Vec<Bookmark>) -> Self {
142 self.bookmarks = bookmarks;
143 self
144 }
145}
146
147impl Default for SessionConfig {
148 fn default() -> Self {
149 Self {
150 database: None,
151 fetch_size: 1000,
152 default_access_mode: AccessMode::Write,
153 bookmarks: Vec::new(),
154 impersonated_user: None,
155 }
156 }
157}
158
159#[derive(Debug, Default)]
165pub struct SessionConfigBuilder {
166 config: SessionConfig,
167}
168
169impl SessionConfigBuilder {
170 pub fn new() -> Self {
172 Self::default()
173 }
174
175 pub fn with_database(mut self, database: impl Into<String>) -> Self {
177 self.config.database = Some(database.into());
178 self
179 }
180
181 pub fn with_fetch_size(mut self, size: usize) -> Self {
183 self.config.fetch_size = size;
184 self
185 }
186
187 pub fn with_read_access(mut self) -> Self {
189 self.config.default_access_mode = AccessMode::Read;
190 self
191 }
192
193 pub fn with_write_access(mut self) -> Self {
195 self.config.default_access_mode = AccessMode::Write;
196 self
197 }
198
199 pub fn with_bookmarks(mut self, bookmarks: Vec<Bookmark>) -> Self {
201 self.config.bookmarks = bookmarks;
202 self
203 }
204
205 pub fn with_bookmark(mut self, bookmark: Bookmark) -> Self {
207 self.config.bookmarks.push(bookmark);
208 self
209 }
210
211 pub fn with_impersonated_user(mut self, user: impl Into<String>) -> Self {
213 self.config.impersonated_user = Some(user.into());
214 self
215 }
216
217 pub fn build(self) -> SessionConfig {
219 self.config
220 }
221}
222
223#[derive(Debug, Clone)]
229pub struct Query {
230 pub text: String,
232 pub parameters: HashMap<String, Value>,
234}
235
236impl Query {
237 pub fn new(text: impl Into<String>) -> Self {
239 Self {
240 text: text.into(),
241 parameters: HashMap::new(),
242 }
243 }
244
245 pub fn with_param(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
247 self.parameters.insert(key.into(), value.into());
248 self
249 }
250
251 pub fn with_params(mut self, params: HashMap<String, Value>) -> Self {
253 self.parameters.extend(params);
254 self
255 }
256}
257
258impl From<&str> for Query {
259 fn from(s: &str) -> Self {
260 Self::new(s)
261 }
262}
263
264impl From<String> for Query {
265 fn from(s: String) -> Self {
266 Self::new(s)
267 }
268}
269
270#[derive(Debug, Clone, Default)]
276pub struct ResultSummary {
277 pub query: Option<Query>,
279 pub query_type: QueryType,
281 pub counters: Counters,
283 pub result_available_after: Duration,
285 pub result_consumed_after: Duration,
287 pub database: Option<String>,
289 pub server: Option<String>,
291 pub notifications: Vec<Notification>,
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
297pub enum QueryType {
298 #[default]
300 ReadOnly,
301 ReadWrite,
303 WriteOnly,
305 SchemaWrite,
307}
308
309#[derive(Debug, Clone, Default)]
311pub struct Counters {
312 pub nodes_created: i64,
314 pub nodes_deleted: i64,
316 pub relationships_created: i64,
318 pub relationships_deleted: i64,
320 pub properties_set: i64,
322 pub labels_added: i64,
324 pub labels_removed: i64,
326 pub indexes_added: i64,
328 pub indexes_removed: i64,
330 pub constraints_added: i64,
332 pub constraints_removed: i64,
334}
335
336impl Counters {
337 pub fn contains_updates(&self) -> bool {
339 self.nodes_created > 0
340 || self.nodes_deleted > 0
341 || self.relationships_created > 0
342 || self.relationships_deleted > 0
343 || self.properties_set > 0
344 || self.labels_added > 0
345 || self.labels_removed > 0
346 }
347
348 pub fn contains_system_updates(&self) -> bool {
350 self.indexes_added > 0
351 || self.indexes_removed > 0
352 || self.constraints_added > 0
353 || self.constraints_removed > 0
354 }
355}
356
357#[derive(Debug, Clone)]
359pub struct Notification {
360 pub code: String,
362 pub title: String,
364 pub description: String,
366 pub severity: String,
368 pub position: Option<InputPosition>,
370}
371
372#[derive(Debug, Clone)]
374pub struct InputPosition {
375 pub offset: i64,
377 pub line: i64,
379 pub column: i64,
381}
382
383#[derive(Debug)]
389pub struct QueryResult {
390 pub records: RecordStream,
392 pub keys: Vec<String>,
394 pub summary: ResultSummary,
396}
397
398impl QueryResult {
399 pub fn new(records: Vec<Record>, keys: Vec<String>, summary: ResultSummary) -> Self {
401 Self {
402 records: RecordStream::new(records),
403 keys,
404 summary,
405 }
406 }
407
408 pub fn empty() -> Self {
410 Self {
411 records: RecordStream::empty(),
412 keys: Vec::new(),
413 summary: ResultSummary::default(),
414 }
415 }
416
417 pub fn single(self) -> DriverResult<Record> {
419 self.records.single()
420 }
421
422 pub fn first(self) -> Option<Record> {
424 self.records.first()
425 }
426
427 pub fn collect(self) -> Vec<Record> {
429 self.records.collect_all()
430 }
431}
432
433impl Iterator for QueryResult {
434 type Item = Record;
435
436 fn next(&mut self) -> Option<Self::Item> {
437 self.records.next()
438 }
439}
440
441pub struct Session {
447 driver_config: Arc<DriverConfig>,
449 pool: Arc<ConnectionPool>,
451 config: SessionConfig,
453 last_bookmark: RwLock<Option<Bookmark>>,
455 open: RwLock<bool>,
457}
458
459impl Session {
460 pub fn new(
462 driver_config: Arc<DriverConfig>,
463 pool: Arc<ConnectionPool>,
464 config: SessionConfig,
465 ) -> DriverResult<Self> {
466 Ok(Self {
467 driver_config,
468 pool,
469 config,
470 last_bookmark: RwLock::new(None),
471 open: RwLock::new(true),
472 })
473 }
474
475 pub async fn run(
477 &self,
478 query: impl Into<Query>,
479 params: Option<HashMap<String, Value>>,
480 ) -> DriverResult<QueryResult> {
481 self.ensure_open()?;
482
483 let mut query = query.into();
484 if let Some(p) = params {
485 query = query.with_params(p);
486 }
487
488 let mut conn = self.pool.acquire().await?;
489 let result = self.execute_query(&mut conn, &query).await?;
490 conn.return_to_pool();
491
492 Ok(result)
493 }
494
495 pub async fn begin_transaction(
497 &self,
498 config: Option<TransactionConfig>,
499 ) -> DriverResult<Transaction> {
500 self.ensure_open()?;
501
502 let conn = self.pool.acquire().await?;
503 let config = config.unwrap_or_default();
504
505 Transaction::begin(conn, config, self.config.database.clone()).await
506 }
507
508 pub async fn read_transaction<F, Fut, T>(&self, work: F) -> DriverResult<T>
510 where
511 F: Fn(Transaction) -> Fut,
512 Fut: Future<Output = DriverResult<T>>,
513 {
514 self.execute_transaction(AccessMode::Read, work).await
515 }
516
517 pub async fn write_transaction<F, Fut, T>(&self, work: F) -> DriverResult<T>
519 where
520 F: Fn(Transaction) -> Fut,
521 Fut: Future<Output = DriverResult<T>>,
522 {
523 self.execute_transaction(AccessMode::Write, work).await
524 }
525
526 async fn execute_transaction<F, Fut, T>(&self, _mode: AccessMode, work: F) -> DriverResult<T>
528 where
529 F: Fn(Transaction) -> Fut,
530 Fut: Future<Output = DriverResult<T>>,
531 {
532 self.ensure_open()?;
533
534 let max_retry_time = self.driver_config.max_transaction_retry_time;
535 let start = std::time::Instant::now();
536 let mut attempts = 0;
537
538 loop {
539 attempts += 1;
540
541 let tx = self.begin_transaction(None).await?;
542
543 match work(tx).await {
544 Ok(result) => return Ok(result),
545 Err(e) if e.is_retryable() && start.elapsed() < max_retry_time => {
546 let delay = std::cmp::min(
548 Duration::from_millis(100 * attempts),
549 Duration::from_secs(5),
550 );
551 tokio::time::sleep(delay).await;
552 continue;
553 }
554 Err(e) => return Err(e),
555 }
556 }
557 }
558
559 async fn execute_query(&self, conn: &mut PooledConnection, query: &Query) -> DriverResult<QueryResult> {
561 let client = conn.bolt_client_mut()
563 .ok_or_else(|| DriverError::connection("No Bolt connection available"))?;
564
565 let parameters: HashMap<String, PackStreamValue> = query.parameters
567 .iter()
568 .map(|(k, v)| (k.clone(), v.clone().into()))
569 .collect();
570
571 let bolt_result = client.run(&query.text, parameters, self.config.database.as_deref())
573 .await
574 .map_err(|e| DriverError::query("QueryExecutionError", format!("{}", e)))?;
575
576 if let Some(bookmark_str) = &bolt_result.bookmark {
578 let bookmark = Bookmark::new(bookmark_str);
579 *self.last_bookmark.write() = Some(bookmark);
580 }
581
582 let keys = bolt_result.keys.clone();
584
585 let records: Vec<Record> = bolt_result.records
587 .into_iter()
588 .map(|fields| {
589 let values: Vec<Value> = fields.into_iter().map(Into::into).collect();
590 Record::new(keys.clone(), values)
591 })
592 .collect();
593
594 let counters = bolt_result.stats.map(|stats| {
596 Counters {
597 nodes_created: stats.get("nodes-created").and_then(|v| v.as_int()).unwrap_or(0),
598 nodes_deleted: stats.get("nodes-deleted").and_then(|v| v.as_int()).unwrap_or(0),
599 relationships_created: stats.get("relationships-created").and_then(|v| v.as_int()).unwrap_or(0),
600 relationships_deleted: stats.get("relationships-deleted").and_then(|v| v.as_int()).unwrap_or(0),
601 properties_set: stats.get("properties-set").and_then(|v| v.as_int()).unwrap_or(0),
602 labels_added: stats.get("labels-added").and_then(|v| v.as_int()).unwrap_or(0),
603 labels_removed: stats.get("labels-removed").and_then(|v| v.as_int()).unwrap_or(0),
604 indexes_added: stats.get("indexes-added").and_then(|v| v.as_int()).unwrap_or(0),
605 indexes_removed: stats.get("indexes-removed").and_then(|v| v.as_int()).unwrap_or(0),
606 constraints_added: stats.get("constraints-added").and_then(|v| v.as_int()).unwrap_or(0),
607 constraints_removed: stats.get("constraints-removed").and_then(|v| v.as_int()).unwrap_or(0),
608 }
609 }).unwrap_or_default();
610
611 let summary = ResultSummary {
612 query: Some(query.clone()),
613 counters,
614 database: bolt_result.database,
615 ..Default::default()
616 };
617
618 Ok(QueryResult::new(records, keys, summary))
619 }
620
621 pub fn last_bookmark(&self) -> Option<Bookmark> {
623 self.last_bookmark.read().clone()
624 }
625
626 pub fn last_bookmarks(&self) -> Vec<Bookmark> {
628 let mut bookmarks = self.config.bookmarks.clone();
629 if let Some(bookmark) = self.last_bookmark() {
630 bookmarks.push(bookmark);
631 }
632 bookmarks
633 }
634
635 pub async fn close(&self) -> DriverResult<()> {
637 *self.open.write() = false;
638 Ok(())
639 }
640
641 fn ensure_open(&self) -> DriverResult<()> {
643 if *self.open.read() {
644 Ok(())
645 } else {
646 Err(DriverError::session("Session is closed"))
647 }
648 }
649
650 pub fn config(&self) -> &SessionConfig {
652 &self.config
653 }
654}
655
656impl std::fmt::Debug for Session {
657 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658 f.debug_struct("Session")
659 .field("database", &self.config.database)
660 .field("open", &*self.open.read())
661 .finish()
662 }
663}
664
665#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_access_mode() {
675 assert_eq!(AccessMode::default(), AccessMode::Read);
676 assert_ne!(AccessMode::Read, AccessMode::Write);
677 }
678
679 #[test]
680 fn test_bookmark() {
681 let bookmark = Bookmark::new("zeta4g:bookmark:v1:tx123");
682 assert_eq!(bookmark.value(), "zeta4g:bookmark:v1:tx123");
683 assert!(!bookmark.is_empty());
684
685 let empty = Bookmark::new("");
686 assert!(empty.is_empty());
687 }
688
689 #[test]
690 fn test_bookmark_from() {
691 let b1: Bookmark = "bookmark1".into();
692 assert_eq!(b1.value(), "bookmark1");
693
694 let b2: Bookmark = String::from("bookmark2").into();
695 assert_eq!(b2.value(), "bookmark2");
696 }
697
698 #[test]
699 fn test_bookmark_from_bookmarks() {
700 let bookmarks = vec![
701 Bookmark::new("b1"),
702 Bookmark::new("b2"),
703 Bookmark::new("b3"),
704 ];
705
706 let combined = Bookmark::from_bookmarks(&bookmarks);
707 assert_eq!(combined.value(), "b3"); let empty = Bookmark::from_bookmarks(&[]);
710 assert!(empty.is_empty());
711 }
712
713 #[test]
714 fn test_session_config() {
715 let config = SessionConfig::new()
716 .with_database("mydb")
717 .with_fetch_size(500)
718 .with_access_mode(AccessMode::Read);
719
720 assert_eq!(config.database, Some("mydb".to_string()));
721 assert_eq!(config.fetch_size, 500);
722 assert_eq!(config.default_access_mode, AccessMode::Read);
723 }
724
725 #[test]
726 fn test_session_config_builder() {
727 let config = SessionConfig::builder()
728 .with_database("mydb")
729 .with_fetch_size(500)
730 .with_read_access()
731 .with_bookmark(Bookmark::new("b1"))
732 .build();
733
734 assert_eq!(config.database, Some("mydb".to_string()));
735 assert_eq!(config.fetch_size, 500);
736 assert_eq!(config.default_access_mode, AccessMode::Read);
737 assert_eq!(config.bookmarks.len(), 1);
738 }
739
740 #[test]
741 fn test_query() {
742 let query = Query::new("MATCH (n) RETURN n")
743 .with_param("name", "Alice")
744 .with_param("age", 30i64);
745
746 assert_eq!(query.text, "MATCH (n) RETURN n");
747 assert_eq!(query.parameters.len(), 2);
748 assert_eq!(query.parameters.get("name"), Some(&Value::String("Alice".into())));
749 assert_eq!(query.parameters.get("age"), Some(&Value::Integer(30)));
750 }
751
752 #[test]
753 fn test_query_from() {
754 let q1: Query = "RETURN 1".into();
755 assert_eq!(q1.text, "RETURN 1");
756
757 let q2: Query = String::from("RETURN 2").into();
758 assert_eq!(q2.text, "RETURN 2");
759 }
760
761 #[test]
762 fn test_counters() {
763 let counters = Counters {
764 nodes_created: 1,
765 ..Default::default()
766 };
767
768 assert!(counters.contains_updates());
769 assert!(!counters.contains_system_updates());
770
771 let schema_counters = Counters {
772 indexes_added: 1,
773 ..Default::default()
774 };
775
776 assert!(schema_counters.contains_system_updates());
777 }
778
779 #[test]
780 fn test_query_result_empty() {
781 let result = QueryResult::empty();
782 assert!(result.keys.is_empty());
783 }
784
785 #[test]
786 fn test_query_result_collect() {
787 let records = vec![
788 Record::new(vec!["n".into()], vec![Value::Integer(1)]),
789 Record::new(vec!["n".into()], vec![Value::Integer(2)]),
790 ];
791
792 let result = QueryResult::new(records, vec!["n".into()], ResultSummary::default());
793 let collected = result.collect();
794
795 assert_eq!(collected.len(), 2);
796 }
797
798 #[test]
799 fn test_result_summary() {
800 let summary = ResultSummary {
801 query_type: QueryType::ReadWrite,
802 counters: Counters {
803 nodes_created: 5,
804 ..Default::default()
805 },
806 database: Some("zeta4g".to_string()),
807 ..Default::default()
808 };
809
810 assert_eq!(summary.query_type, QueryType::ReadWrite);
811 assert!(summary.counters.contains_updates());
812 assert_eq!(summary.database, Some("zeta4g".to_string()));
813 }
814
815 #[test]
816 fn test_notification() {
817 let notification = Notification {
818 code: "Neo.ClientNotification.Statement.UnknownLabelWarning".into(),
819 title: "Unknown label".into(),
820 description: "Label 'Foo' does not exist".into(),
821 severity: "WARNING".into(),
822 position: Some(InputPosition {
823 offset: 10,
824 line: 1,
825 column: 11,
826 }),
827 };
828
829 assert!(notification.code.contains("Warning"));
830 assert_eq!(notification.severity, "WARNING");
831 }
832}