llm_memory_graph/query/async_query.rs
1//! Async query builder with streaming support for memory-efficient queries
2//!
3//! This module provides a fluent API for building and executing async queries
4//! over the graph data with support for streaming large result sets.
5
6use crate::error::Result;
7use crate::storage::AsyncStorageBackend;
8use crate::types::{Node, NodeType, SessionId};
9use chrono::{DateTime, Utc};
10use futures::stream::Stream;
11use std::pin::Pin;
12use std::sync::Arc;
13
14/// Builder for constructing async queries over the graph
15///
16/// Provides a fluent API for filtering and executing queries asynchronously.
17/// Supports both batch loading and streaming for memory-efficient processing.
18///
19/// # Examples
20///
21/// ```no_run
22/// use llm_memory_graph::query::AsyncQueryBuilder;
23/// use llm_memory_graph::types::NodeType;
24/// use futures::stream::StreamExt;
25///
26/// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
27/// // Query with filters
28/// let nodes = builder
29/// .node_type(NodeType::Prompt)
30/// .limit(100)
31/// .execute()
32/// .await?;
33/// # Ok(())
34/// # }
35///
36/// # async fn example2(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
37/// // Stream large result sets
38/// let mut stream = builder.execute_stream();
39/// while let Some(node) = stream.next().await {
40/// // Process node...
41/// }
42/// # Ok(())
43/// # }
44/// ```
45pub struct AsyncQueryBuilder {
46 storage: Arc<dyn AsyncStorageBackend>,
47 session_filter: Option<SessionId>,
48 node_type_filter: Option<NodeType>,
49 time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
50 limit: Option<usize>,
51 offset: usize,
52}
53
54impl AsyncQueryBuilder {
55 /// Create a new async query builder
56 ///
57 /// # Examples
58 ///
59 /// ```no_run
60 /// use llm_memory_graph::query::AsyncQueryBuilder;
61 /// use llm_memory_graph::storage::AsyncSledBackend;
62 /// use std::sync::Arc;
63 ///
64 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
65 /// let backend = AsyncSledBackend::open("./data/graph.db").await?;
66 /// let builder = AsyncQueryBuilder::new(Arc::new(backend));
67 /// # Ok(())
68 /// # }
69 /// ```
70 pub fn new(storage: Arc<dyn AsyncStorageBackend>) -> Self {
71 Self {
72 storage,
73 session_filter: None,
74 node_type_filter: None,
75 time_range: None,
76 limit: None,
77 offset: 0,
78 }
79 }
80
81 /// Filter by session ID
82 ///
83 /// # Examples
84 ///
85 /// ```no_run
86 /// # use llm_memory_graph::query::AsyncQueryBuilder;
87 /// # use llm_memory_graph::types::SessionId;
88 /// # async fn example(builder: AsyncQueryBuilder, session_id: SessionId) -> Result<(), Box<dyn std::error::Error>> {
89 /// let nodes = builder
90 /// .session(session_id)
91 /// .execute()
92 /// .await?;
93 /// # Ok(())
94 /// # }
95 /// ```
96 pub fn session(mut self, session_id: SessionId) -> Self {
97 self.session_filter = Some(session_id);
98 self
99 }
100
101 /// Filter by node type
102 ///
103 /// # Examples
104 ///
105 /// ```no_run
106 /// # use llm_memory_graph::query::AsyncQueryBuilder;
107 /// # use llm_memory_graph::types::NodeType;
108 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
109 /// let prompts = builder
110 /// .node_type(NodeType::Prompt)
111 /// .execute()
112 /// .await?;
113 /// # Ok(())
114 /// # }
115 /// ```
116 pub fn node_type(mut self, node_type: NodeType) -> Self {
117 self.node_type_filter = Some(node_type);
118 self
119 }
120
121 /// Filter by time range (inclusive)
122 ///
123 /// # Examples
124 ///
125 /// ```no_run
126 /// # use llm_memory_graph::query::AsyncQueryBuilder;
127 /// # use chrono::Utc;
128 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
129 /// let start = Utc::now() - chrono::Duration::hours(24);
130 /// let end = Utc::now();
131 ///
132 /// let recent_nodes = builder
133 /// .time_range(start, end)
134 /// .execute()
135 /// .await?;
136 /// # Ok(())
137 /// # }
138 /// ```
139 pub fn time_range(mut self, start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
140 self.time_range = Some((start, end));
141 self
142 }
143
144 /// Limit the number of results
145 ///
146 /// # Examples
147 ///
148 /// ```no_run
149 /// # use llm_memory_graph::query::AsyncQueryBuilder;
150 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
151 /// let first_10 = builder
152 /// .limit(10)
153 /// .execute()
154 /// .await?;
155 /// # Ok(())
156 /// # }
157 /// ```
158 pub fn limit(mut self, limit: usize) -> Self {
159 self.limit = Some(limit);
160 self
161 }
162
163 /// Skip the first N results
164 ///
165 /// # Examples
166 ///
167 /// ```no_run
168 /// # use llm_memory_graph::query::AsyncQueryBuilder;
169 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
170 /// // Get results 11-20 (skip first 10, take next 10)
171 /// let page2 = builder
172 /// .offset(10)
173 /// .limit(10)
174 /// .execute()
175 /// .await?;
176 /// # Ok(())
177 /// # }
178 /// ```
179 pub fn offset(mut self, offset: usize) -> Self {
180 self.offset = offset;
181 self
182 }
183
184 /// Execute the query and return all matching nodes
185 ///
186 /// This loads all results into memory. For large result sets, consider using
187 /// `execute_stream()` instead.
188 ///
189 /// # Examples
190 ///
191 /// ```no_run
192 /// # use llm_memory_graph::query::AsyncQueryBuilder;
193 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
194 /// let nodes = builder.execute().await?;
195 /// println!("Found {} nodes", nodes.len());
196 /// # Ok(())
197 /// # }
198 /// ```
199 pub async fn execute(&self) -> Result<Vec<Node>> {
200 // Get base nodes from session or all nodes
201 let mut nodes = if let Some(session_id) = &self.session_filter {
202 self.storage.get_session_nodes(session_id).await?
203 } else {
204 // For now, we'll need to iterate through sessions
205 // In production, you'd want a more efficient approach
206 vec![]
207 };
208
209 // Apply node type filter
210 if let Some(node_type) = &self.node_type_filter {
211 nodes.retain(|node| node.node_type() == *node_type);
212 }
213
214 // Apply time range filter
215 if let Some((start, end)) = &self.time_range {
216 nodes.retain(|node| {
217 let timestamp = match node {
218 Node::Prompt(p) => p.timestamp,
219 Node::Response(r) => r.timestamp,
220 Node::Session(s) => s.created_at,
221 Node::ToolInvocation(t) => t.timestamp,
222 Node::Agent(a) => a.created_at,
223 Node::Template(t) => t.created_at,
224 };
225 timestamp >= *start && timestamp <= *end
226 });
227 }
228
229 // Sort by timestamp (newest first)
230 nodes.sort_by(|a, b| {
231 let ts_a = match a {
232 Node::Prompt(p) => p.timestamp,
233 Node::Response(r) => r.timestamp,
234 Node::Session(s) => s.created_at,
235 Node::ToolInvocation(t) => t.timestamp,
236 Node::Agent(a) => a.created_at,
237 Node::Template(t) => t.created_at,
238 };
239 let ts_b = match b {
240 Node::Prompt(p) => p.timestamp,
241 Node::Response(r) => r.timestamp,
242 Node::Session(s) => s.created_at,
243 Node::ToolInvocation(t) => t.timestamp,
244 Node::Agent(a) => a.created_at,
245 Node::Template(t) => t.created_at,
246 };
247 ts_b.cmp(&ts_a)
248 });
249
250 // Apply offset
251 let nodes: Vec<_> = nodes.into_iter().skip(self.offset).collect();
252
253 // Apply limit
254 let nodes = if let Some(limit) = self.limit {
255 nodes.into_iter().take(limit).collect()
256 } else {
257 nodes
258 };
259
260 Ok(nodes)
261 }
262
263 /// Execute the query and return a stream of results
264 ///
265 /// This is memory-efficient for large result sets as it processes nodes
266 /// one at a time without loading everything into memory. The stream uses
267 /// storage-level streaming to avoid loading all nodes at once.
268 ///
269 /// # Examples
270 ///
271 /// ```no_run
272 /// # use llm_memory_graph::query::AsyncQueryBuilder;
273 /// # use futures::stream::StreamExt;
274 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
275 /// let mut stream = builder.execute_stream();
276 ///
277 /// let mut count = 0;
278 /// while let Some(result) = stream.next().await {
279 /// match result {
280 /// Ok(node) => {
281 /// // Process node without loading all into memory
282 /// count += 1;
283 /// }
284 /// Err(e) => eprintln!("Error: {}", e),
285 /// }
286 /// }
287 ///
288 /// println!("Processed {} nodes", count);
289 /// # Ok(())
290 /// # }
291 /// ```
292 pub fn execute_stream(&self) -> Pin<Box<dyn Stream<Item = Result<Node>> + Send + '_>> {
293 use futures::StreamExt;
294
295 let session_filter = self.session_filter;
296 let node_type_filter = self.node_type_filter.clone();
297 let time_range = self.time_range;
298 let limit = self.limit;
299 let offset = self.offset;
300
301 Box::pin(async_stream::stream! {
302 // Use storage-level streaming for better memory efficiency
303 let mut stream = if let Some(session_id) = session_filter {
304 self.storage.get_session_nodes_stream(&session_id)
305 } else {
306 // Empty stream if no session filter
307 Box::pin(futures::stream::empty()) as Pin<Box<dyn Stream<Item = Result<Node>> + Send + '_>>
308 };
309
310 // Apply filters and stream results
311 let mut skipped = 0;
312 let mut emitted = 0;
313
314 while let Some(result) = stream.next().await {
315 let node = match result {
316 Ok(n) => n,
317 Err(e) => {
318 yield Err(e);
319 continue;
320 }
321 };
322
323 // Apply node type filter
324 if let Some(ref nt) = node_type_filter {
325 if node.node_type() != *nt {
326 continue;
327 }
328 }
329
330 // Apply time range filter
331 if let Some((start, end)) = time_range {
332 let timestamp = match &node {
333 Node::Prompt(p) => p.timestamp,
334 Node::Response(r) => r.timestamp,
335 Node::Session(s) => s.created_at,
336 Node::ToolInvocation(t) => t.timestamp,
337 Node::Agent(a) => a.created_at,
338 Node::Template(t) => t.created_at,
339 };
340
341 if timestamp < start || timestamp > end {
342 continue;
343 }
344 }
345
346 // Apply offset
347 if skipped < offset {
348 skipped += 1;
349 continue;
350 }
351
352 // Apply limit
353 if let Some(lim) = limit {
354 if emitted >= lim {
355 break;
356 }
357 }
358
359 emitted += 1;
360 yield Ok(node);
361 }
362 })
363 }
364
365 /// Count the number of matching nodes without loading them
366 ///
367 /// This is more efficient than `execute().await?.len()` for large result sets
368 /// as it uses storage-level counting when possible.
369 ///
370 /// # Examples
371 ///
372 /// ```no_run
373 /// # use llm_memory_graph::query::AsyncQueryBuilder;
374 /// # use llm_memory_graph::types::NodeType;
375 /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
376 /// let prompt_count = builder
377 /// .node_type(NodeType::Prompt)
378 /// .count()
379 /// .await?;
380 ///
381 /// println!("Total prompts: {}", prompt_count);
382 /// # Ok(())
383 /// # }
384 /// ```
385 pub async fn count(&self) -> Result<usize> {
386 use futures::StreamExt;
387
388 // If we only have a session filter and no other filters, use efficient count
389 if self.session_filter.is_some()
390 && self.node_type_filter.is_none()
391 && self.time_range.is_none()
392 && self.offset == 0
393 && self.limit.is_none()
394 {
395 return self
396 .storage
397 .count_session_nodes(&self.session_filter.unwrap())
398 .await;
399 }
400
401 // Otherwise, stream and count to avoid loading all into memory
402 let mut stream = self.execute_stream();
403 let mut count = 0;
404 while let Some(result) = stream.next().await {
405 result?;
406 count += 1;
407 }
408 Ok(count)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::storage::AsyncSledBackend;
416 use crate::types::{ConversationSession, PromptNode};
417 use futures::stream::StreamExt;
418 use tempfile::tempdir;
419
420 #[tokio::test]
421 async fn test_query_builder_creation() {
422 let dir = tempdir().unwrap();
423 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
424 let builder = AsyncQueryBuilder::new(
425 Arc::new(backend) as Arc<dyn crate::storage::AsyncStorageBackend>
426 );
427
428 let results = builder.execute().await.unwrap();
429 assert_eq!(results.len(), 0);
430 }
431
432 #[tokio::test]
433 async fn test_query_with_session_filter() {
434 let dir = tempdir().unwrap();
435 let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
436 as Arc<dyn crate::storage::AsyncStorageBackend>;
437
438 // Create test data
439 let session = ConversationSession::new();
440 backend
441 .store_node(&Node::Session(session.clone()))
442 .await
443 .unwrap();
444
445 for i in 0..5 {
446 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
447 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
448 }
449
450 // Query with session filter
451 let builder = AsyncQueryBuilder::new(backend);
452 let results = builder.session(session.id).execute().await.unwrap();
453
454 assert_eq!(results.len(), 6); // 1 session + 5 prompts
455 }
456
457 #[tokio::test]
458 async fn test_query_with_node_type_filter() {
459 let dir = tempdir().unwrap();
460 let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
461 as Arc<dyn crate::storage::AsyncStorageBackend>;
462
463 let session = ConversationSession::new();
464 backend
465 .store_node(&Node::Session(session.clone()))
466 .await
467 .unwrap();
468
469 for i in 0..3 {
470 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
471 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
472 }
473
474 // Query only prompts
475 let builder = AsyncQueryBuilder::new(backend);
476 let results = builder
477 .session(session.id)
478 .node_type(NodeType::Prompt)
479 .execute()
480 .await
481 .unwrap();
482
483 assert_eq!(results.len(), 3);
484 for node in results {
485 assert!(matches!(node, Node::Prompt(_)));
486 }
487 }
488
489 #[tokio::test]
490 async fn test_query_with_limit_and_offset() {
491 let dir = tempdir().unwrap();
492 let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
493 as Arc<dyn crate::storage::AsyncStorageBackend>;
494
495 let session = ConversationSession::new();
496 backend
497 .store_node(&Node::Session(session.clone()))
498 .await
499 .unwrap();
500
501 for i in 0..10 {
502 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
503 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
504 }
505
506 // Test limit
507 let builder = AsyncQueryBuilder::new(Arc::clone(&backend));
508 let results = builder
509 .session(session.id)
510 .node_type(NodeType::Prompt)
511 .limit(5)
512 .execute()
513 .await
514 .unwrap();
515 assert_eq!(results.len(), 5);
516
517 // Test offset + limit (pagination)
518 let builder = AsyncQueryBuilder::new(backend);
519 let results = builder
520 .session(session.id)
521 .node_type(NodeType::Prompt)
522 .offset(5)
523 .limit(3)
524 .execute()
525 .await
526 .unwrap();
527 assert_eq!(results.len(), 3);
528 }
529
530 #[tokio::test]
531 async fn test_query_streaming() {
532 let dir = tempdir().unwrap();
533 let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
534 as Arc<dyn crate::storage::AsyncStorageBackend>;
535
536 let session = ConversationSession::new();
537 backend
538 .store_node(&Node::Session(session.clone()))
539 .await
540 .unwrap();
541
542 for i in 0..10 {
543 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
544 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
545 }
546
547 // Stream results
548 let query = AsyncQueryBuilder::new(backend)
549 .session(session.id)
550 .node_type(NodeType::Prompt);
551 let mut stream = query.execute_stream();
552
553 let mut count = 0;
554 while let Some(result) = stream.next().await {
555 result.unwrap();
556 count += 1;
557 }
558
559 assert_eq!(count, 10);
560 }
561
562 #[tokio::test]
563 async fn test_query_count() {
564 let dir = tempdir().unwrap();
565 let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
566 as Arc<dyn crate::storage::AsyncStorageBackend>;
567
568 let session = ConversationSession::new();
569 backend
570 .store_node(&Node::Session(session.clone()))
571 .await
572 .unwrap();
573
574 for i in 0..7 {
575 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
576 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
577 }
578
579 // Count prompts
580 let builder = AsyncQueryBuilder::new(backend);
581 let count = builder
582 .session(session.id)
583 .node_type(NodeType::Prompt)
584 .count()
585 .await
586 .unwrap();
587
588 assert_eq!(count, 7);
589 }
590
591 #[tokio::test]
592 async fn test_streaming_with_limit() {
593 let dir = tempdir().unwrap();
594 let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
595 as Arc<dyn crate::storage::AsyncStorageBackend>;
596
597 let session = ConversationSession::new();
598 backend
599 .store_node(&Node::Session(session.clone()))
600 .await
601 .unwrap();
602
603 for i in 0..20 {
604 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
605 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
606 }
607
608 // Stream with limit
609 let query = AsyncQueryBuilder::new(backend)
610 .session(session.id)
611 .node_type(NodeType::Prompt)
612 .limit(5);
613 let mut stream = query.execute_stream();
614
615 let mut count = 0;
616 while let Some(result) = stream.next().await {
617 result.unwrap();
618 count += 1;
619 }
620
621 assert_eq!(count, 5);
622 }
623}