1pub mod logprobs;
10
11use futures::Stream;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::{Duration, Instant};
15use tokio::sync::oneshot;
16
17use dynamo_runtime::engine::{
19 AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
20 EngineStream, ResponseStream,
21};
22use std::sync::Arc;
23
24pub type RecordedStreamReceiver<R> = oneshot::Receiver<RecordedStream<R>>;
26
27pub type RecordingResult<R> = (EngineStream<R>, RecordedStreamReceiver<R>);
29
30#[derive(Debug, Clone)]
32pub struct TimestampedResponse<T> {
33 pub response: T,
35 pub timestamp: Instant,
37 pub sequence_number: usize,
39}
40
41impl<T> TimestampedResponse<T> {
42 pub fn new(response: T, sequence_number: usize) -> Self {
44 Self {
45 response,
46 timestamp: Instant::now(),
47 sequence_number,
48 }
49 }
50
51 pub fn data(&self) -> &T {
53 &self.response
54 }
55
56 pub fn elapsed_since(&self, start_time: Instant) -> Duration {
58 self.timestamp.duration_since(start_time)
59 }
60}
61
62pub trait CapacityHint {
65 fn estimated_response_count(&self) -> Option<usize>;
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum RecordingMode {
73 Scan,
76 Sink,
79}
80
81#[derive(Debug, Clone)]
84pub struct RecordedStream<T> {
85 responses: Vec<TimestampedResponse<T>>,
87
88 start_time: Instant,
90
91 end_time: Instant,
93}
94
95impl<T> RecordedStream<T> {
96 pub fn new(
98 responses: Vec<TimestampedResponse<T>>,
99 start_time: Instant,
100 end_time: Instant,
101 ) -> Self {
102 Self {
103 responses,
104 start_time,
105 end_time,
106 }
107 }
108
109 pub fn response_count(&self) -> usize {
111 self.responses.len()
112 }
113
114 pub fn total_duration(&self) -> Duration {
116 self.end_time.duration_since(self.start_time)
117 }
118
119 pub fn responses(&self) -> &[TimestampedResponse<T>] {
121 &self.responses
122 }
123
124 pub fn start_time(&self) -> &Instant {
126 &self.start_time
127 }
128
129 pub fn end_time(&self) -> &Instant {
131 &self.end_time
132 }
133}
134
135pub struct RecordingStream<R: Data> {
138 stream: DataStream<R>,
140 ctx: Arc<dyn AsyncEngineContext>,
142 mode: RecordingMode,
144 responses: Vec<TimestampedResponse<R>>,
146 start_time: Instant,
148 recorded_tx: Option<oneshot::Sender<RecordedStream<R>>>,
150}
151
152impl<R: Data> Unpin for RecordingStream<R> {}
153
154impl<R: Data + Clone> RecordingStream<R> {
155 pub fn from_stream_and_context(
157 stream: DataStream<R>,
158 ctx: Arc<dyn AsyncEngineContext>,
159 mode: RecordingMode,
160 capacity: Option<usize>,
161 recorded_tx: oneshot::Sender<RecordedStream<R>>,
162 ) -> Self {
163 let mut responses = Vec::new();
164 if let Some(cap) = capacity {
165 responses.reserve(cap);
166 }
167
168 Self {
169 stream,
170 ctx,
171 mode,
172 responses,
173 start_time: Instant::now(),
174 recorded_tx: Some(recorded_tx),
175 }
176 }
177
178 fn from_async_engine_stream(
180 stream: EngineStream<R>,
181 mode: RecordingMode,
182 capacity: Option<usize>,
183 recorded_tx: oneshot::Sender<RecordedStream<R>>,
184 ) -> Self {
185 let ctx = stream.context();
186 Self::from_stream_and_context(stream, ctx, mode, capacity, recorded_tx)
187 }
188
189 pub fn into_async_engine_stream(self) -> EngineStream<R> {
191 Box::pin(self)
192 }
193}
194
195impl<R: Data + Clone> Stream for RecordingStream<R> {
196 type Item = R;
197
198 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199 let this = self.as_mut().get_mut();
200
201 match Pin::new(&mut this.stream).poll_next(cx) {
202 Poll::Ready(Some(item)) => {
203 let timestamp = Instant::now();
205 let sequence_number = this.responses.len();
206
207 match this.mode {
208 RecordingMode::Scan => {
209 let timestamped = TimestampedResponse {
211 response: item.clone(),
212 timestamp,
213 sequence_number,
214 };
215 this.responses.push(timestamped);
216 Poll::Ready(Some(item)) }
218 RecordingMode::Sink => {
219 let timestamped = TimestampedResponse {
221 response: item, timestamp,
223 sequence_number,
224 };
225 this.responses.push(timestamped);
226
227 cx.waker().wake_by_ref();
230 Poll::Pending
231 }
232 }
233 }
234 Poll::Ready(None) => {
235 if let Some(tx) = this.recorded_tx.take() {
237 let recorded = RecordedStream::new(
238 std::mem::take(&mut this.responses),
239 this.start_time,
240 Instant::now(),
241 );
242 let _ = tx.send(recorded); }
244
245 Poll::Ready(None)
246 }
247 Poll::Pending => Poll::Pending,
248 }
249 }
250}
251
252impl<R: Data + Clone> AsyncEngineStream<R> for RecordingStream<R> {}
253
254impl<R: Data + Clone> AsyncEngineContextProvider for RecordingStream<R> {
255 fn context(&self) -> Arc<dyn AsyncEngineContext> {
256 self.ctx.clone()
257 }
258}
259
260impl<R: Data + Clone> std::fmt::Debug for RecordingStream<R> {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("RecordingStream")
263 .field("mode", &self.mode)
264 .field("responses_count", &self.responses.len())
265 .field("ctx", &self.ctx)
266 .finish()
267 }
268}
269
270pub fn record_stream<R: Data + Clone>(
273 stream: EngineStream<R>,
274 mode: RecordingMode,
275) -> RecordingResult<R> {
276 let (tx, rx) = oneshot::channel();
277 let recording_stream = RecordingStream::from_async_engine_stream(stream, mode, None, tx);
278 let boxed_stream = Box::pin(recording_stream);
279 (boxed_stream, rx)
280}
281
282pub fn record_stream_with_context<R: Data + Clone>(
285 stream: DataStream<R>,
286 ctx: Arc<dyn AsyncEngineContext>,
287 mode: RecordingMode,
288) -> RecordingResult<R> {
289 let (tx, rx) = oneshot::channel();
290 let recording_stream = RecordingStream::from_stream_and_context(stream, ctx, mode, None, tx);
291 let boxed_stream = Box::pin(recording_stream);
292 (boxed_stream, rx)
293}
294
295pub fn record_stream_with_capacity<R: Data + Clone>(
297 stream: EngineStream<R>,
298 mode: RecordingMode,
299 capacity: usize,
300) -> RecordingResult<R> {
301 let (tx, rx) = oneshot::channel();
302 let recording_stream =
303 RecordingStream::from_async_engine_stream(stream, mode, Some(capacity), tx);
304 let boxed_stream = Box::pin(recording_stream);
305 (boxed_stream, rx)
306}
307
308pub fn record_stream_with_request_hint<R: Data + Clone, Req: CapacityHint>(
310 stream: EngineStream<R>,
311 mode: RecordingMode,
312 request: &Req,
313) -> RecordingResult<R> {
314 let capacity = request.estimated_response_count();
315 match capacity {
316 Some(cap) => record_stream_with_capacity(stream, mode, cap),
317 None => record_stream(stream, mode),
318 }
319}
320
321pub fn record_stream_with_context_and_capacity<R: Data + Clone>(
323 stream: DataStream<R>,
324 ctx: Arc<dyn AsyncEngineContext>,
325 mode: RecordingMode,
326 capacity: usize,
327) -> RecordingResult<R> {
328 let (tx, rx) = oneshot::channel();
329 let recording_stream =
330 RecordingStream::from_stream_and_context(stream, ctx, mode, Some(capacity), tx);
331 let boxed_stream = Box::pin(recording_stream);
332 (boxed_stream, rx)
333}
334
335pub fn record_response_stream<R: Data + Clone>(
337 response_stream: Pin<Box<ResponseStream<R>>>,
338 mode: RecordingMode,
339) -> RecordingResult<R> {
340 record_stream(response_stream, mode)
341}
342
343#[cfg(test)]
344pub mod tests {
345 use super::*;
346 use dynamo_runtime::engine::ResponseStream;
347 use futures::stream;
348 use std::time::Duration;
349
350 #[test]
351 fn test_timestamped_response_creation() {
352 let response = "test response";
353 let timestamped = TimestampedResponse::new(response, 0);
354
355 assert_eq!(timestamped.response, response);
356 assert_eq!(timestamped.sequence_number, 0);
357 assert_eq!(timestamped.data(), &response);
358 }
359
360 #[test]
361 fn test_recorded_stream_analysis() {
362 let start_time = Instant::now();
363
364 let responses = vec![
366 TimestampedResponse {
367 response: "response1",
368 timestamp: start_time,
369 sequence_number: 0,
370 },
371 TimestampedResponse {
372 response: "response2",
373 timestamp: start_time + Duration::from_millis(100),
374 sequence_number: 1,
375 },
376 TimestampedResponse {
377 response: "response3",
378 timestamp: start_time + Duration::from_millis(250),
379 sequence_number: 2,
380 },
381 ];
382
383 let end_time = start_time + Duration::from_millis(250);
384 let recorded = RecordedStream::new(responses, start_time, end_time);
385
386 assert_eq!(recorded.response_count(), 3);
387 assert_eq!(recorded.total_duration(), Duration::from_millis(250));
388 }
389
390 #[test]
391 fn test_performance_metrics_conversion() {
392 let start_time = Instant::now();
393 let responses = vec![
394 TimestampedResponse {
395 response: "test",
396 timestamp: start_time + Duration::from_millis(50),
397 sequence_number: 0,
398 },
399 TimestampedResponse {
400 response: "test",
401 timestamp: start_time + Duration::from_millis(150),
402 sequence_number: 1,
403 },
404 ];
405
406 let end_time = start_time + Duration::from_millis(150);
407 let recorded = RecordedStream::new(responses, start_time, end_time);
408
409 assert_eq!(recorded.response_count(), 2);
410 assert_eq!(recorded.total_duration(), Duration::from_millis(150));
411 }
412
413 #[tokio::test]
414 async fn test_recording_stream_scan_mode() {
415 use futures::StreamExt;
416
417 let test_data = vec!["token1", "token2", "token3"];
419 let base_stream = stream::iter(test_data.clone());
420
421 let ctx = Arc::new(MockContext::new());
423
424 let (recorded_stream, recording_rx) =
426 record_stream_with_context(Box::pin(base_stream), ctx, RecordingMode::Scan);
427
428 let collected_responses: Vec<_> = recorded_stream.collect().await;
430
431 assert_eq!(collected_responses, test_data);
433
434 let recorded = recording_rx.await.unwrap();
436 assert_eq!(recorded.response_count(), 3);
437 assert_eq!(recorded.responses[0].response, "token1");
438 assert_eq!(recorded.responses[1].response, "token2");
439 assert_eq!(recorded.responses[2].response, "token3");
440
441 assert!(recorded.total_duration() > Duration::from_nanos(0));
443 }
444
445 #[tokio::test]
446 async fn test_recording_stream_sink_mode() {
447 use futures::StreamExt;
448
449 let test_data = vec!["token1", "token2", "token3"];
451 let base_stream = stream::iter(test_data.clone());
452
453 let ctx = Arc::new(MockContext::new());
455
456 let (recorded_stream, recording_rx) =
458 record_stream_with_context(Box::pin(base_stream), ctx, RecordingMode::Sink);
459
460 let collected_responses: Vec<_> = recorded_stream.collect().await;
462 assert_eq!(collected_responses, Vec::<&str>::new());
463
464 let recorded = recording_rx.await.unwrap();
466 assert_eq!(recorded.response_count(), 3);
467 assert_eq!(recorded.responses[0].response, "token1");
468 assert_eq!(recorded.responses[1].response, "token2");
469 assert_eq!(recorded.responses[2].response, "token3");
470
471 assert!(recorded.total_duration() > Duration::from_nanos(0));
473 }
474
475 #[tokio::test]
476 async fn test_recording_stream_from_response_stream() {
477 use futures::StreamExt;
478
479 let test_data = vec!["token1", "token2", "token3"];
481 let base_stream = stream::iter(test_data.clone());
482
483 let ctx = Arc::new(MockContext::new());
485 let response_stream = ResponseStream::new(Box::pin(base_stream), ctx);
486
487 let (recorded_stream, recording_rx) =
489 record_response_stream(response_stream, RecordingMode::Scan);
490
491 let collected_responses: Vec<_> = recorded_stream.collect().await;
493
494 assert_eq!(collected_responses, test_data);
496
497 let recorded = recording_rx.await.unwrap();
499 assert_eq!(recorded.response_count(), 3);
500 assert_eq!(recorded.responses[0].response, "token1");
501 assert_eq!(recorded.responses[1].response, "token2");
502 assert_eq!(recorded.responses[2].response, "token3");
503
504 assert!(recorded.total_duration() > Duration::from_nanos(0));
506 }
507
508 #[derive(Debug)]
510 struct MockContext {
511 id: String,
512 }
513
514 impl MockContext {
515 fn new() -> Self {
516 Self {
517 id: "test-context".to_string(),
518 }
519 }
520 }
521
522 #[async_trait::async_trait]
523 impl AsyncEngineContext for MockContext {
524 fn id(&self) -> &str {
525 &self.id
526 }
527
528 fn stop(&self) {
529 }
531
532 fn stop_generating(&self) {
533 }
535
536 fn kill(&self) {
537 }
539
540 fn is_stopped(&self) -> bool {
541 false
542 }
543
544 fn is_killed(&self) -> bool {
545 false
546 }
547
548 async fn stopped(&self) {
549 }
551
552 async fn killed(&self) {
553 }
555
556 fn link_child(&self, _: Arc<dyn AsyncEngineContext>) {
557 }
559 }
560}