llm_memory_graph/query/
async_query.rs

1//! Async query builder with streaming support for memory-efficient queries
2//!
3//! This module provides a fluent API for building and executing async queries
4//! over the graph data with support for streaming large result sets.
5
6use crate::error::Result;
7use crate::storage::AsyncStorageBackend;
8use crate::types::{Node, NodeType, SessionId};
9use chrono::{DateTime, Utc};
10use futures::stream::Stream;
11use std::pin::Pin;
12use std::sync::Arc;
13
14/// Builder for constructing async queries over the graph
15///
16/// Provides a fluent API for filtering and executing queries asynchronously.
17/// Supports both batch loading and streaming for memory-efficient processing.
18///
19/// # Examples
20///
21/// ```no_run
22/// use llm_memory_graph::query::AsyncQueryBuilder;
23/// use llm_memory_graph::types::NodeType;
24/// use futures::stream::StreamExt;
25///
26/// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
27/// // Query with filters
28/// let nodes = builder
29///     .node_type(NodeType::Prompt)
30///     .limit(100)
31///     .execute()
32///     .await?;
33/// # Ok(())
34/// # }
35///
36/// # async fn example2(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
37/// // Stream large result sets
38/// let mut stream = builder.execute_stream();
39/// while let Some(node) = stream.next().await {
40///     // Process node...
41/// }
42/// # Ok(())
43/// # }
44/// ```
45pub struct AsyncQueryBuilder {
46    storage: Arc<dyn AsyncStorageBackend>,
47    session_filter: Option<SessionId>,
48    node_type_filter: Option<NodeType>,
49    time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
50    limit: Option<usize>,
51    offset: usize,
52}
53
54impl AsyncQueryBuilder {
55    /// Create a new async query builder
56    ///
57    /// # Examples
58    ///
59    /// ```no_run
60    /// use llm_memory_graph::query::AsyncQueryBuilder;
61    /// use llm_memory_graph::storage::AsyncSledBackend;
62    /// use std::sync::Arc;
63    ///
64    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
65    /// let backend = AsyncSledBackend::open("./data/graph.db").await?;
66    /// let builder = AsyncQueryBuilder::new(Arc::new(backend));
67    /// # Ok(())
68    /// # }
69    /// ```
70    pub fn new(storage: Arc<dyn AsyncStorageBackend>) -> Self {
71        Self {
72            storage,
73            session_filter: None,
74            node_type_filter: None,
75            time_range: None,
76            limit: None,
77            offset: 0,
78        }
79    }
80
81    /// Filter by session ID
82    ///
83    /// # Examples
84    ///
85    /// ```no_run
86    /// # use llm_memory_graph::query::AsyncQueryBuilder;
87    /// # use llm_memory_graph::types::SessionId;
88    /// # async fn example(builder: AsyncQueryBuilder, session_id: SessionId) -> Result<(), Box<dyn std::error::Error>> {
89    /// let nodes = builder
90    ///     .session(session_id)
91    ///     .execute()
92    ///     .await?;
93    /// # Ok(())
94    /// # }
95    /// ```
96    pub fn session(mut self, session_id: SessionId) -> Self {
97        self.session_filter = Some(session_id);
98        self
99    }
100
101    /// Filter by node type
102    ///
103    /// # Examples
104    ///
105    /// ```no_run
106    /// # use llm_memory_graph::query::AsyncQueryBuilder;
107    /// # use llm_memory_graph::types::NodeType;
108    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
109    /// let prompts = builder
110    ///     .node_type(NodeType::Prompt)
111    ///     .execute()
112    ///     .await?;
113    /// # Ok(())
114    /// # }
115    /// ```
116    pub fn node_type(mut self, node_type: NodeType) -> Self {
117        self.node_type_filter = Some(node_type);
118        self
119    }
120
121    /// Filter by time range (inclusive)
122    ///
123    /// # Examples
124    ///
125    /// ```no_run
126    /// # use llm_memory_graph::query::AsyncQueryBuilder;
127    /// # use chrono::Utc;
128    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
129    /// let start = Utc::now() - chrono::Duration::hours(24);
130    /// let end = Utc::now();
131    ///
132    /// let recent_nodes = builder
133    ///     .time_range(start, end)
134    ///     .execute()
135    ///     .await?;
136    /// # Ok(())
137    /// # }
138    /// ```
139    pub fn time_range(mut self, start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
140        self.time_range = Some((start, end));
141        self
142    }
143
144    /// Limit the number of results
145    ///
146    /// # Examples
147    ///
148    /// ```no_run
149    /// # use llm_memory_graph::query::AsyncQueryBuilder;
150    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
151    /// let first_10 = builder
152    ///     .limit(10)
153    ///     .execute()
154    ///     .await?;
155    /// # Ok(())
156    /// # }
157    /// ```
158    pub fn limit(mut self, limit: usize) -> Self {
159        self.limit = Some(limit);
160        self
161    }
162
163    /// Skip the first N results
164    ///
165    /// # Examples
166    ///
167    /// ```no_run
168    /// # use llm_memory_graph::query::AsyncQueryBuilder;
169    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
170    /// // Get results 11-20 (skip first 10, take next 10)
171    /// let page2 = builder
172    ///     .offset(10)
173    ///     .limit(10)
174    ///     .execute()
175    ///     .await?;
176    /// # Ok(())
177    /// # }
178    /// ```
179    pub fn offset(mut self, offset: usize) -> Self {
180        self.offset = offset;
181        self
182    }
183
184    /// Execute the query and return all matching nodes
185    ///
186    /// This loads all results into memory. For large result sets, consider using
187    /// `execute_stream()` instead.
188    ///
189    /// # Examples
190    ///
191    /// ```no_run
192    /// # use llm_memory_graph::query::AsyncQueryBuilder;
193    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
194    /// let nodes = builder.execute().await?;
195    /// println!("Found {} nodes", nodes.len());
196    /// # Ok(())
197    /// # }
198    /// ```
199    pub async fn execute(&self) -> Result<Vec<Node>> {
200        // Get base nodes from session or all nodes
201        let mut nodes = if let Some(session_id) = &self.session_filter {
202            self.storage.get_session_nodes(session_id).await?
203        } else {
204            // For now, we'll need to iterate through sessions
205            // In production, you'd want a more efficient approach
206            vec![]
207        };
208
209        // Apply node type filter
210        if let Some(node_type) = &self.node_type_filter {
211            nodes.retain(|node| node.node_type() == *node_type);
212        }
213
214        // Apply time range filter
215        if let Some((start, end)) = &self.time_range {
216            nodes.retain(|node| {
217                let timestamp = match node {
218                    Node::Prompt(p) => p.timestamp,
219                    Node::Response(r) => r.timestamp,
220                    Node::Session(s) => s.created_at,
221                    Node::ToolInvocation(t) => t.timestamp,
222                    Node::Agent(a) => a.created_at,
223                    Node::Template(t) => t.created_at,
224                };
225                timestamp >= *start && timestamp <= *end
226            });
227        }
228
229        // Sort by timestamp (newest first)
230        nodes.sort_by(|a, b| {
231            let ts_a = match a {
232                Node::Prompt(p) => p.timestamp,
233                Node::Response(r) => r.timestamp,
234                Node::Session(s) => s.created_at,
235                Node::ToolInvocation(t) => t.timestamp,
236                Node::Agent(a) => a.created_at,
237                Node::Template(t) => t.created_at,
238            };
239            let ts_b = match b {
240                Node::Prompt(p) => p.timestamp,
241                Node::Response(r) => r.timestamp,
242                Node::Session(s) => s.created_at,
243                Node::ToolInvocation(t) => t.timestamp,
244                Node::Agent(a) => a.created_at,
245                Node::Template(t) => t.created_at,
246            };
247            ts_b.cmp(&ts_a)
248        });
249
250        // Apply offset
251        let nodes: Vec<_> = nodes.into_iter().skip(self.offset).collect();
252
253        // Apply limit
254        let nodes = if let Some(limit) = self.limit {
255            nodes.into_iter().take(limit).collect()
256        } else {
257            nodes
258        };
259
260        Ok(nodes)
261    }
262
263    /// Execute the query and return a stream of results
264    ///
265    /// This is memory-efficient for large result sets as it processes nodes
266    /// one at a time without loading everything into memory. The stream uses
267    /// storage-level streaming to avoid loading all nodes at once.
268    ///
269    /// # Examples
270    ///
271    /// ```no_run
272    /// # use llm_memory_graph::query::AsyncQueryBuilder;
273    /// # use futures::stream::StreamExt;
274    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
275    /// let mut stream = builder.execute_stream();
276    ///
277    /// let mut count = 0;
278    /// while let Some(result) = stream.next().await {
279    ///     match result {
280    ///         Ok(node) => {
281    ///             // Process node without loading all into memory
282    ///             count += 1;
283    ///         }
284    ///         Err(e) => eprintln!("Error: {}", e),
285    ///     }
286    /// }
287    ///
288    /// println!("Processed {} nodes", count);
289    /// # Ok(())
290    /// # }
291    /// ```
292    pub fn execute_stream(&self) -> Pin<Box<dyn Stream<Item = Result<Node>> + Send + '_>> {
293        use futures::StreamExt;
294
295        let session_filter = self.session_filter;
296        let node_type_filter = self.node_type_filter.clone();
297        let time_range = self.time_range;
298        let limit = self.limit;
299        let offset = self.offset;
300
301        Box::pin(async_stream::stream! {
302            // Use storage-level streaming for better memory efficiency
303            let mut stream = if let Some(session_id) = session_filter {
304                self.storage.get_session_nodes_stream(&session_id)
305            } else {
306                // Empty stream if no session filter
307                Box::pin(futures::stream::empty()) as Pin<Box<dyn Stream<Item = Result<Node>> + Send + '_>>
308            };
309
310            // Apply filters and stream results
311            let mut skipped = 0;
312            let mut emitted = 0;
313
314            while let Some(result) = stream.next().await {
315                let node = match result {
316                    Ok(n) => n,
317                    Err(e) => {
318                        yield Err(e);
319                        continue;
320                    }
321                };
322
323                // Apply node type filter
324                if let Some(ref nt) = node_type_filter {
325                    if node.node_type() != *nt {
326                        continue;
327                    }
328                }
329
330                // Apply time range filter
331                if let Some((start, end)) = time_range {
332                    let timestamp = match &node {
333                        Node::Prompt(p) => p.timestamp,
334                        Node::Response(r) => r.timestamp,
335                        Node::Session(s) => s.created_at,
336                        Node::ToolInvocation(t) => t.timestamp,
337                        Node::Agent(a) => a.created_at,
338                        Node::Template(t) => t.created_at,
339                    };
340
341                    if timestamp < start || timestamp > end {
342                        continue;
343                    }
344                }
345
346                // Apply offset
347                if skipped < offset {
348                    skipped += 1;
349                    continue;
350                }
351
352                // Apply limit
353                if let Some(lim) = limit {
354                    if emitted >= lim {
355                        break;
356                    }
357                }
358
359                emitted += 1;
360                yield Ok(node);
361            }
362        })
363    }
364
365    /// Count the number of matching nodes without loading them
366    ///
367    /// This is more efficient than `execute().await?.len()` for large result sets
368    /// as it uses storage-level counting when possible.
369    ///
370    /// # Examples
371    ///
372    /// ```no_run
373    /// # use llm_memory_graph::query::AsyncQueryBuilder;
374    /// # use llm_memory_graph::types::NodeType;
375    /// # async fn example(builder: AsyncQueryBuilder) -> Result<(), Box<dyn std::error::Error>> {
376    /// let prompt_count = builder
377    ///     .node_type(NodeType::Prompt)
378    ///     .count()
379    ///     .await?;
380    ///
381    /// println!("Total prompts: {}", prompt_count);
382    /// # Ok(())
383    /// # }
384    /// ```
385    pub async fn count(&self) -> Result<usize> {
386        use futures::StreamExt;
387
388        // If we only have a session filter and no other filters, use efficient count
389        if self.session_filter.is_some()
390            && self.node_type_filter.is_none()
391            && self.time_range.is_none()
392            && self.offset == 0
393            && self.limit.is_none()
394        {
395            return self
396                .storage
397                .count_session_nodes(&self.session_filter.unwrap())
398                .await;
399        }
400
401        // Otherwise, stream and count to avoid loading all into memory
402        let mut stream = self.execute_stream();
403        let mut count = 0;
404        while let Some(result) = stream.next().await {
405            result?;
406            count += 1;
407        }
408        Ok(count)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::storage::AsyncSledBackend;
416    use crate::types::{ConversationSession, PromptNode};
417    use futures::stream::StreamExt;
418    use tempfile::tempdir;
419
420    #[tokio::test]
421    async fn test_query_builder_creation() {
422        let dir = tempdir().unwrap();
423        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
424        let builder = AsyncQueryBuilder::new(
425            Arc::new(backend) as Arc<dyn crate::storage::AsyncStorageBackend>
426        );
427
428        let results = builder.execute().await.unwrap();
429        assert_eq!(results.len(), 0);
430    }
431
432    #[tokio::test]
433    async fn test_query_with_session_filter() {
434        let dir = tempdir().unwrap();
435        let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
436            as Arc<dyn crate::storage::AsyncStorageBackend>;
437
438        // Create test data
439        let session = ConversationSession::new();
440        backend
441            .store_node(&Node::Session(session.clone()))
442            .await
443            .unwrap();
444
445        for i in 0..5 {
446            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
447            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
448        }
449
450        // Query with session filter
451        let builder = AsyncQueryBuilder::new(backend);
452        let results = builder.session(session.id).execute().await.unwrap();
453
454        assert_eq!(results.len(), 6); // 1 session + 5 prompts
455    }
456
457    #[tokio::test]
458    async fn test_query_with_node_type_filter() {
459        let dir = tempdir().unwrap();
460        let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
461            as Arc<dyn crate::storage::AsyncStorageBackend>;
462
463        let session = ConversationSession::new();
464        backend
465            .store_node(&Node::Session(session.clone()))
466            .await
467            .unwrap();
468
469        for i in 0..3 {
470            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
471            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
472        }
473
474        // Query only prompts
475        let builder = AsyncQueryBuilder::new(backend);
476        let results = builder
477            .session(session.id)
478            .node_type(NodeType::Prompt)
479            .execute()
480            .await
481            .unwrap();
482
483        assert_eq!(results.len(), 3);
484        for node in results {
485            assert!(matches!(node, Node::Prompt(_)));
486        }
487    }
488
489    #[tokio::test]
490    async fn test_query_with_limit_and_offset() {
491        let dir = tempdir().unwrap();
492        let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
493            as Arc<dyn crate::storage::AsyncStorageBackend>;
494
495        let session = ConversationSession::new();
496        backend
497            .store_node(&Node::Session(session.clone()))
498            .await
499            .unwrap();
500
501        for i in 0..10 {
502            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
503            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
504        }
505
506        // Test limit
507        let builder = AsyncQueryBuilder::new(Arc::clone(&backend));
508        let results = builder
509            .session(session.id)
510            .node_type(NodeType::Prompt)
511            .limit(5)
512            .execute()
513            .await
514            .unwrap();
515        assert_eq!(results.len(), 5);
516
517        // Test offset + limit (pagination)
518        let builder = AsyncQueryBuilder::new(backend);
519        let results = builder
520            .session(session.id)
521            .node_type(NodeType::Prompt)
522            .offset(5)
523            .limit(3)
524            .execute()
525            .await
526            .unwrap();
527        assert_eq!(results.len(), 3);
528    }
529
530    #[tokio::test]
531    async fn test_query_streaming() {
532        let dir = tempdir().unwrap();
533        let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
534            as Arc<dyn crate::storage::AsyncStorageBackend>;
535
536        let session = ConversationSession::new();
537        backend
538            .store_node(&Node::Session(session.clone()))
539            .await
540            .unwrap();
541
542        for i in 0..10 {
543            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
544            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
545        }
546
547        // Stream results
548        let query = AsyncQueryBuilder::new(backend)
549            .session(session.id)
550            .node_type(NodeType::Prompt);
551        let mut stream = query.execute_stream();
552
553        let mut count = 0;
554        while let Some(result) = stream.next().await {
555            result.unwrap();
556            count += 1;
557        }
558
559        assert_eq!(count, 10);
560    }
561
562    #[tokio::test]
563    async fn test_query_count() {
564        let dir = tempdir().unwrap();
565        let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
566            as Arc<dyn crate::storage::AsyncStorageBackend>;
567
568        let session = ConversationSession::new();
569        backend
570            .store_node(&Node::Session(session.clone()))
571            .await
572            .unwrap();
573
574        for i in 0..7 {
575            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
576            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
577        }
578
579        // Count prompts
580        let builder = AsyncQueryBuilder::new(backend);
581        let count = builder
582            .session(session.id)
583            .node_type(NodeType::Prompt)
584            .count()
585            .await
586            .unwrap();
587
588        assert_eq!(count, 7);
589    }
590
591    #[tokio::test]
592    async fn test_streaming_with_limit() {
593        let dir = tempdir().unwrap();
594        let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
595            as Arc<dyn crate::storage::AsyncStorageBackend>;
596
597        let session = ConversationSession::new();
598        backend
599            .store_node(&Node::Session(session.clone()))
600            .await
601            .unwrap();
602
603        for i in 0..20 {
604            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
605            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
606        }
607
608        // Stream with limit
609        let query = AsyncQueryBuilder::new(backend)
610            .session(session.id)
611            .node_type(NodeType::Prompt)
612            .limit(5);
613        let mut stream = query.execute_stream();
614
615        let mut count = 0;
616        while let Some(result) = stream.next().await {
617            result.unwrap();
618            count += 1;
619        }
620
621        assert_eq!(count, 5);
622    }
623}