use std::sync::Arc;
use async_trait::async_trait;
use parking_lot::Mutex;
use rs_genai::prelude::{Content, FunctionResponse};
use rs_genai::session::{SessionError, SessionWriter};
pub struct PendingContext {
buffer: Mutex<Vec<Content>>,
prompt: Mutex<bool>,
}
impl PendingContext {
pub fn new() -> Self {
Self {
buffer: Mutex::new(Vec::new()),
prompt: Mutex::new(false),
}
}
pub fn push(&self, content: Content) {
self.buffer.lock().push(content);
}
pub fn extend(&self, contents: Vec<Content>) {
if !contents.is_empty() {
self.buffer.lock().extend(contents);
}
}
pub fn set_prompt(&self) {
*self.prompt.lock() = true;
}
pub fn drain(&self) -> (Vec<Content>, bool) {
let contents = {
let mut buf = self.buffer.lock();
std::mem::take(&mut *buf)
};
let prompt = {
let mut p = self.prompt.lock();
std::mem::replace(&mut *p, false)
};
(contents, prompt)
}
pub fn is_empty(&self) -> bool {
self.buffer.lock().is_empty() && !*self.prompt.lock()
}
}
impl Default for PendingContext {
fn default() -> Self {
Self::new()
}
}
pub struct DeferredWriter {
inner: Arc<dyn SessionWriter>,
pending: Arc<PendingContext>,
}
impl DeferredWriter {
pub fn new(inner: Arc<dyn SessionWriter>, pending: Arc<PendingContext>) -> Self {
Self { inner, pending }
}
async fn flush(&self) -> Result<(), SessionError> {
let (contents, prompt) = self.pending.drain();
if !contents.is_empty() {
self.inner.send_client_content(contents, false).await?;
}
if prompt {
self.inner.send_client_content(vec![], true).await?;
}
Ok(())
}
pub fn pending(&self) -> &Arc<PendingContext> {
&self.pending
}
}
#[async_trait]
impl SessionWriter for DeferredWriter {
async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
self.flush().await?;
self.inner.send_audio(data).await
}
async fn send_text(&self, text: String) -> Result<(), SessionError> {
self.flush().await?;
self.inner.send_text(text).await
}
async fn send_tool_response(
&self,
responses: Vec<FunctionResponse>,
) -> Result<(), SessionError> {
self.inner.send_tool_response(responses).await
}
async fn send_client_content(
&self,
turns: Vec<Content>,
turn_complete: bool,
) -> Result<(), SessionError> {
self.inner.send_client_content(turns, turn_complete).await
}
async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
self.flush().await?;
self.inner.send_video(jpeg_data).await
}
async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
self.inner.update_instruction(instruction).await
}
async fn signal_activity_start(&self) -> Result<(), SessionError> {
self.inner.signal_activity_start().await
}
async fn signal_activity_end(&self) -> Result<(), SessionError> {
self.inner.signal_activity_end().await
}
async fn disconnect(&self) -> Result<(), SessionError> {
let _ = self.flush().await;
self.inner.disconnect().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingWriter {
audio_count: AtomicUsize,
text_count: AtomicUsize,
client_content_count: AtomicUsize,
video_count: AtomicUsize,
}
impl CountingWriter {
fn new() -> Self {
Self {
audio_count: AtomicUsize::new(0),
text_count: AtomicUsize::new(0),
client_content_count: AtomicUsize::new(0),
video_count: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl SessionWriter for CountingWriter {
async fn send_audio(&self, _: Vec<u8>) -> Result<(), SessionError> {
self.audio_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn send_text(&self, _: String) -> Result<(), SessionError> {
self.text_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn send_tool_response(&self, _: Vec<FunctionResponse>) -> Result<(), SessionError> {
Ok(())
}
async fn send_client_content(&self, _: Vec<Content>, _: bool) -> Result<(), SessionError> {
self.client_content_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn send_video(&self, _: Vec<u8>) -> Result<(), SessionError> {
self.video_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn update_instruction(&self, _: String) -> Result<(), SessionError> {
Ok(())
}
async fn signal_activity_start(&self) -> Result<(), SessionError> {
Ok(())
}
async fn signal_activity_end(&self) -> Result<(), SessionError> {
Ok(())
}
async fn disconnect(&self) -> Result<(), SessionError> {
Ok(())
}
}
#[test]
fn pending_context_push_and_drain() {
let pc = PendingContext::new();
assert!(pc.is_empty());
pc.push(Content::model("context 1"));
pc.push(Content::model("context 2"));
assert!(!pc.is_empty());
let (contents, prompt) = pc.drain();
assert_eq!(contents.len(), 2);
assert!(!prompt);
assert!(pc.is_empty());
}
#[test]
fn pending_context_extend() {
let pc = PendingContext::new();
pc.extend(vec![
Content::model("a"),
Content::model("b"),
Content::model("c"),
]);
let (contents, _) = pc.drain();
assert_eq!(contents.len(), 3);
}
#[test]
fn pending_context_prompt_flag() {
let pc = PendingContext::new();
pc.push(Content::model("ctx"));
pc.set_prompt();
assert!(!pc.is_empty());
let (contents, prompt) = pc.drain();
assert_eq!(contents.len(), 1);
assert!(prompt);
assert!(pc.is_empty());
}
#[test]
fn pending_context_drain_clears() {
let pc = PendingContext::new();
pc.push(Content::model("a"));
pc.set_prompt();
let _ = pc.drain();
let (contents, prompt) = pc.drain();
assert!(contents.is_empty());
assert!(!prompt);
}
#[tokio::test]
async fn deferred_writer_flushes_on_send_audio() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
pending.push(Content::model("steering context"));
pending.push(Content::model("phase instruction"));
writer.send_audio(vec![0u8; 100]).await.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
assert!(pending.is_empty());
}
#[tokio::test]
async fn deferred_writer_flushes_on_send_text() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
pending.push(Content::model("context"));
writer.send_text("hello".into()).await.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
assert_eq!(inner.text_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn deferred_writer_flushes_on_send_video() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
pending.push(Content::model("context"));
writer.send_video(vec![0xFFu8; 50]).await.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
assert_eq!(inner.video_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn deferred_writer_no_flush_when_empty() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
writer.send_audio(vec![0u8; 100]).await.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 0);
assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn deferred_writer_flushes_prompt_after_context() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
pending.push(Content::model("repair nudge"));
pending.set_prompt();
writer.send_audio(vec![0u8; 100]).await.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 2);
assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn deferred_writer_does_not_flush_on_tool_response() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
pending.push(Content::model("context"));
writer.send_tool_response(vec![]).await.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 0);
assert!(!pending.is_empty());
}
#[tokio::test]
async fn deferred_writer_client_content_passes_through() {
let inner = Arc::new(CountingWriter::new());
let pending = Arc::new(PendingContext::new());
let writer = DeferredWriter::new(inner.clone(), pending.clone());
pending.push(Content::model("queued context"));
writer
.send_client_content(vec![Content::user("explicit")], true)
.await
.unwrap();
assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
assert!(!pending.is_empty());
}
}