llm_memory_graph/observatory/
streaming.rs1use super::events::MemoryGraphEvent;
7use crate::error::Result;
8use async_trait::async_trait;
9use futures::stream::Stream;
10use std::pin::Pin;
11use std::sync::Arc;
12use tokio::sync::{broadcast, RwLock};
13
14#[async_trait]
16pub trait EventStream: Send + Sync {
17 async fn publish(&self, event: MemoryGraphEvent) -> Result<()>;
19
20 async fn publish_batch(&self, events: Vec<MemoryGraphEvent>) -> Result<()> {
22 for event in events {
23 self.publish(event).await?;
24 }
25 Ok(())
26 }
27
28 fn subscribe(&self) -> Pin<Box<dyn Stream<Item = MemoryGraphEvent> + Send + '_>>;
30}
31
32#[derive(Clone)]
34pub struct InMemoryEventStream {
35 sender: broadcast::Sender<MemoryGraphEvent>,
36 buffer: Arc<RwLock<Vec<MemoryGraphEvent>>>,
38 buffer_size: usize,
40}
41
42impl InMemoryEventStream {
43 pub fn new(capacity: usize, buffer_size: usize) -> Self {
58 let (sender, _) = broadcast::channel(capacity);
59 Self {
60 sender,
61 buffer: Arc::new(RwLock::new(Vec::new())),
62 buffer_size,
63 }
64 }
65
66 pub async fn get_buffered_events(&self) -> Vec<MemoryGraphEvent> {
68 self.buffer.read().await.clone()
69 }
70
71 pub async fn clear_buffer(&self) {
73 self.buffer.write().await.clear();
74 }
75
76 pub fn subscriber_count(&self) -> usize {
78 self.sender.receiver_count()
79 }
80}
81
82#[async_trait]
83impl EventStream for InMemoryEventStream {
84 async fn publish(&self, event: MemoryGraphEvent) -> Result<()> {
85 let mut buffer = self.buffer.write().await;
87 buffer.push(event.clone());
88
89 if buffer.len() > self.buffer_size {
91 let excess = buffer.len() - self.buffer_size;
92 buffer.drain(0..excess);
93 }
94 drop(buffer);
95
96 let _ = self.sender.send(event);
98
99 Ok(())
100 }
101
102 async fn publish_batch(&self, events: Vec<MemoryGraphEvent>) -> Result<()> {
103 let mut buffer = self.buffer.write().await;
105 buffer.extend(events.iter().cloned());
106
107 if buffer.len() > self.buffer_size {
109 let excess = buffer.len() - self.buffer_size;
110 buffer.drain(0..excess);
111 }
112 drop(buffer);
113
114 for event in events {
116 let _ = self.sender.send(event);
117 }
118
119 Ok(())
120 }
121
122 fn subscribe(&self) -> Pin<Box<dyn Stream<Item = MemoryGraphEvent> + Send + '_>> {
123 let receiver = self.sender.subscribe();
124 Box::pin(async_stream::stream! {
125 let mut rx = receiver;
126 while let Ok(event) = rx.recv().await {
127 yield event;
128 }
129 })
130 }
131}
132
133pub struct MultiEventStream {
135 streams: Vec<Arc<dyn EventStream>>,
136}
137
138impl MultiEventStream {
139 pub fn new(streams: Vec<Arc<dyn EventStream>>) -> Self {
141 Self { streams }
142 }
143
144 pub fn add_stream(&mut self, stream: Arc<dyn EventStream>) {
146 self.streams.push(stream);
147 }
148}
149
150#[async_trait]
151impl EventStream for MultiEventStream {
152 async fn publish(&self, event: MemoryGraphEvent) -> Result<()> {
153 let futures: Vec<_> = self
154 .streams
155 .iter()
156 .map(|stream| stream.publish(event.clone()))
157 .collect();
158
159 futures::future::try_join_all(futures).await?;
160 Ok(())
161 }
162
163 async fn publish_batch(&self, events: Vec<MemoryGraphEvent>) -> Result<()> {
164 let futures: Vec<_> = self
165 .streams
166 .iter()
167 .map(|stream| stream.publish_batch(events.clone()))
168 .collect();
169
170 futures::future::try_join_all(futures).await?;
171 Ok(())
172 }
173
174 fn subscribe(&self) -> Pin<Box<dyn Stream<Item = MemoryGraphEvent> + Send + '_>> {
175 if let Some(first) = self.streams.first() {
177 first.subscribe()
178 } else {
179 Box::pin(futures::stream::empty())
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::types::{NodeId, NodeType, SessionId};
188 use chrono::Utc;
189 use futures::StreamExt;
190 use std::collections::HashMap;
191
192 #[tokio::test]
193 async fn test_in_memory_stream_creation() {
194 let stream = InMemoryEventStream::new(100, 1000);
195 assert_eq!(stream.subscriber_count(), 0);
196 }
197
198 #[tokio::test]
199 async fn test_publish_and_subscribe() {
200 let stream = InMemoryEventStream::new(100, 1000);
201 let mut subscription = stream.subscribe();
202
203 let event = MemoryGraphEvent::NodeCreated {
204 node_id: NodeId::new(),
205 node_type: NodeType::Prompt,
206 session_id: Some(SessionId::new()),
207 timestamp: Utc::now(),
208 metadata: HashMap::new(),
209 };
210
211 stream.publish(event.clone()).await.unwrap();
213
214 let received = subscription.next().await.unwrap();
216 assert_eq!(received.event_type(), event.event_type());
217 }
218
219 #[tokio::test]
220 async fn test_multiple_subscribers() {
221 let stream = InMemoryEventStream::new(100, 1000);
222 let mut sub1 = stream.subscribe();
223 let mut sub2 = stream.subscribe();
224 let mut sub3 = stream.subscribe();
225
226 assert_eq!(stream.subscriber_count(), 3);
227
228 let event = MemoryGraphEvent::QueryExecuted {
229 query_type: "test".to_string(),
230 results_count: 42,
231 duration_ms: 100,
232 timestamp: Utc::now(),
233 };
234
235 stream.publish(event.clone()).await.unwrap();
236
237 let r1 = sub1.next().await.unwrap();
239 let r2 = sub2.next().await.unwrap();
240 let r3 = sub3.next().await.unwrap();
241
242 assert_eq!(r1.event_type(), "query_executed");
243 assert_eq!(r2.event_type(), "query_executed");
244 assert_eq!(r3.event_type(), "query_executed");
245 }
246
247 #[tokio::test]
248 async fn test_event_buffer() {
249 let stream = InMemoryEventStream::new(100, 10);
250
251 for i in 0..5 {
253 let event = MemoryGraphEvent::NodeCreated {
254 node_id: NodeId::new(),
255 node_type: NodeType::Prompt,
256 session_id: None,
257 timestamp: Utc::now(),
258 metadata: HashMap::from([("index".to_string(), i.to_string())]),
259 };
260 stream.publish(event).await.unwrap();
261 }
262
263 let buffered = stream.get_buffered_events().await;
264 assert_eq!(buffered.len(), 5);
265 }
266
267 #[tokio::test]
268 async fn test_buffer_trimming() {
269 let stream = InMemoryEventStream::new(100, 5);
270
271 for i in 0..10 {
273 let event = MemoryGraphEvent::QueryExecuted {
274 query_type: format!("query_{}", i),
275 results_count: i,
276 duration_ms: 100,
277 timestamp: Utc::now(),
278 };
279 stream.publish(event).await.unwrap();
280 }
281
282 let buffered = stream.get_buffered_events().await;
283 assert_eq!(buffered.len(), 5);
284
285 if let MemoryGraphEvent::QueryExecuted { results_count, .. } = &buffered[0] {
287 assert_eq!(*results_count, 5);
288 } else {
289 panic!("Wrong event type");
290 }
291 }
292
293 #[tokio::test]
294 async fn test_clear_buffer() {
295 let stream = InMemoryEventStream::new(100, 100);
296
297 for _ in 0..5 {
298 let event = MemoryGraphEvent::NodeCreated {
299 node_id: NodeId::new(),
300 node_type: NodeType::Prompt,
301 session_id: None,
302 timestamp: Utc::now(),
303 metadata: HashMap::new(),
304 };
305 stream.publish(event).await.unwrap();
306 }
307
308 assert_eq!(stream.get_buffered_events().await.len(), 5);
309
310 stream.clear_buffer().await;
311 assert_eq!(stream.get_buffered_events().await.len(), 0);
312 }
313
314 #[tokio::test]
315 async fn test_publish_batch() {
316 let stream = InMemoryEventStream::new(100, 100);
317 let mut subscription = stream.subscribe();
318
319 let events = vec![
320 MemoryGraphEvent::NodeCreated {
321 node_id: NodeId::new(),
322 node_type: NodeType::Prompt,
323 session_id: None,
324 timestamp: Utc::now(),
325 metadata: HashMap::new(),
326 },
327 MemoryGraphEvent::NodeCreated {
328 node_id: NodeId::new(),
329 node_type: NodeType::Response,
330 session_id: None,
331 timestamp: Utc::now(),
332 metadata: HashMap::new(),
333 },
334 ];
335
336 stream.publish_batch(events.clone()).await.unwrap();
337
338 let e1 = subscription.next().await.unwrap();
340 let e2 = subscription.next().await.unwrap();
341
342 assert_eq!(e1.event_type(), "node_created");
343 assert_eq!(e2.event_type(), "node_created");
344
345 let buffered = stream.get_buffered_events().await;
347 assert_eq!(buffered.len(), 2);
348 }
349
350 #[tokio::test]
351 async fn test_multi_event_stream() {
352 let stream1 = Arc::new(InMemoryEventStream::new(100, 100));
353 let stream2 = Arc::new(InMemoryEventStream::new(100, 100));
354
355 let multi = MultiEventStream::new(vec![stream1.clone(), stream2.clone()]);
356
357 let event = MemoryGraphEvent::QueryExecuted {
358 query_type: "test".to_string(),
359 results_count: 10,
360 duration_ms: 50,
361 timestamp: Utc::now(),
362 };
363
364 multi.publish(event).await.unwrap();
365
366 let buf1 = stream1.get_buffered_events().await;
368 let buf2 = stream2.get_buffered_events().await;
369
370 assert_eq!(buf1.len(), 1);
371 assert_eq!(buf2.len(), 1);
372 }
373
374 #[tokio::test]
375 async fn test_multi_stream_batch() {
376 let stream1 = Arc::new(InMemoryEventStream::new(100, 100));
377 let stream2 = Arc::new(InMemoryEventStream::new(100, 100));
378
379 let multi = MultiEventStream::new(vec![stream1.clone(), stream2.clone()]);
380
381 let events = vec![
382 MemoryGraphEvent::NodeCreated {
383 node_id: NodeId::new(),
384 node_type: NodeType::Prompt,
385 session_id: None,
386 timestamp: Utc::now(),
387 metadata: HashMap::new(),
388 },
389 MemoryGraphEvent::NodeCreated {
390 node_id: NodeId::new(),
391 node_type: NodeType::Response,
392 session_id: None,
393 timestamp: Utc::now(),
394 metadata: HashMap::new(),
395 },
396 ];
397
398 multi.publish_batch(events).await.unwrap();
399
400 assert_eq!(stream1.get_buffered_events().await.len(), 2);
402 assert_eq!(stream2.get_buffered_events().await.len(), 2);
403 }
404
405 #[tokio::test]
406 async fn test_concurrent_publishing() {
407 let stream = Arc::new(InMemoryEventStream::new(1000, 1000));
408
409 let mut handles = vec![];
410
411 for i in 0..10 {
413 let stream_clone = Arc::clone(&stream);
414 let handle = tokio::spawn(async move {
415 for j in 0..10 {
416 let event = MemoryGraphEvent::QueryExecuted {
417 query_type: format!("query_{}_{}", i, j),
418 results_count: j,
419 duration_ms: 100,
420 timestamp: Utc::now(),
421 };
422 stream_clone.publish(event).await.unwrap();
423 }
424 });
425 handles.push(handle);
426 }
427
428 for handle in handles {
429 handle.await.unwrap();
430 }
431
432 let buffered = stream.get_buffered_events().await;
433 assert_eq!(buffered.len(), 100); }
435}