use crate::error::Result;
use crate::storage::AsyncStorageBackend;
use crate::types::{Node, NodeType, SessionId};
use chrono::{DateTime, Utc};
use futures::stream::Stream;
use std::pin::Pin;
use std::sync::Arc;
pub struct AsyncQueryBuilder {
storage: Arc<dyn AsyncStorageBackend>,
session_filter: Option<SessionId>,
node_type_filter: Option<NodeType>,
time_range: Option<(DateTime<Utc>, DateTime<Utc>)>,
limit: Option<usize>,
offset: usize,
}
impl AsyncQueryBuilder {
pub fn new(storage: Arc<dyn AsyncStorageBackend>) -> Self {
Self {
storage,
session_filter: None,
node_type_filter: None,
time_range: None,
limit: None,
offset: 0,
}
}
pub fn session(mut self, session_id: SessionId) -> Self {
self.session_filter = Some(session_id);
self
}
pub fn node_type(mut self, node_type: NodeType) -> Self {
self.node_type_filter = Some(node_type);
self
}
pub fn time_range(mut self, start: DateTime<Utc>, end: DateTime<Utc>) -> Self {
self.time_range = Some((start, end));
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
pub async fn execute(&self) -> Result<Vec<Node>> {
let mut nodes = if let Some(session_id) = &self.session_filter {
self.storage.get_session_nodes(session_id).await?
} else {
vec![]
};
if let Some(node_type) = &self.node_type_filter {
nodes.retain(|node| node.node_type() == *node_type);
}
if let Some((start, end)) = &self.time_range {
nodes.retain(|node| {
let timestamp = match node {
Node::Prompt(p) => p.timestamp,
Node::Response(r) => r.timestamp,
Node::Session(s) => s.created_at,
Node::ToolInvocation(t) => t.timestamp,
Node::Agent(a) => a.created_at,
Node::Template(t) => t.created_at,
};
timestamp >= *start && timestamp <= *end
});
}
nodes.sort_by(|a, b| {
let ts_a = match a {
Node::Prompt(p) => p.timestamp,
Node::Response(r) => r.timestamp,
Node::Session(s) => s.created_at,
Node::ToolInvocation(t) => t.timestamp,
Node::Agent(a) => a.created_at,
Node::Template(t) => t.created_at,
};
let ts_b = match b {
Node::Prompt(p) => p.timestamp,
Node::Response(r) => r.timestamp,
Node::Session(s) => s.created_at,
Node::ToolInvocation(t) => t.timestamp,
Node::Agent(a) => a.created_at,
Node::Template(t) => t.created_at,
};
ts_b.cmp(&ts_a)
});
let nodes: Vec<_> = nodes.into_iter().skip(self.offset).collect();
let nodes = if let Some(limit) = self.limit {
nodes.into_iter().take(limit).collect()
} else {
nodes
};
Ok(nodes)
}
pub fn execute_stream(&self) -> Pin<Box<dyn Stream<Item = Result<Node>> + Send + '_>> {
use futures::StreamExt;
let session_filter = self.session_filter;
let node_type_filter = self.node_type_filter.clone();
let time_range = self.time_range;
let limit = self.limit;
let offset = self.offset;
Box::pin(async_stream::stream! {
let mut stream = if let Some(session_id) = session_filter {
self.storage.get_session_nodes_stream(&session_id)
} else {
Box::pin(futures::stream::empty()) as Pin<Box<dyn Stream<Item = Result<Node>> + Send + '_>>
};
let mut skipped = 0;
let mut emitted = 0;
while let Some(result) = stream.next().await {
let node = match result {
Ok(n) => n,
Err(e) => {
yield Err(e);
continue;
}
};
if let Some(ref nt) = node_type_filter {
if node.node_type() != *nt {
continue;
}
}
if let Some((start, end)) = time_range {
let timestamp = match &node {
Node::Prompt(p) => p.timestamp,
Node::Response(r) => r.timestamp,
Node::Session(s) => s.created_at,
Node::ToolInvocation(t) => t.timestamp,
Node::Agent(a) => a.created_at,
Node::Template(t) => t.created_at,
};
if timestamp < start || timestamp > end {
continue;
}
}
if skipped < offset {
skipped += 1;
continue;
}
if let Some(lim) = limit {
if emitted >= lim {
break;
}
}
emitted += 1;
yield Ok(node);
}
})
}
pub async fn count(&self) -> Result<usize> {
use futures::StreamExt;
if self.session_filter.is_some()
&& self.node_type_filter.is_none()
&& self.time_range.is_none()
&& self.offset == 0
&& self.limit.is_none()
{
return self
.storage
.count_session_nodes(&self.session_filter.unwrap())
.await;
}
let mut stream = self.execute_stream();
let mut count = 0;
while let Some(result) = stream.next().await {
result?;
count += 1;
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::AsyncSledBackend;
use crate::types::{ConversationSession, PromptNode};
use futures::stream::StreamExt;
use tempfile::tempdir;
#[tokio::test]
async fn test_query_builder_creation() {
let dir = tempdir().unwrap();
let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
let builder = AsyncQueryBuilder::new(
Arc::new(backend) as Arc<dyn crate::storage::AsyncStorageBackend>
);
let results = builder.execute().await.unwrap();
assert_eq!(results.len(), 0);
}
#[tokio::test]
async fn test_query_with_session_filter() {
let dir = tempdir().unwrap();
let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
as Arc<dyn crate::storage::AsyncStorageBackend>;
let session = ConversationSession::new();
backend
.store_node(&Node::Session(session.clone()))
.await
.unwrap();
for i in 0..5 {
let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
backend.store_node(&Node::Prompt(prompt)).await.unwrap();
}
let builder = AsyncQueryBuilder::new(backend);
let results = builder.session(session.id).execute().await.unwrap();
assert_eq!(results.len(), 6); }
#[tokio::test]
async fn test_query_with_node_type_filter() {
let dir = tempdir().unwrap();
let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
as Arc<dyn crate::storage::AsyncStorageBackend>;
let session = ConversationSession::new();
backend
.store_node(&Node::Session(session.clone()))
.await
.unwrap();
for i in 0..3 {
let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
backend.store_node(&Node::Prompt(prompt)).await.unwrap();
}
let builder = AsyncQueryBuilder::new(backend);
let results = builder
.session(session.id)
.node_type(NodeType::Prompt)
.execute()
.await
.unwrap();
assert_eq!(results.len(), 3);
for node in results {
assert!(matches!(node, Node::Prompt(_)));
}
}
#[tokio::test]
async fn test_query_with_limit_and_offset() {
let dir = tempdir().unwrap();
let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
as Arc<dyn crate::storage::AsyncStorageBackend>;
let session = ConversationSession::new();
backend
.store_node(&Node::Session(session.clone()))
.await
.unwrap();
for i in 0..10 {
let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
backend.store_node(&Node::Prompt(prompt)).await.unwrap();
}
let builder = AsyncQueryBuilder::new(Arc::clone(&backend));
let results = builder
.session(session.id)
.node_type(NodeType::Prompt)
.limit(5)
.execute()
.await
.unwrap();
assert_eq!(results.len(), 5);
let builder = AsyncQueryBuilder::new(backend);
let results = builder
.session(session.id)
.node_type(NodeType::Prompt)
.offset(5)
.limit(3)
.execute()
.await
.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn test_query_streaming() {
let dir = tempdir().unwrap();
let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
as Arc<dyn crate::storage::AsyncStorageBackend>;
let session = ConversationSession::new();
backend
.store_node(&Node::Session(session.clone()))
.await
.unwrap();
for i in 0..10 {
let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
backend.store_node(&Node::Prompt(prompt)).await.unwrap();
}
let query = AsyncQueryBuilder::new(backend)
.session(session.id)
.node_type(NodeType::Prompt);
let mut stream = query.execute_stream();
let mut count = 0;
while let Some(result) = stream.next().await {
result.unwrap();
count += 1;
}
assert_eq!(count, 10);
}
#[tokio::test]
async fn test_query_count() {
let dir = tempdir().unwrap();
let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
as Arc<dyn crate::storage::AsyncStorageBackend>;
let session = ConversationSession::new();
backend
.store_node(&Node::Session(session.clone()))
.await
.unwrap();
for i in 0..7 {
let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
backend.store_node(&Node::Prompt(prompt)).await.unwrap();
}
let builder = AsyncQueryBuilder::new(backend);
let count = builder
.session(session.id)
.node_type(NodeType::Prompt)
.count()
.await
.unwrap();
assert_eq!(count, 7);
}
#[tokio::test]
async fn test_streaming_with_limit() {
let dir = tempdir().unwrap();
let backend = Arc::new(AsyncSledBackend::open(dir.path()).await.unwrap())
as Arc<dyn crate::storage::AsyncStorageBackend>;
let session = ConversationSession::new();
backend
.store_node(&Node::Session(session.clone()))
.await
.unwrap();
for i in 0..20 {
let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
backend.store_node(&Node::Prompt(prompt)).await.unwrap();
}
let query = AsyncQueryBuilder::new(backend)
.session(session.id)
.node_type(NodeType::Prompt)
.limit(5);
let mut stream = query.execute_stream();
let mut count = 0;
while let Some(result) = stream.next().await {
result.unwrap();
count += 1;
}
assert_eq!(count, 5);
}
}