use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use chrono::Utc;
use log;
use uuid::Uuid;
use crate::billing::collector::BillingCollector;
use crate::billing::events::{TranscriptEntry, TranscriptRole};
use crate::context::LLMContext;
use crate::error::Result;
use crate::frames::{
DataFrame, Frame, FrameDirection, FrameHandler, FrameInner, FrameProcessor, SystemFrame,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LateTranscriptPolicy {
Defer,
Discard,
}
struct State {
aggregation: String,
deferred: String,
turn_open: bool,
}
pub struct LLMUserAggregator {
context: Arc<Mutex<LLMContext>>,
billing: Option<Arc<dyn BillingCollector>>,
active_user_turn_id: Arc<Mutex<Option<Uuid>>>,
policy: LateTranscriptPolicy,
state: Mutex<State>,
}
impl LLMUserAggregator {
pub fn new(context: Arc<Mutex<LLMContext>>) -> FrameProcessor {
Self::build(context, None, Arc::new(Mutex::new(None)), LateTranscriptPolicy::Defer)
}
pub fn new_with_policy(
context: Arc<Mutex<LLMContext>>,
policy: LateTranscriptPolicy,
) -> FrameProcessor {
Self::build(context, None, Arc::new(Mutex::new(None)), policy)
}
pub fn with_billing(
context: Arc<Mutex<LLMContext>>,
billing: Arc<dyn BillingCollector>,
active_user_turn_id: Arc<Mutex<Option<Uuid>>>,
) -> FrameProcessor {
Self::build(context, Some(billing), active_user_turn_id, LateTranscriptPolicy::Defer)
}
pub fn with_billing_and_policy(
context: Arc<Mutex<LLMContext>>,
billing: Arc<dyn BillingCollector>,
active_user_turn_id: Arc<Mutex<Option<Uuid>>>,
policy: LateTranscriptPolicy,
) -> FrameProcessor {
Self::build(context, Some(billing), active_user_turn_id, policy)
}
fn build(
context: Arc<Mutex<LLMContext>>,
billing: Option<Arc<dyn BillingCollector>>,
active_user_turn_id: Arc<Mutex<Option<Uuid>>>,
policy: LateTranscriptPolicy,
) -> FrameProcessor {
FrameProcessor::new(
"LLMUserAggregator",
Box::new(Self {
context,
billing,
active_user_turn_id,
policy,
state: Mutex::new(State {
aggregation: String::new(),
deferred: String::new(),
turn_open: false,
}),
}),
false,
)
}
async fn flush_and_trigger(&self, processor: &FrameProcessor) -> Result<bool> {
let text = {
let mut state = self.state.lock().unwrap();
let combined = combine(&state.deferred, &state.aggregation);
state.deferred.clear();
state.aggregation.clear();
combined
};
if text.is_empty() {
log::debug!("LLMUserAggregator: empty turn, skipping LLM trigger");
return Ok(false);
}
log::info!("LLMUserAggregator: user said: '{}'", text);
self.context.lock().unwrap().add_user_message(&text);
if let Some(billing) = &self.billing {
let turn_id = self.active_user_turn_id
.lock().unwrap()
.unwrap_or_else(Uuid::new_v4);
billing.record_transcript(TranscriptEntry {
turn_id,
session_id: billing.session_id(),
role: TranscriptRole::User,
text: text.clone(),
language: None,
interrupted: false,
occurred_at: Utc::now(),
});
}
processor
.push_frame(
Frame::llm_context(self.context.clone()),
FrameDirection::Downstream,
)
.await?;
Ok(true)
}
}
fn combine(deferred: &str, aggregation: &str) -> String {
let d = deferred.trim();
let a = aggregation.trim();
match (d.is_empty(), a.is_empty()) {
(true, true) => String::new(),
(false, true) => d.to_string(),
(true, false) => a.to_string(),
(false, false) => format!("{} {}", d, a),
}
}
#[async_trait]
impl FrameHandler for LLMUserAggregator {
async fn on_process_frame(
&self,
processor: &FrameProcessor,
frame: Frame,
direction: FrameDirection,
) -> Result<()> {
match &frame.inner {
FrameInner::System(SystemFrame::VADUserStartedSpeaking { .. }) => {
self.state.lock().unwrap().turn_open = true;
processor.push_frame(frame, direction).await?;
if processor.interruptions_allowed() {
processor.broadcast_interruption().await?;
}
}
FrameInner::System(SystemFrame::VADUserStoppedSpeaking { transcript, .. }) => {
let (was_open, has_text) = {
let mut state = self.state.lock().unwrap();
if let Some(td) = transcript {
let text = td.text.trim();
if !text.is_empty() {
if state.aggregation.is_empty() {
state.aggregation = text.to_string();
} else {
state.aggregation.push(' ');
state.aggregation.push_str(text);
}
}
}
let was_open = state.turn_open;
state.turn_open = false;
(was_open, !state.aggregation.is_empty() || !state.deferred.is_empty())
};
processor.push_frame(frame, direction).await?;
if was_open || has_text {
self.flush_and_trigger(processor).await?;
}
}
FrameInner::Data(DataFrame::Transcription(t)) => {
let text = t.text.trim().to_string();
if text.is_empty() {
return Ok(());
}
let mut state = self.state.lock().unwrap();
if state.turn_open {
if state.aggregation.is_empty() {
state.aggregation = text;
} else {
state.aggregation.push(' ');
state.aggregation.push_str(&text);
}
} else {
match self.policy {
LateTranscriptPolicy::Defer => {
log::info!(
"LLMUserAggregator: late transcript deferred to next \
turn: '{}'",
text
);
if state.deferred.is_empty() {
state.deferred = text;
} else {
state.deferred.push(' ');
state.deferred.push_str(&text);
}
}
LateTranscriptPolicy::Discard => {
log::warn!(
"LLMUserAggregator: transcript with no open turn \
discarded: '{}'",
text
);
}
}
}
}
_ => {
processor.push_frame(frame, direction).await?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frames::{StartFrameData, TranscriptionData};
#[test]
fn combine_joins_deferred_and_current() {
assert_eq!(combine("", ""), "");
assert_eq!(combine("cancel cheyyanam", ""), "cancel cheyyanam");
assert_eq!(combine("", "what is the status"), "what is the status");
assert_eq!(
combine("cancel cheyyanam", "of my order"),
"cancel cheyyanam of my order"
);
assert_eq!(combine(" padded ", " text "), "padded text");
}
#[tokio::test]
async fn vad_stop_with_bundled_transcript_triggers_llm() {
let ctx = Arc::new(Mutex::new(LLMContext::new(None)));
let proc = LLMUserAggregator::new(ctx);
let captured = Arc::new(Mutex::new(Vec::new()));
let cap = captured.clone();
proc.on_after_push_frame(move |f| {
cap.lock().unwrap().push(f.clone());
});
let _ = proc
.process_frame(Frame::start(StartFrameData::default()), FrameDirection::Downstream)
.await;
let _ = proc
.process_frame(Frame::vad_user_started_speaking(0.0, 0.0), FrameDirection::Downstream)
.await;
let td = TranscriptionData::new("hello world", "user", "2024-01-01T00:00:00Z");
let stop = Frame::vad_user_stopped_speaking(0.0, 0.0).with_vad_stop_transcript(td);
let _ = proc.process_frame(stop, FrameDirection::Downstream).await;
let frames = captured.lock().unwrap();
let has_context = frames.iter().any(|f| {
matches!(f.inner, FrameInner::Data(DataFrame::LLMContextFrame(_)))
});
assert!(
has_context,
"expected LLMContextFrame after VadStop with bundled transcript"
);
}
}