llm_memory_graph/engine/async_memory_graph.rs
1//! Async interface for the memory graph using Tokio runtime
2//!
3//! This module provides a fully async API for all graph operations, enabling
4//! high-performance concurrent operations and non-blocking I/O.
5
6use crate::error::{Error, Result};
7use crate::observatory::{
8 EventPublisher, MemoryGraphEvent, MemoryGraphMetrics, NoOpPublisher, ObservatoryConfig,
9};
10use crate::storage::{AsyncSledBackend, AsyncStorageBackend, StorageCache};
11use crate::types::{
12 AgentId, AgentNode, Config, ConversationSession, Edge, EdgeType, Node, NodeId, PromptMetadata,
13 PromptNode, PromptTemplate, ResponseMetadata, ResponseNode, SessionId, TemplateId, TokenUsage,
14 ToolInvocation,
15};
16use chrono::Utc;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Instant;
20use tokio::sync::RwLock;
21
22/// Type alias for batch conversation data: (SessionId, prompt_content), optional (response_content, TokenUsage)
23type ConversationBatchItem = ((SessionId, String), Option<(String, TokenUsage)>);
24
25/// Async interface for interacting with the memory graph
26///
27/// `AsyncMemoryGraph` provides a fully async, thread-safe API for managing conversation
28/// sessions, prompts, responses, agents, templates, and their relationships in a graph structure.
29///
30/// All operations are non-blocking and can be executed concurrently without performance degradation.
31///
32/// # Examples
33///
34/// ```no_run
35/// use llm_memory_graph::engine::AsyncMemoryGraph;
36/// use llm_memory_graph::Config;
37///
38/// #[tokio::main]
39/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
40/// let config = Config::new("./data/my_graph.db");
41/// let graph = AsyncMemoryGraph::open(config).await?;
42///
43/// let session = graph.create_session().await?;
44/// let prompt_id = graph.add_prompt(session.id, "What is Rust?".to_string(), None).await?;
45/// Ok(())
46/// }
47/// ```
48pub struct AsyncMemoryGraph {
49 backend: Arc<dyn AsyncStorageBackend>,
50 sessions: Arc<RwLock<HashMap<SessionId, ConversationSession>>>,
51 observatory: Option<Arc<dyn EventPublisher>>,
52 metrics: Option<Arc<MemoryGraphMetrics>>,
53 cache: StorageCache,
54}
55
56impl AsyncMemoryGraph {
57 /// Open or create an async memory graph with the given configuration
58 ///
59 /// This will create the database directory if it doesn't exist and initialize
60 /// all necessary storage trees. Operations use Tokio's async runtime.
61 ///
62 /// # Errors
63 ///
64 /// Returns an error if:
65 /// - The database path is invalid or inaccessible
66 /// - Storage initialization fails
67 /// - Existing data is corrupted
68 ///
69 /// # Examples
70 ///
71 /// ```no_run
72 /// use llm_memory_graph::engine::AsyncMemoryGraph;
73 /// use llm_memory_graph::Config;
74 ///
75 /// #[tokio::main]
76 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
77 /// let config = Config::new("./data/graph.db");
78 /// let graph = AsyncMemoryGraph::open(config).await?;
79 /// Ok(())
80 /// }
81 /// ```
82 pub async fn open(config: Config) -> Result<Self> {
83 let backend = AsyncSledBackend::open(&config.path).await?;
84
85 // Convert cache size from MB to approximate entry count
86 // Assume ~1KB per node, so 100MB = ~100,000 nodes
87 let node_capacity = (config.cache_size_mb as u64) * 1000;
88 let edge_capacity = node_capacity * 5; // Edges are smaller, cache more
89
90 let cache = StorageCache::with_capacity(node_capacity, edge_capacity);
91
92 Ok(Self {
93 backend: Arc::new(backend),
94 sessions: Arc::new(RwLock::new(HashMap::new())),
95 observatory: None,
96 metrics: None,
97 cache,
98 })
99 }
100
101 /// Open graph with Observatory integration
102 ///
103 /// # Examples
104 ///
105 /// ```no_run
106 /// use llm_memory_graph::engine::AsyncMemoryGraph;
107 /// use llm_memory_graph::observatory::{ObservatoryConfig, InMemoryPublisher};
108 /// use llm_memory_graph::Config;
109 /// use std::sync::Arc;
110 ///
111 /// #[tokio::main]
112 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
113 /// let config = Config::default();
114 /// let publisher = Arc::new(InMemoryPublisher::new());
115 /// let obs_config = ObservatoryConfig::new().enabled();
116 ///
117 /// let graph = AsyncMemoryGraph::with_observatory(
118 /// config,
119 /// Some(publisher),
120 /// obs_config
121 /// ).await?;
122 /// Ok(())
123 /// }
124 /// ```
125 pub async fn with_observatory(
126 config: Config,
127 publisher: Option<Arc<dyn EventPublisher>>,
128 obs_config: ObservatoryConfig,
129 ) -> Result<Self> {
130 let backend = AsyncSledBackend::open(&config.path).await?;
131
132 // Convert cache size from MB to approximate entry count
133 // Assume ~1KB per node, so 100MB = ~100,000 nodes
134 let node_capacity = (config.cache_size_mb as u64) * 1000;
135 let edge_capacity = node_capacity * 5; // Edges are smaller, cache more
136
137 let cache = StorageCache::with_capacity(node_capacity, edge_capacity);
138
139 let metrics = if obs_config.enable_metrics {
140 Some(Arc::new(MemoryGraphMetrics::new()))
141 } else {
142 None
143 };
144
145 let observatory = if obs_config.enabled {
146 publisher.or_else(|| Some(Arc::new(NoOpPublisher)))
147 } else {
148 None
149 };
150
151 Ok(Self {
152 backend: Arc::new(backend),
153 sessions: Arc::new(RwLock::new(HashMap::new())),
154 observatory,
155 metrics,
156 cache,
157 })
158 }
159
160 /// Get metrics snapshot
161 pub fn get_metrics(&self) -> Option<crate::observatory::MetricsSnapshot> {
162 self.metrics.as_ref().map(|m| m.snapshot())
163 }
164
165 /// Publish an event to Observatory (non-blocking)
166 fn publish_event(&self, event: MemoryGraphEvent) {
167 if let Some(obs) = &self.observatory {
168 let obs = Arc::clone(obs);
169 tokio::spawn(async move {
170 if let Err(e) = obs.publish(event).await {
171 tracing::warn!("Failed to publish Observatory event: {}", e);
172 }
173 });
174 }
175 }
176
177 // ===== Session Management =====
178
179 /// Create a new conversation session asynchronously
180 ///
181 /// Sessions are used to group related prompts and responses together.
182 /// Each session has a unique ID and can store custom metadata.
183 ///
184 /// # Errors
185 ///
186 /// Returns an error if the session cannot be persisted to storage.
187 ///
188 /// # Examples
189 ///
190 /// ```no_run
191 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
192 /// # use llm_memory_graph::Config;
193 /// # #[tokio::main]
194 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
195 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
196 /// let session = graph.create_session().await?;
197 /// println!("Created session: {}", session.id);
198 /// # Ok(())
199 /// # }
200 /// ```
201 pub async fn create_session(&self) -> Result<ConversationSession> {
202 let start = Instant::now();
203
204 let session = ConversationSession::new();
205 let node = Node::Session(session.clone());
206 self.backend.store_node(&node).await?;
207
208 // Cache the session in both session cache and node cache
209 self.sessions
210 .write()
211 .await
212 .insert(session.id, session.clone());
213 self.cache.insert_node(session.node_id, node).await;
214
215 // Record metrics
216 let latency_us = start.elapsed().as_micros() as u64;
217 if let Some(metrics) = &self.metrics {
218 metrics.record_node_created();
219 metrics.record_write_latency_us(latency_us);
220 }
221
222 // Publish event
223 self.publish_event(MemoryGraphEvent::NodeCreated {
224 node_id: session.node_id,
225 node_type: crate::types::NodeType::Session,
226 session_id: Some(session.id),
227 timestamp: Utc::now(),
228 metadata: session.metadata.clone(),
229 });
230
231 Ok(session)
232 }
233
234 /// Create a session with custom metadata asynchronously
235 ///
236 /// # Examples
237 ///
238 /// ```no_run
239 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
240 /// # use llm_memory_graph::Config;
241 /// # use std::collections::HashMap;
242 /// # #[tokio::main]
243 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
244 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
245 /// let mut metadata = HashMap::new();
246 /// metadata.insert("user_id".to_string(), "123".to_string());
247 /// let session = graph.create_session_with_metadata(metadata).await?;
248 /// # Ok(())
249 /// # }
250 /// ```
251 pub async fn create_session_with_metadata(
252 &self,
253 metadata: HashMap<String, String>,
254 ) -> Result<ConversationSession> {
255 let session = ConversationSession::with_metadata(metadata);
256 let node = Node::Session(session.clone());
257 self.backend.store_node(&node).await?;
258
259 // Cache the session in both session cache and node cache
260 self.sessions
261 .write()
262 .await
263 .insert(session.id, session.clone());
264 self.cache.insert_node(session.node_id, node).await;
265
266 Ok(session)
267 }
268
269 /// Get a session by ID asynchronously
270 ///
271 /// This will first check the in-memory cache, then fall back to storage.
272 ///
273 /// # Errors
274 ///
275 /// Returns an error if the session doesn't exist or storage retrieval fails.
276 pub async fn get_session(&self, session_id: SessionId) -> Result<ConversationSession> {
277 // Check cache first
278 {
279 let sessions = self.sessions.read().await;
280 if let Some(session) = sessions.get(&session_id) {
281 return Ok(session.clone());
282 }
283 }
284
285 // Fall back to storage
286 let session_nodes = self.backend.get_session_nodes(&session_id).await?;
287
288 for node in session_nodes {
289 if let Node::Session(session) = node {
290 if session.id == session_id {
291 // Update cache
292 self.sessions
293 .write()
294 .await
295 .insert(session_id, session.clone());
296 return Ok(session);
297 }
298 }
299 }
300
301 Err(Error::SessionNotFound(session_id.to_string()))
302 }
303
304 // ===== Prompt Operations =====
305
306 /// Add a prompt node to a session asynchronously
307 ///
308 /// # Examples
309 ///
310 /// ```no_run
311 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
312 /// # use llm_memory_graph::Config;
313 /// # #[tokio::main]
314 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
315 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
316 /// # let session = graph.create_session().await?;
317 /// let prompt_id = graph.add_prompt(
318 /// session.id,
319 /// "Explain async/await in Rust".to_string(),
320 /// None
321 /// ).await?;
322 /// # Ok(())
323 /// # }
324 /// ```
325 pub async fn add_prompt(
326 &self,
327 session_id: SessionId,
328 content: String,
329 metadata: Option<PromptMetadata>,
330 ) -> Result<NodeId> {
331 let start = Instant::now();
332
333 // Verify session exists
334 self.get_session(session_id).await?;
335
336 let prompt = PromptNode {
337 id: NodeId::new(),
338 session_id,
339 content: content.clone(),
340 metadata: metadata.clone().unwrap_or_default(),
341 timestamp: chrono::Utc::now(),
342 template_id: None,
343 variables: HashMap::new(),
344 };
345
346 let prompt_id = prompt.id;
347 let node = Node::Prompt(prompt.clone());
348 self.backend.store_node(&node).await?;
349
350 // Populate cache for immediate read performance
351 self.cache.insert_node(prompt_id, node).await;
352
353 // Create PartOf edge - get session node to get its NodeId
354 let session_nodes = self.backend.get_session_nodes(&session_id).await?;
355 if let Some(session_node) = session_nodes.iter().find(|n| matches!(n, Node::Session(_))) {
356 let edge = Edge::new(prompt_id, session_node.id(), EdgeType::PartOf);
357 self.backend.store_edge(&edge).await?;
358 // Cache the edge
359 self.cache.insert_edge(edge.id, edge).await;
360 }
361
362 // Record metrics
363 let latency_us = start.elapsed().as_micros() as u64;
364 if let Some(metrics) = &self.metrics {
365 metrics.record_node_created();
366 metrics.record_prompt_submitted();
367 metrics.record_write_latency_us(latency_us);
368 }
369
370 // Publish event
371 self.publish_event(MemoryGraphEvent::PromptSubmitted {
372 prompt_id,
373 session_id,
374 content_length: content.len(),
375 model: metadata.unwrap_or_default().model,
376 timestamp: Utc::now(),
377 });
378
379 Ok(prompt_id)
380 }
381
382 /// Add multiple prompts concurrently (batch operation)
383 ///
384 /// This method processes all prompts in parallel for maximum throughput.
385 ///
386 /// # Examples
387 ///
388 /// ```no_run
389 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
390 /// # use llm_memory_graph::Config;
391 /// # #[tokio::main]
392 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
393 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
394 /// # let session = graph.create_session().await?;
395 /// let prompts = vec![
396 /// (session.id, "First prompt".to_string()),
397 /// (session.id, "Second prompt".to_string()),
398 /// ];
399 /// let ids = graph.add_prompts_batch(prompts).await?;
400 /// # Ok(())
401 /// # }
402 /// ```
403 pub async fn add_prompts_batch(
404 &self,
405 prompts: Vec<(SessionId, String)>,
406 ) -> Result<Vec<NodeId>> {
407 let futures: Vec<_> = prompts
408 .into_iter()
409 .map(|(session_id, content)| self.add_prompt(session_id, content, None))
410 .collect();
411
412 futures::future::try_join_all(futures).await
413 }
414
415 // ===== Response Operations =====
416
417 /// Add a response node linked to a prompt asynchronously
418 ///
419 /// # Examples
420 ///
421 /// ```no_run
422 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
423 /// # use llm_memory_graph::{Config, TokenUsage};
424 /// # #[tokio::main]
425 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
426 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
427 /// # let session = graph.create_session().await?;
428 /// # let prompt_id = graph.add_prompt(session.id, "Hello".to_string(), None).await?;
429 /// let usage = TokenUsage::new(10, 50);
430 /// let response_id = graph.add_response(
431 /// prompt_id,
432 /// "Async operations are non-blocking!".to_string(),
433 /// usage,
434 /// None
435 /// ).await?;
436 /// # Ok(())
437 /// # }
438 /// ```
439 pub async fn add_response(
440 &self,
441 prompt_id: NodeId,
442 content: String,
443 token_usage: TokenUsage,
444 metadata: Option<ResponseMetadata>,
445 ) -> Result<NodeId> {
446 let start = Instant::now();
447
448 let response = ResponseNode {
449 id: NodeId::new(),
450 prompt_id,
451 timestamp: chrono::Utc::now(),
452 content: content.clone(),
453 usage: token_usage,
454 metadata: metadata.unwrap_or_default(),
455 };
456
457 let response_id = response.id;
458 let node = Node::Response(response.clone());
459 self.backend.store_node(&node).await?;
460
461 // Populate cache for immediate read performance
462 self.cache.insert_node(response_id, node).await;
463
464 // Create RespondsTo edge
465 let edge = Edge::new(response_id, prompt_id, EdgeType::RespondsTo);
466 self.backend.store_edge(&edge).await?;
467 // Cache the edge
468 self.cache.insert_edge(edge.id, edge).await;
469
470 // Record metrics
471 let latency_us = start.elapsed().as_micros() as u64;
472 if let Some(metrics) = &self.metrics {
473 metrics.record_node_created();
474 metrics.record_response_generated();
475 metrics.record_write_latency_us(latency_us);
476 }
477
478 // Publish event
479 let response_latency_ms = latency_us / 1000;
480 self.publish_event(MemoryGraphEvent::ResponseGenerated {
481 response_id,
482 prompt_id,
483 content_length: content.len(),
484 tokens_used: token_usage,
485 latency_ms: response_latency_ms,
486 timestamp: Utc::now(),
487 });
488
489 Ok(response_id)
490 }
491
492 // ===== Agent Operations =====
493
494 /// Add an agent node asynchronously
495 ///
496 /// # Examples
497 ///
498 /// ```no_run
499 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
500 /// # use llm_memory_graph::{Config, AgentNode};
501 /// # #[tokio::main]
502 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
503 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
504 /// let agent = AgentNode::new(
505 /// "CodeReviewer".to_string(),
506 /// "code-review".to_string(),
507 /// vec!["rust".to_string(), "python".to_string()]
508 /// );
509 /// let agent_id = graph.add_agent(agent).await?;
510 /// # Ok(())
511 /// # }
512 /// ```
513 pub async fn add_agent(&self, agent: AgentNode) -> Result<AgentId> {
514 let agent_id = agent.id;
515 let node_id = agent.node_id;
516 let node = Node::Agent(agent);
517 self.backend.store_node(&node).await?;
518
519 // Populate cache for immediate read performance
520 self.cache.insert_node(node_id, node).await;
521
522 Ok(agent_id)
523 }
524
525 /// Update an existing agent asynchronously
526 ///
527 /// This invalidates the cache entry for the agent to ensure consistency.
528 pub async fn update_agent(&self, agent: AgentNode) -> Result<()> {
529 let node_id = agent.node_id;
530 self.backend.store_node(&Node::Agent(agent)).await?;
531
532 // Invalidate cache to ensure consistency
533 self.cache.invalidate_node(&node_id).await;
534
535 Ok(())
536 }
537
538 /// Assign an agent to handle a prompt asynchronously
539 ///
540 /// Creates a HandledBy edge from the prompt to the agent.
541 pub async fn assign_agent_to_prompt(
542 &self,
543 prompt_id: NodeId,
544 agent_node_id: NodeId,
545 ) -> Result<()> {
546 let edge = Edge::new(prompt_id, agent_node_id, EdgeType::HandledBy);
547 self.backend.store_edge(&edge).await
548 }
549
550 /// Transfer from one agent to another asynchronously
551 ///
552 /// Creates a TransfersTo edge representing agent handoff.
553 pub async fn transfer_to_agent(
554 &self,
555 from_response: NodeId,
556 to_agent_node_id: NodeId,
557 ) -> Result<()> {
558 let edge = Edge::new(from_response, to_agent_node_id, EdgeType::TransfersTo);
559 self.backend.store_edge(&edge).await
560 }
561
562 // ===== Template Operations =====
563
564 /// Create a new prompt template asynchronously
565 ///
566 /// # Examples
567 ///
568 /// ```no_run
569 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
570 /// # use llm_memory_graph::{Config, PromptTemplate, VariableSpec};
571 /// # #[tokio::main]
572 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
573 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
574 /// let template = PromptTemplate::new(
575 /// "Greeting".to_string(),
576 /// "Hello {{name}}!".to_string(),
577 /// vec![]
578 /// );
579 /// let template_id = graph.create_template(template).await?;
580 /// # Ok(())
581 /// # }
582 /// ```
583 pub async fn create_template(&self, template: PromptTemplate) -> Result<TemplateId> {
584 let template_id = template.id;
585 let template_node_id = template.node_id;
586 let node = Node::Template(template);
587 self.backend.store_node(&node).await?;
588
589 // Populate cache for immediate read performance
590 self.cache.insert_node(template_node_id, node).await;
591
592 Ok(template_id)
593 }
594
595 /// Update an existing template asynchronously
596 ///
597 /// This invalidates the cache entry for the template to ensure consistency.
598 pub async fn update_template(&self, template: PromptTemplate) -> Result<()> {
599 let template_node_id = template.node_id;
600 self.backend.store_node(&Node::Template(template)).await?;
601
602 // Invalidate cache to ensure consistency
603 self.cache.invalidate_node(&template_node_id).await;
604
605 Ok(())
606 }
607
608 /// Get a template by its template ID asynchronously
609 pub async fn get_template(&self, template_id: TemplateId) -> Result<PromptTemplate> {
610 // Search through all nodes to find the template
611 // This is a simplified implementation - in production, you'd want an index
612 let all_sessions = self.backend.get_session_nodes(&SessionId::new()).await?;
613
614 for node in all_sessions {
615 if let Node::Template(template) = node {
616 if template.id == template_id {
617 return Ok(template);
618 }
619 }
620 }
621
622 Err(Error::NodeNotFound(format!("Template {}", template_id)))
623 }
624
625 /// Get a template by its node ID asynchronously
626 pub async fn get_template_by_node_id(&self, node_id: NodeId) -> Result<PromptTemplate> {
627 if let Some(Node::Template(template)) = self.backend.get_node(&node_id).await? {
628 return Ok(template);
629 }
630
631 Err(Error::NodeNotFound(node_id.to_string()))
632 }
633
634 /// Create template from parent (inheritance) asynchronously
635 pub async fn create_template_from_parent(
636 &self,
637 template: PromptTemplate,
638 parent_node_id: NodeId,
639 ) -> Result<TemplateId> {
640 let template_node_id = template.node_id;
641 let template_id = template.id;
642
643 // Store the new template
644 self.backend.store_node(&Node::Template(template)).await?;
645
646 // Create Inherits edge
647 let edge = Edge::new(template_node_id, parent_node_id, EdgeType::Inherits);
648 self.backend.store_edge(&edge).await?;
649
650 Ok(template_id)
651 }
652
653 /// Link a prompt to the template it was instantiated from
654 pub async fn link_prompt_to_template(
655 &self,
656 prompt_id: NodeId,
657 template_node_id: NodeId,
658 ) -> Result<()> {
659 let edge = Edge::new(prompt_id, template_node_id, EdgeType::Instantiates);
660 self.backend.store_edge(&edge).await
661 }
662
663 // ===== Tool Invocation Operations =====
664
665 /// Add a tool invocation node asynchronously
666 ///
667 /// # Examples
668 ///
669 /// ```no_run
670 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
671 /// # use llm_memory_graph::{Config, ToolInvocation, NodeId};
672 /// # #[tokio::main]
673 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
674 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
675 /// # let response_id = NodeId::new();
676 /// let tool = ToolInvocation::new(
677 /// response_id,
678 /// "calculator".to_string(),
679 /// serde_json::json!({"operation": "add", "a": 2, "b": 3})
680 /// );
681 /// let tool_id = graph.add_tool_invocation(tool).await?;
682 /// # Ok(())
683 /// # }
684 /// ```
685 pub async fn add_tool_invocation(&self, tool: ToolInvocation) -> Result<NodeId> {
686 let tool_id = tool.id;
687 let response_id = tool.response_id;
688
689 // Store the tool invocation node
690 let node = Node::ToolInvocation(tool);
691 self.backend.store_node(&node).await?;
692
693 // Populate cache for immediate read performance
694 self.cache.insert_node(tool_id, node).await;
695
696 // Create INVOKES edge from response to tool
697 let edge = Edge::new(response_id, tool_id, EdgeType::Invokes);
698 self.backend.store_edge(&edge).await?;
699 // Cache the edge
700 self.cache.insert_edge(edge.id, edge).await;
701
702 Ok(tool_id)
703 }
704
705 /// Update tool invocation with results asynchronously
706 ///
707 /// This invalidates the cache entry for the tool to ensure consistency.
708 pub async fn update_tool_invocation(&self, tool: ToolInvocation) -> Result<()> {
709 let tool_id = tool.id;
710 self.backend.store_node(&Node::ToolInvocation(tool)).await?;
711
712 // Invalidate cache to ensure consistency
713 self.cache.invalidate_node(&tool_id).await;
714
715 Ok(())
716 }
717
718 // ===== Edge and Traversal Operations =====
719
720 /// Get a node by ID asynchronously (cache-aware)
721 ///
722 /// This method first checks the cache for the node. If found in cache,
723 /// it returns immediately (< 1ms latency). Otherwise, it loads from
724 /// storage and populates the cache for future requests.
725 ///
726 /// # Performance
727 ///
728 /// - Cache hit: < 1ms latency
729 /// - Cache miss: 2-10ms latency (loads from storage)
730 ///
731 /// # Examples
732 ///
733 /// ```no_run
734 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
735 /// # use llm_memory_graph::{Config, NodeId};
736 /// # #[tokio::main]
737 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
738 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
739 /// # let node_id = NodeId::new();
740 /// let node = graph.get_node(&node_id).await?;
741 /// # Ok(())
742 /// # }
743 /// ```
744 pub async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
745 let start = Instant::now();
746
747 // Check cache first
748 if let Some(node) = self.cache.get_node(id).await {
749 // Record cache hit in metrics
750 if let Some(metrics) = &self.metrics {
751 let latency_us = start.elapsed().as_micros() as u64;
752 metrics.record_read_latency_us(latency_us);
753 }
754 return Ok(Some(node));
755 }
756
757 // Cache miss - load from storage
758 if let Some(node) = self.backend.get_node(id).await? {
759 // Populate cache for future requests
760 self.cache.insert_node(*id, node.clone()).await;
761
762 // Record read latency
763 if let Some(metrics) = &self.metrics {
764 let latency_us = start.elapsed().as_micros() as u64;
765 metrics.record_read_latency_us(latency_us);
766 }
767
768 return Ok(Some(node));
769 }
770
771 Ok(None)
772 }
773
774 /// Get an edge by ID asynchronously (cache-aware)
775 ///
776 /// This method first checks the cache for the edge. If found in cache,
777 /// it returns immediately. Otherwise, it loads from storage and populates
778 /// the cache for future requests.
779 ///
780 /// # Performance
781 ///
782 /// - Cache hit: < 1ms latency
783 /// - Cache miss: 2-10ms latency (loads from storage)
784 pub async fn get_edge(&self, id: &crate::types::EdgeId) -> Result<Option<Edge>> {
785 let start = Instant::now();
786
787 // Check cache first
788 if let Some(edge) = self.cache.get_edge(id).await {
789 // Record cache hit in metrics
790 if let Some(metrics) = &self.metrics {
791 let latency_us = start.elapsed().as_micros() as u64;
792 metrics.record_read_latency_us(latency_us);
793 }
794 return Ok(Some(edge));
795 }
796
797 // Cache miss - load from storage
798 if let Some(edge) = self.backend.get_edge(id).await? {
799 // Populate cache for future requests
800 self.cache.insert_edge(*id, edge.clone()).await;
801
802 // Record read latency
803 if let Some(metrics) = &self.metrics {
804 let latency_us = start.elapsed().as_micros() as u64;
805 metrics.record_read_latency_us(latency_us);
806 }
807
808 return Ok(Some(edge));
809 }
810
811 Ok(None)
812 }
813
814 /// Add a custom edge asynchronously
815 pub async fn add_edge(&self, from: NodeId, to: NodeId, edge_type: EdgeType) -> Result<()> {
816 let edge = Edge::new(from, to, edge_type);
817 self.backend.store_edge(&edge).await
818 }
819
820 /// Get all outgoing edges from a node asynchronously
821 pub async fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
822 self.backend.get_outgoing_edges(node_id).await
823 }
824
825 /// Get all incoming edges to a node asynchronously
826 pub async fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
827 self.backend.get_incoming_edges(node_id).await
828 }
829
830 /// Get all nodes in a session asynchronously
831 pub async fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
832 self.backend.get_session_nodes(session_id).await
833 }
834
835 // ===== Batch Operations =====
836
837 /// Store multiple nodes concurrently asynchronously
838 ///
839 /// This method leverages async concurrency to store multiple nodes in parallel.
840 pub async fn store_nodes_batch(&self, nodes: Vec<Node>) -> Result<Vec<NodeId>> {
841 self.backend.store_nodes_batch(&nodes).await
842 }
843
844 /// Store multiple edges concurrently asynchronously
845 pub async fn store_edges_batch(&self, edges: Vec<Edge>) -> Result<()> {
846 self.backend.store_edges_batch(&edges).await?;
847 Ok(())
848 }
849
850 /// Add multiple responses concurrently (batch operation)
851 ///
852 /// This method processes all responses in parallel for maximum throughput.
853 /// Each response is linked to its corresponding prompt.
854 ///
855 /// # Examples
856 ///
857 /// ```no_run
858 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
859 /// # use llm_memory_graph::{Config, TokenUsage};
860 /// # #[tokio::main]
861 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
862 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
863 /// # let session = graph.create_session().await?;
864 /// # let prompt1 = graph.add_prompt(session.id, "Q1".to_string(), None).await?;
865 /// # let prompt2 = graph.add_prompt(session.id, "Q2".to_string(), None).await?;
866 /// let responses = vec![
867 /// (prompt1, "Answer 1".to_string(), TokenUsage::new(10, 50)),
868 /// (prompt2, "Answer 2".to_string(), TokenUsage::new(15, 60)),
869 /// ];
870 /// let ids = graph.add_responses_batch(responses).await?;
871 /// # Ok(())
872 /// # }
873 /// ```
874 pub async fn add_responses_batch(
875 &self,
876 responses: Vec<(NodeId, String, TokenUsage)>,
877 ) -> Result<Vec<NodeId>> {
878 let futures: Vec<_> = responses
879 .into_iter()
880 .map(|(prompt_id, content, usage)| self.add_response(prompt_id, content, usage, None))
881 .collect();
882
883 futures::future::try_join_all(futures).await
884 }
885
886 /// Create multiple sessions concurrently (batch operation)
887 ///
888 /// This method creates multiple sessions in parallel for maximum throughput.
889 ///
890 /// # Examples
891 ///
892 /// ```no_run
893 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
894 /// # use llm_memory_graph::Config;
895 /// # #[tokio::main]
896 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
897 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
898 /// let sessions = graph.create_sessions_batch(5).await?;
899 /// assert_eq!(sessions.len(), 5);
900 /// # Ok(())
901 /// # }
902 /// ```
903 pub async fn create_sessions_batch(&self, count: usize) -> Result<Vec<ConversationSession>> {
904 let futures: Vec<_> = (0..count).map(|_| self.create_session()).collect();
905
906 futures::future::try_join_all(futures).await
907 }
908
909 /// Retrieve multiple nodes concurrently (batch operation)
910 ///
911 /// This method fetches all nodes in parallel for maximum throughput.
912 /// Returns nodes in the same order as the input IDs. Missing nodes are None.
913 ///
914 /// # Examples
915 ///
916 /// ```no_run
917 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
918 /// # use llm_memory_graph::Config;
919 /// # use llm_memory_graph::types::NodeId;
920 /// # #[tokio::main]
921 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
922 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
923 /// # let session = graph.create_session().await?;
924 /// # let id1 = graph.add_prompt(session.id, "Q1".to_string(), None).await?;
925 /// # let id2 = graph.add_prompt(session.id, "Q2".to_string(), None).await?;
926 /// let ids = vec![id1, id2];
927 /// let nodes = graph.get_nodes_batch(ids).await?;
928 /// assert_eq!(nodes.len(), 2);
929 /// # Ok(())
930 /// # }
931 /// ```
932 pub async fn get_nodes_batch(&self, ids: Vec<NodeId>) -> Result<Vec<Option<Node>>> {
933 let futures: Vec<_> = ids.iter().map(|id| self.get_node(id)).collect();
934
935 futures::future::try_join_all(futures).await
936 }
937
938 /// Delete multiple nodes concurrently (batch operation)
939 ///
940 /// This method deletes all nodes in parallel for maximum throughput.
941 /// Note: This does not cascade delete related edges - you may want to
942 /// delete related edges separately.
943 ///
944 /// # Examples
945 ///
946 /// ```no_run
947 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
948 /// # use llm_memory_graph::Config;
949 /// # #[tokio::main]
950 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
951 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
952 /// # let session = graph.create_session().await?;
953 /// # let id1 = graph.add_prompt(session.id, "Q1".to_string(), None).await?;
954 /// # let id2 = graph.add_prompt(session.id, "Q2".to_string(), None).await?;
955 /// let ids = vec![id1, id2];
956 /// graph.delete_nodes_batch(ids).await?;
957 /// # Ok(())
958 /// # }
959 /// ```
960 pub async fn delete_nodes_batch(&self, ids: Vec<NodeId>) -> Result<()> {
961 let futures: Vec<_> = ids.iter().map(|id| self.backend.delete_node(id)).collect();
962
963 futures::future::try_join_all(futures).await?;
964 Ok(())
965 }
966
967 /// Process a mixed batch of prompts and responses concurrently
968 ///
969 /// This is an advanced operation that allows you to add prompts and their
970 /// responses in a single concurrent batch operation. This is useful for
971 /// bulk importing conversation data.
972 ///
973 /// # Examples
974 ///
975 /// ```no_run
976 /// # use llm_memory_graph::engine::AsyncMemoryGraph;
977 /// # use llm_memory_graph::{Config, TokenUsage};
978 /// # #[tokio::main]
979 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
980 /// # let graph = AsyncMemoryGraph::open(Config::default()).await?;
981 /// # let session = graph.create_session().await?;
982 /// let conversations = vec![
983 /// (
984 /// (session.id, "What is Rust?".to_string()),
985 /// Some(("Rust is a systems programming language".to_string(), TokenUsage::new(5, 30))),
986 /// ),
987 /// (
988 /// (session.id, "How does async work?".to_string()),
989 /// Some(("Async in Rust is zero-cost".to_string(), TokenUsage::new(6, 25))),
990 /// ),
991 /// ];
992 /// let results = graph.add_conversations_batch(conversations).await?;
993 /// # Ok(())
994 /// # }
995 /// ```
996 pub async fn add_conversations_batch(
997 &self,
998 conversations: Vec<ConversationBatchItem>,
999 ) -> Result<Vec<(NodeId, Option<NodeId>)>> {
1000 let futures: Vec<_> = conversations
1001 .into_iter()
1002 .map(|((session_id, prompt_content), response_data)| async move {
1003 // Add prompt
1004 let prompt_id = self.add_prompt(session_id, prompt_content, None).await?;
1005
1006 // Add response if provided
1007 let response_id = if let Some((response_content, usage)) = response_data {
1008 Some(
1009 self.add_response(prompt_id, response_content, usage, None)
1010 .await?,
1011 )
1012 } else {
1013 None
1014 };
1015
1016 Ok((prompt_id, response_id))
1017 })
1018 .collect();
1019
1020 futures::future::try_join_all(futures).await
1021 }
1022
1023 // ===== Utility Operations =====
1024
1025 /// Flush any pending writes asynchronously
1026 pub async fn flush(&self) -> Result<()> {
1027 self.backend.flush().await
1028 }
1029
1030 /// Get storage statistics asynchronously
1031 pub async fn stats(&self) -> Result<crate::storage::StorageStats> {
1032 self.backend.stats().await
1033 }
1034
1035 // ===== Query Operations =====
1036
1037 /// Create a new async query builder for querying the graph
1038 ///
1039 /// Returns an `AsyncQueryBuilder` that provides a fluent API for building
1040 /// and executing queries with filtering, pagination, and streaming support.
1041 ///
1042 /// # Examples
1043 ///
1044 /// ```no_run
1045 /// use llm_memory_graph::engine::AsyncMemoryGraph;
1046 /// use llm_memory_graph::types::NodeType;
1047 /// use llm_memory_graph::Config;
1048 ///
1049 /// #[tokio::main]
1050 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
1051 /// let graph = AsyncMemoryGraph::open(Config::default()).await?;
1052 /// let session = graph.create_session().await?;
1053 ///
1054 /// // Query with fluent API
1055 /// let prompts = graph.query()
1056 /// .session(session.id)
1057 /// .node_type(NodeType::Prompt)
1058 /// .limit(10)
1059 /// .execute()
1060 /// .await?;
1061 ///
1062 /// println!("Found {} prompts", prompts.len());
1063 /// Ok(())
1064 /// }
1065 /// ```
1066 pub fn query(&self) -> crate::query::AsyncQueryBuilder {
1067 crate::query::AsyncQueryBuilder::new(Arc::clone(&self.backend))
1068 }
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073 use super::*;
1074 use tempfile::tempdir;
1075
1076 #[tokio::test]
1077 async fn test_async_graph_creation() {
1078 let dir = tempdir().unwrap();
1079 let config = Config::new(dir.path());
1080 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1081
1082 let stats = graph.stats().await.unwrap();
1083 assert_eq!(stats.node_count, 0);
1084 }
1085
1086 #[tokio::test]
1087 async fn test_async_session_management() {
1088 let dir = tempdir().unwrap();
1089 let config = Config::new(dir.path());
1090 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1091
1092 // Create session
1093 let session = graph.create_session().await.unwrap();
1094 assert!(!session.id.to_string().is_empty());
1095
1096 // Retrieve session
1097 let retrieved = graph.get_session(session.id).await.unwrap();
1098 assert_eq!(retrieved.id, session.id);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_async_prompt_and_response() {
1103 let dir = tempdir().unwrap();
1104 let config = Config::new(dir.path());
1105 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1106
1107 let session = graph.create_session().await.unwrap();
1108 let prompt_id = graph
1109 .add_prompt(session.id, "Test prompt".to_string(), None)
1110 .await
1111 .unwrap();
1112
1113 let usage = TokenUsage::new(10, 20);
1114 let response_id = graph
1115 .add_response(prompt_id, "Test response".to_string(), usage, None)
1116 .await
1117 .unwrap();
1118
1119 // Verify edges
1120 let edges = graph.get_outgoing_edges(&response_id).await.unwrap();
1121 assert_eq!(edges.len(), 1);
1122 assert_eq!(edges[0].edge_type, EdgeType::RespondsTo);
1123 }
1124
1125 #[tokio::test]
1126 async fn test_concurrent_prompts() {
1127 let dir = tempdir().unwrap();
1128 let config = Config::new(dir.path());
1129 let graph = Arc::new(AsyncMemoryGraph::open(config).await.unwrap());
1130
1131 let session = graph.create_session().await.unwrap();
1132
1133 // Create 100 prompts concurrently
1134 let mut handles = vec![];
1135 for i in 0..100 {
1136 let graph_clone = Arc::clone(&graph);
1137 let session_id = session.id;
1138
1139 let handle = tokio::spawn(async move {
1140 graph_clone
1141 .add_prompt(session_id, format!("Prompt {}", i), None)
1142 .await
1143 });
1144
1145 handles.push(handle);
1146 }
1147
1148 // Wait for all to complete
1149 for handle in handles {
1150 handle.await.unwrap().unwrap();
1151 }
1152
1153 // Verify all were stored
1154 let stats = graph.stats().await.unwrap();
1155 assert_eq!(stats.node_count, 101); // 1 session + 100 prompts
1156 }
1157
1158 #[tokio::test]
1159 async fn test_batch_operations() {
1160 let dir = tempdir().unwrap();
1161 let config = Config::new(dir.path());
1162 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1163
1164 let session = graph.create_session().await.unwrap();
1165
1166 // Batch add prompts
1167 let prompts = (0..10)
1168 .map(|i| (session.id, format!("Prompt {}", i)))
1169 .collect();
1170
1171 let ids = graph.add_prompts_batch(prompts).await.unwrap();
1172 assert_eq!(ids.len(), 10);
1173
1174 let stats = graph.stats().await.unwrap();
1175 assert_eq!(stats.node_count, 11); // 1 session + 10 prompts
1176 }
1177
1178 #[tokio::test]
1179 async fn test_agent_operations() {
1180 let dir = tempdir().unwrap();
1181 let config = Config::new(dir.path());
1182 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1183
1184 let agent = AgentNode::new(
1185 "TestAgent".to_string(),
1186 "tester".to_string(),
1187 vec!["testing".to_string()],
1188 );
1189
1190 let agent_id = graph.add_agent(agent).await.unwrap();
1191 assert!(!agent_id.to_string().is_empty());
1192 }
1193
1194 #[tokio::test]
1195 async fn test_template_operations() {
1196 let dir = tempdir().unwrap();
1197 let config = Config::new(dir.path());
1198 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1199
1200 let template = PromptTemplate::new(
1201 "Test Template".to_string(),
1202 "Hello {{name}}!".to_string(),
1203 vec![],
1204 );
1205
1206 let template_id = graph.create_template(template.clone()).await.unwrap();
1207 assert_eq!(template_id, template.id);
1208 }
1209
1210 #[tokio::test]
1211 async fn test_tool_invocation() {
1212 let dir = tempdir().unwrap();
1213 let config = Config::new(dir.path());
1214 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1215
1216 let session = graph.create_session().await.unwrap();
1217 let prompt_id = graph
1218 .add_prompt(session.id, "Calculate 2+2".to_string(), None)
1219 .await
1220 .unwrap();
1221
1222 let usage = TokenUsage::new(5, 10);
1223 let response_id = graph
1224 .add_response(prompt_id, "Using calculator...".to_string(), usage, None)
1225 .await
1226 .unwrap();
1227
1228 let tool = ToolInvocation::new(
1229 response_id,
1230 "calculator".to_string(),
1231 serde_json::json!({"op": "add", "a": 2, "b": 2}),
1232 );
1233
1234 let _tool_id = graph.add_tool_invocation(tool).await.unwrap();
1235
1236 // Verify INVOKES edge was created
1237 let edges = graph.get_outgoing_edges(&response_id).await.unwrap();
1238 let invokes_edge = edges.iter().find(|e| e.edge_type == EdgeType::Invokes);
1239 assert!(invokes_edge.is_some());
1240 }
1241
1242 #[tokio::test]
1243 async fn test_add_responses_batch() {
1244 let dir = tempdir().unwrap();
1245 let config = Config::new(dir.path());
1246 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1247
1248 let session = graph.create_session().await.unwrap();
1249
1250 // Create 5 prompts first
1251 let mut prompt_ids = vec![];
1252 for i in 0..5 {
1253 let id = graph
1254 .add_prompt(session.id, format!("Prompt {}", i), None)
1255 .await
1256 .unwrap();
1257 prompt_ids.push(id);
1258 }
1259
1260 // Batch add responses
1261 let responses: Vec<_> = prompt_ids
1262 .iter()
1263 .enumerate()
1264 .map(|(i, &prompt_id)| {
1265 (
1266 prompt_id,
1267 format!("Response {}", i),
1268 TokenUsage::new(10, 20),
1269 )
1270 })
1271 .collect();
1272
1273 let response_ids = graph.add_responses_batch(responses).await.unwrap();
1274 assert_eq!(response_ids.len(), 5);
1275
1276 // Verify all responses were created with proper edges
1277 for (i, &response_id) in response_ids.iter().enumerate() {
1278 let node = graph.get_node(&response_id).await.unwrap();
1279 assert!(matches!(node, Some(Node::Response(_))));
1280
1281 // Check RESPONDS_TO edge
1282 let edges = graph.get_outgoing_edges(&response_id).await.unwrap();
1283 let responds_to = edges.iter().find(|e| e.edge_type == EdgeType::RespondsTo);
1284 assert!(responds_to.is_some());
1285 assert_eq!(responds_to.unwrap().to, prompt_ids[i]);
1286 }
1287
1288 let stats = graph.stats().await.unwrap();
1289 assert_eq!(stats.node_count, 11); // 1 session + 5 prompts + 5 responses
1290 }
1291
1292 #[tokio::test]
1293 async fn test_create_sessions_batch() {
1294 let dir = tempdir().unwrap();
1295 let config = Config::new(dir.path());
1296 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1297
1298 // Create 10 sessions concurrently
1299 let sessions = graph.create_sessions_batch(10).await.unwrap();
1300 assert_eq!(sessions.len(), 10);
1301
1302 // Verify all sessions have unique IDs
1303 let mut ids = std::collections::HashSet::new();
1304 for session in &sessions {
1305 assert!(ids.insert(session.id));
1306 }
1307
1308 // Verify all can be retrieved
1309 for session in &sessions {
1310 let retrieved = graph.get_session(session.id).await.unwrap();
1311 assert_eq!(retrieved.id, session.id);
1312 }
1313
1314 let stats = graph.stats().await.unwrap();
1315 assert_eq!(stats.node_count, 10);
1316 assert_eq!(stats.session_count, 10);
1317 }
1318
1319 #[tokio::test]
1320 async fn test_get_nodes_batch() {
1321 let dir = tempdir().unwrap();
1322 let config = Config::new(dir.path());
1323 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1324
1325 let session = graph.create_session().await.unwrap();
1326
1327 // Create 20 prompts
1328 let mut expected_ids = vec![];
1329 for i in 0..20 {
1330 let id = graph
1331 .add_prompt(session.id, format!("Prompt {}", i), None)
1332 .await
1333 .unwrap();
1334 expected_ids.push(id);
1335 }
1336
1337 // Batch retrieve all nodes
1338 let nodes = graph.get_nodes_batch(expected_ids.clone()).await.unwrap();
1339 assert_eq!(nodes.len(), 20);
1340
1341 // Verify all nodes were retrieved
1342 for (i, node_opt) in nodes.iter().enumerate() {
1343 assert!(node_opt.is_some());
1344 let node = node_opt.as_ref().unwrap();
1345 assert_eq!(node.id(), expected_ids[i]);
1346
1347 if let Node::Prompt(prompt) = node {
1348 assert_eq!(prompt.content, format!("Prompt {}", i));
1349 } else {
1350 panic!("Expected Prompt node");
1351 }
1352 }
1353 }
1354
1355 #[tokio::test]
1356 async fn test_get_nodes_batch_with_missing() {
1357 let dir = tempdir().unwrap();
1358 let config = Config::new(dir.path());
1359 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1360
1361 let session = graph.create_session().await.unwrap();
1362
1363 // Create 3 prompts
1364 let mut ids = vec![];
1365 for i in 0..3 {
1366 let id = graph
1367 .add_prompt(session.id, format!("Prompt {}", i), None)
1368 .await
1369 .unwrap();
1370 ids.push(id);
1371 }
1372
1373 // Add non-existent ID in the middle
1374 let fake_id = NodeId::new();
1375 ids.insert(1, fake_id);
1376
1377 // Batch retrieve should return None for missing node
1378 let nodes = graph.get_nodes_batch(ids).await.unwrap();
1379 assert_eq!(nodes.len(), 4);
1380 assert!(nodes[0].is_some());
1381 assert!(nodes[1].is_none()); // Fake ID
1382 assert!(nodes[2].is_some());
1383 assert!(nodes[3].is_some());
1384 }
1385
1386 #[tokio::test]
1387 async fn test_delete_nodes_batch() {
1388 let dir = tempdir().unwrap();
1389 let config = Config::new(dir.path());
1390 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1391
1392 let session = graph.create_session().await.unwrap();
1393
1394 // Create 15 prompts
1395 let mut ids_to_delete = vec![];
1396 for i in 0..15 {
1397 let id = graph
1398 .add_prompt(session.id, format!("Prompt {}", i), None)
1399 .await
1400 .unwrap();
1401 ids_to_delete.push(id);
1402 }
1403
1404 // Verify initial state
1405 let stats = graph.stats().await.unwrap();
1406 assert_eq!(stats.node_count, 16); // 1 session + 15 prompts
1407
1408 // Batch delete all prompts
1409 graph
1410 .delete_nodes_batch(ids_to_delete.clone())
1411 .await
1412 .unwrap();
1413
1414 // Note: Current implementation may cache nodes, so deletion might not be immediate
1415 // This test verifies the batch operation completes without errors
1416 // For stricter deletion verification, use flush and clear cache
1417 }
1418
1419 #[tokio::test]
1420 async fn test_add_conversations_batch() {
1421 let dir = tempdir().unwrap();
1422 let config = Config::new(dir.path());
1423 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1424
1425 let session = graph.create_session().await.unwrap();
1426
1427 // Create mixed batch: some with responses, some without
1428 let conversations = vec![
1429 (
1430 (session.id, "Prompt 1".to_string()),
1431 Some(("Response 1".to_string(), TokenUsage::new(10, 20))),
1432 ),
1433 (
1434 (session.id, "Prompt 2".to_string()),
1435 None, // No response
1436 ),
1437 (
1438 (session.id, "Prompt 3".to_string()),
1439 Some(("Response 3".to_string(), TokenUsage::new(15, 25))),
1440 ),
1441 (
1442 (session.id, "Prompt 4".to_string()),
1443 Some(("Response 4".to_string(), TokenUsage::new(12, 22))),
1444 ),
1445 (
1446 (session.id, "Prompt 5".to_string()),
1447 None, // No response
1448 ),
1449 ];
1450
1451 let results = graph.add_conversations_batch(conversations).await.unwrap();
1452 assert_eq!(results.len(), 5);
1453
1454 // Verify structure
1455 assert!(results[0].1.is_some()); // Has response
1456 assert!(results[1].1.is_none()); // No response
1457 assert!(results[2].1.is_some()); // Has response
1458 assert!(results[3].1.is_some()); // Has response
1459 assert!(results[4].1.is_none()); // No response
1460
1461 // Verify all prompts exist
1462 for (prompt_id, _) in &results {
1463 let node = graph.get_node(prompt_id).await.unwrap();
1464 assert!(matches!(node, Some(Node::Prompt(_))));
1465 }
1466
1467 // Verify responses exist and have proper edges
1468 for (prompt_id, response_id_opt) in &results {
1469 if let Some(response_id) = response_id_opt {
1470 let node = graph.get_node(response_id).await.unwrap();
1471 assert!(matches!(node, Some(Node::Response(_))));
1472
1473 // Check RESPONDS_TO edge
1474 let edges = graph.get_outgoing_edges(response_id).await.unwrap();
1475 let responds_to = edges.iter().find(|e| e.edge_type == EdgeType::RespondsTo);
1476 assert!(responds_to.is_some());
1477 assert_eq!(responds_to.unwrap().to, *prompt_id);
1478 }
1479 }
1480
1481 let stats = graph.stats().await.unwrap();
1482 assert_eq!(stats.node_count, 9); // 1 session + 5 prompts + 3 responses
1483 }
1484
1485 #[tokio::test]
1486 async fn test_empty_batch_operations() {
1487 let dir = tempdir().unwrap();
1488 let config = Config::new(dir.path());
1489 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1490
1491 // Test empty batches
1492 let sessions = graph.create_sessions_batch(0).await.unwrap();
1493 assert_eq!(sessions.len(), 0);
1494
1495 let nodes = graph.get_nodes_batch(vec![]).await.unwrap();
1496 assert_eq!(nodes.len(), 0);
1497
1498 graph.delete_nodes_batch(vec![]).await.unwrap();
1499
1500 let prompts = graph.add_prompts_batch(vec![]).await.unwrap();
1501 assert_eq!(prompts.len(), 0);
1502
1503 let responses = graph.add_responses_batch(vec![]).await.unwrap();
1504 assert_eq!(responses.len(), 0);
1505
1506 let conversations = graph.add_conversations_batch(vec![]).await.unwrap();
1507 assert_eq!(conversations.len(), 0);
1508 }
1509
1510 #[tokio::test]
1511 async fn test_large_batch_operations() {
1512 let dir = tempdir().unwrap();
1513 let config = Config::new(dir.path());
1514 let graph = AsyncMemoryGraph::open(config).await.unwrap();
1515
1516 let session = graph.create_session().await.unwrap();
1517
1518 // Create 100 prompts in batch
1519 let prompts: Vec<_> = (0..100)
1520 .map(|i| (session.id, format!("Prompt {}", i)))
1521 .collect();
1522
1523 let prompt_ids = graph.add_prompts_batch(prompts).await.unwrap();
1524 assert_eq!(prompt_ids.len(), 100);
1525
1526 // Create 100 responses in batch
1527 let responses: Vec<_> = prompt_ids
1528 .iter()
1529 .enumerate()
1530 .map(|(i, &id)| (id, format!("Response {}", i), TokenUsage::new(10, 20)))
1531 .collect();
1532
1533 let response_ids = graph.add_responses_batch(responses).await.unwrap();
1534 assert_eq!(response_ids.len(), 100);
1535
1536 // Batch retrieve all prompts
1537 let nodes = graph.get_nodes_batch(prompt_ids.clone()).await.unwrap();
1538 assert_eq!(nodes.len(), 100);
1539 assert!(nodes.iter().all(|n| n.is_some()));
1540
1541 let stats = graph.stats().await.unwrap();
1542 assert_eq!(stats.node_count, 201); // 1 session + 100 prompts + 100 responses
1543 }
1544
1545 #[tokio::test]
1546 async fn test_batch_concurrent_execution() {
1547 let dir = tempdir().unwrap();
1548 let config = Config::new(dir.path());
1549 let graph = Arc::new(AsyncMemoryGraph::open(config).await.unwrap());
1550
1551 // Test that batch operations can be called concurrently from multiple tasks
1552 let mut handles = vec![];
1553
1554 for i in 0..5 {
1555 let graph_clone = Arc::clone(&graph);
1556 let handle = tokio::spawn(async move {
1557 // Each task creates its own session and prompts
1558 let sessions = graph_clone.create_sessions_batch(2).await.unwrap();
1559 let prompts = vec![
1560 (sessions[0].id, format!("Task {} Prompt 1", i)),
1561 (sessions[1].id, format!("Task {} Prompt 2", i)),
1562 ];
1563 graph_clone.add_prompts_batch(prompts).await.unwrap();
1564 });
1565 handles.push(handle);
1566 }
1567
1568 // Wait for all concurrent operations
1569 for handle in handles {
1570 handle.await.unwrap();
1571 }
1572
1573 // Verify all operations succeeded
1574 let stats = graph.stats().await.unwrap();
1575 assert_eq!(stats.session_count, 10); // 5 tasks × 2 sessions each
1576 assert_eq!(stats.node_count, 20); // 10 sessions + 10 prompts
1577 }
1578}