use tokio::sync::mpsc;
#[derive(Clone, Debug)]
pub struct StreamConfig {
pub buffer_size: usize,
pub include_timing: bool,
pub max_tokens_per_sec: u32,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
buffer_size: 64,
include_timing: false,
max_tokens_per_sec: 0, }
}
}
#[derive(Clone, Debug)]
pub struct StreamToken {
pub content: String,
pub sequence: u32,
pub is_valid: bool,
pub timestamp_ms: Option<u64>,
}
impl StreamToken {
pub fn new(content: String, sequence: u32) -> Self {
Self {
content,
sequence,
is_valid: true,
timestamp_ms: None,
}
}
pub fn invalid(sequence: u32) -> Self {
Self {
content: String::new(),
sequence,
is_valid: false,
timestamp_ms: None,
}
}
pub fn with_timing(mut self, timestamp_ms: u64) -> Self {
self.timestamp_ms = Some(timestamp_ms);
self
}
}
pub type StreamReceiver = mpsc::Receiver<StreamToken>;
pub type StreamSender = mpsc::Sender<StreamToken>;
pub fn create_stream_channel(config: StreamConfig) -> (StreamSender, StreamReceiver) {
mpsc::channel(config.buffer_size)
}
#[derive(Clone, Debug, Default)]
pub struct StreamStats {
pub total_tokens: u32,
pub valid_tokens: u32,
pub invalid_tokens: u32,
pub duration_ms: u64,
pub tokens_per_second: f32,
}
impl StreamStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_token(&mut self, is_valid: bool) {
self.total_tokens += 1;
if is_valid {
self.valid_tokens += 1;
} else {
self.invalid_tokens += 1;
}
}
pub fn finalize(&mut self) {
if self.duration_ms > 0 {
self.tokens_per_second =
(self.total_tokens as f64 / (self.duration_ms as f64 / 1000.0)) as f32;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_token_creation() {
let token = StreamToken::new("hello".to_string(), 0);
assert_eq!(token.content, "hello");
assert_eq!(token.sequence, 0);
assert!(token.is_valid);
}
#[test]
fn test_stream_token_invalid() {
let token = StreamToken::invalid(5);
assert_eq!(token.sequence, 5);
assert!(!token.is_valid);
}
#[test]
fn test_stream_token_with_timing() {
let token = StreamToken::new("test".to_string(), 0).with_timing(100);
assert_eq!(token.timestamp_ms, Some(100));
}
#[test]
fn test_stream_stats() {
let mut stats = StreamStats::new();
stats.record_token(true);
stats.record_token(true);
stats.record_token(false);
stats.duration_ms = 1000;
stats.finalize();
assert_eq!(stats.total_tokens, 3);
assert_eq!(stats.valid_tokens, 2);
assert_eq!(stats.invalid_tokens, 1);
assert!((stats.tokens_per_second - 3.0).abs() < 0.01);
}
#[tokio::test]
async fn test_create_stream_channel() {
let config = StreamConfig {
buffer_size: 10,
..Default::default()
};
let (tx, mut rx) = create_stream_channel(config);
let token = StreamToken::new("hello".to_string(), 0);
tx.send(token.clone()).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.content, "hello");
assert_eq!(received.sequence, 0);
}
#[tokio::test]
async fn test_stream_channel_buffer() {
let config = StreamConfig {
buffer_size: 2,
..Default::default()
};
let (tx, mut rx) = create_stream_channel(config);
tx.send(StreamToken::new("1".to_string(), 0)).await.unwrap();
tx.send(StreamToken::new("2".to_string(), 1)).await.unwrap();
assert_eq!(rx.recv().await.unwrap().content, "1");
assert_eq!(rx.recv().await.unwrap().content, "2");
}
}