pub mod logprobs;
use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
EngineStream, ResponseStream,
};
use std::sync::Arc;
pub type RecordedStreamReceiver<R> = oneshot::Receiver<RecordedStream<R>>;
pub type RecordingResult<R> = (EngineStream<R>, RecordedStreamReceiver<R>);
#[derive(Debug, Clone)]
pub struct TimestampedResponse<T> {
pub response: T,
pub timestamp: Instant,
pub sequence_number: usize,
}
impl<T> TimestampedResponse<T> {
pub fn new(response: T, sequence_number: usize) -> Self {
Self {
response,
timestamp: Instant::now(),
sequence_number,
}
}
pub fn data(&self) -> &T {
&self.response
}
pub fn elapsed_since(&self, start_time: Instant) -> Duration {
self.timestamp.duration_since(start_time)
}
}
pub trait CapacityHint {
fn estimated_response_count(&self) -> Option<usize>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecordingMode {
Scan,
Sink,
}
#[derive(Debug, Clone)]
pub struct RecordedStream<T> {
responses: Vec<TimestampedResponse<T>>,
start_time: Instant,
end_time: Instant,
}
impl<T> RecordedStream<T> {
pub fn new(
responses: Vec<TimestampedResponse<T>>,
start_time: Instant,
end_time: Instant,
) -> Self {
Self {
responses,
start_time,
end_time,
}
}
pub fn response_count(&self) -> usize {
self.responses.len()
}
pub fn total_duration(&self) -> Duration {
self.end_time.duration_since(self.start_time)
}
pub fn responses(&self) -> &[TimestampedResponse<T>] {
&self.responses
}
pub fn start_time(&self) -> &Instant {
&self.start_time
}
pub fn end_time(&self) -> &Instant {
&self.end_time
}
}
pub struct RecordingStream<R: Data> {
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
responses: Vec<TimestampedResponse<R>>,
start_time: Instant,
recorded_tx: Option<oneshot::Sender<RecordedStream<R>>>,
}
impl<R: Data> Unpin for RecordingStream<R> {}
impl<R: Data + Clone> RecordingStream<R> {
pub fn from_stream_and_context(
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
capacity: Option<usize>,
recorded_tx: oneshot::Sender<RecordedStream<R>>,
) -> Self {
let mut responses = Vec::new();
if let Some(cap) = capacity {
responses.reserve(cap);
}
Self {
stream,
ctx,
mode,
responses,
start_time: Instant::now(),
recorded_tx: Some(recorded_tx),
}
}
fn from_async_engine_stream(
stream: EngineStream<R>,
mode: RecordingMode,
capacity: Option<usize>,
recorded_tx: oneshot::Sender<RecordedStream<R>>,
) -> Self {
let ctx = stream.context();
Self::from_stream_and_context(stream, ctx, mode, capacity, recorded_tx)
}
pub fn into_async_engine_stream(self) -> EngineStream<R> {
Box::pin(self)
}
}
impl<R: Data + Clone> Stream for RecordingStream<R> {
type Item = R;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().get_mut();
match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(Some(item)) => {
let timestamp = Instant::now();
let sequence_number = this.responses.len();
match this.mode {
RecordingMode::Scan => {
let timestamped = TimestampedResponse {
response: item.clone(),
timestamp,
sequence_number,
};
this.responses.push(timestamped);
Poll::Ready(Some(item)) }
RecordingMode::Sink => {
let timestamped = TimestampedResponse {
response: item, timestamp,
sequence_number,
};
this.responses.push(timestamped);
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Poll::Ready(None) => {
if let Some(tx) = this.recorded_tx.take() {
let recorded = RecordedStream::new(
std::mem::take(&mut this.responses),
this.start_time,
Instant::now(),
);
let _ = tx.send(recorded); }
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl<R: Data + Clone> AsyncEngineStream<R> for RecordingStream<R> {}
impl<R: Data + Clone> AsyncEngineContextProvider for RecordingStream<R> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.ctx.clone()
}
}
impl<R: Data + Clone> std::fmt::Debug for RecordingStream<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordingStream")
.field("mode", &self.mode)
.field("responses_count", &self.responses.len())
.field("ctx", &self.ctx)
.finish()
}
}
pub fn record_stream<R: Data + Clone>(
stream: EngineStream<R>,
mode: RecordingMode,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream = RecordingStream::from_async_engine_stream(stream, mode, None, tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
pub fn record_stream_with_context<R: Data + Clone>(
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream = RecordingStream::from_stream_and_context(stream, ctx, mode, None, tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
pub fn record_stream_with_capacity<R: Data + Clone>(
stream: EngineStream<R>,
mode: RecordingMode,
capacity: usize,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream =
RecordingStream::from_async_engine_stream(stream, mode, Some(capacity), tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
pub fn record_stream_with_request_hint<R: Data + Clone, Req: CapacityHint>(
stream: EngineStream<R>,
mode: RecordingMode,
request: &Req,
) -> RecordingResult<R> {
let capacity = request.estimated_response_count();
match capacity {
Some(cap) => record_stream_with_capacity(stream, mode, cap),
None => record_stream(stream, mode),
}
}
pub fn record_stream_with_context_and_capacity<R: Data + Clone>(
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
capacity: usize,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream =
RecordingStream::from_stream_and_context(stream, ctx, mode, Some(capacity), tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
pub fn record_response_stream<R: Data + Clone>(
response_stream: Pin<Box<ResponseStream<R>>>,
mode: RecordingMode,
) -> RecordingResult<R> {
record_stream(response_stream, mode)
}
#[cfg(test)]
pub mod tests {
use super::*;
use dynamo_runtime::engine::ResponseStream;
use futures::stream;
use std::time::Duration;
#[test]
fn test_timestamped_response_creation() {
let response = "test response";
let timestamped = TimestampedResponse::new(response, 0);
assert_eq!(timestamped.response, response);
assert_eq!(timestamped.sequence_number, 0);
assert_eq!(timestamped.data(), &response);
}
#[test]
fn test_recorded_stream_analysis() {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse {
response: "response1",
timestamp: start_time,
sequence_number: 0,
},
TimestampedResponse {
response: "response2",
timestamp: start_time + Duration::from_millis(100),
sequence_number: 1,
},
TimestampedResponse {
response: "response3",
timestamp: start_time + Duration::from_millis(250),
sequence_number: 2,
},
];
let end_time = start_time + Duration::from_millis(250);
let recorded = RecordedStream::new(responses, start_time, end_time);
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.total_duration(), Duration::from_millis(250));
}
#[test]
fn test_performance_metrics_conversion() {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse {
response: "test",
timestamp: start_time + Duration::from_millis(50),
sequence_number: 0,
},
TimestampedResponse {
response: "test",
timestamp: start_time + Duration::from_millis(150),
sequence_number: 1,
},
];
let end_time = start_time + Duration::from_millis(150);
let recorded = RecordedStream::new(responses, start_time, end_time);
assert_eq!(recorded.response_count(), 2);
assert_eq!(recorded.total_duration(), Duration::from_millis(150));
}
#[tokio::test]
async fn test_recording_stream_scan_mode() {
use futures::StreamExt;
let test_data = vec!["token1", "token2", "token3"];
let base_stream = stream::iter(test_data.clone());
let ctx = Arc::new(MockContext::new());
let (recorded_stream, recording_rx) =
record_stream_with_context(Box::pin(base_stream), ctx, RecordingMode::Scan);
let collected_responses: Vec<_> = recorded_stream.collect().await;
assert_eq!(collected_responses, test_data);
let recorded = recording_rx.await.unwrap();
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.responses[0].response, "token1");
assert_eq!(recorded.responses[1].response, "token2");
assert_eq!(recorded.responses[2].response, "token3");
assert!(recorded.total_duration() > Duration::from_nanos(0));
}
#[tokio::test]
async fn test_recording_stream_sink_mode() {
use futures::StreamExt;
let test_data = vec!["token1", "token2", "token3"];
let base_stream = stream::iter(test_data.clone());
let ctx = Arc::new(MockContext::new());
let (recorded_stream, recording_rx) =
record_stream_with_context(Box::pin(base_stream), ctx, RecordingMode::Sink);
let collected_responses: Vec<_> = recorded_stream.collect().await;
assert_eq!(collected_responses, Vec::<&str>::new());
let recorded = recording_rx.await.unwrap();
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.responses[0].response, "token1");
assert_eq!(recorded.responses[1].response, "token2");
assert_eq!(recorded.responses[2].response, "token3");
assert!(recorded.total_duration() > Duration::from_nanos(0));
}
#[tokio::test]
async fn test_recording_stream_from_response_stream() {
use futures::StreamExt;
let test_data = vec!["token1", "token2", "token3"];
let base_stream = stream::iter(test_data.clone());
let ctx = Arc::new(MockContext::new());
let response_stream = ResponseStream::new(Box::pin(base_stream), ctx);
let (recorded_stream, recording_rx) =
record_response_stream(response_stream, RecordingMode::Scan);
let collected_responses: Vec<_> = recorded_stream.collect().await;
assert_eq!(collected_responses, test_data);
let recorded = recording_rx.await.unwrap();
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.responses[0].response, "token1");
assert_eq!(recorded.responses[1].response, "token2");
assert_eq!(recorded.responses[2].response, "token3");
assert!(recorded.total_duration() > Duration::from_nanos(0));
}
#[derive(Debug)]
struct MockContext {
id: String,
}
impl MockContext {
fn new() -> Self {
Self {
id: "test-context".to_string(),
}
}
}
#[async_trait::async_trait]
impl AsyncEngineContext for MockContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
}
fn stop_generating(&self) {
}
fn kill(&self) {
}
fn is_stopped(&self) -> bool {
false
}
fn is_killed(&self) -> bool {
false
}
async fn stopped(&self) {
}
async fn killed(&self) {
}
fn link_child(&self, _: Arc<dyn AsyncEngineContext>) {
}
}
}