use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use log;
use crate::context::LLMContext;
use crate::error::Result;
use crate::frames::{
ControlFrame, DataFrame, Frame, FrameDirection, FrameHandler, FrameInner, FrameProcessor,
SystemFrame,
};
struct State {
aggregation: String,
in_response: bool,
}
pub struct LLMAssistantAggregator {
context: Arc<Mutex<LLMContext>>,
state: Mutex<State>,
}
impl LLMAssistantAggregator {
pub fn new(context: Arc<Mutex<LLMContext>>) -> FrameProcessor {
FrameProcessor::new(
"LLMAssistantAggregator",
Box::new(Self {
context,
state: Mutex::new(State {
aggregation: String::new(),
in_response: false,
}),
}),
false,
)
}
}
#[async_trait]
impl FrameHandler for LLMAssistantAggregator {
async fn on_process_frame(
&self,
processor: &FrameProcessor,
frame: Frame,
direction: FrameDirection,
) -> Result<()> {
match &frame.inner {
FrameInner::Control(ControlFrame::LLMFullResponseStart) => {
{
let mut state = self.state.lock().unwrap();
state.in_response = true;
state.aggregation.clear();
}
processor.push_frame(frame, direction).await?;
}
FrameInner::Data(DataFrame::LLMText(text)) => {
let text = text.clone();
{
let mut state = self.state.lock().unwrap();
if state.in_response {
state.aggregation.push_str(&text);
}
}
processor.push_frame(frame, direction).await?;
}
FrameInner::Control(ControlFrame::LLMFullResponseEnd) => {
let aggregation = {
let mut state = self.state.lock().unwrap();
state.in_response = false;
let text = state.aggregation.trim().to_string();
state.aggregation.clear();
text
};
if !aggregation.is_empty() {
log::info!(
"LLMAssistantAggregator: assistant said: '{}'",
aggregation
);
self.context
.lock()
.unwrap()
.add_assistant_message(&aggregation);
}
processor.push_frame(frame, direction).await?;
}
FrameInner::System(SystemFrame::Interruption) => {
{
let mut state = self.state.lock().unwrap();
if state.in_response {
log::debug!(
"LLMAssistantAggregator: interrupted, discarding partial response"
);
}
state.aggregation.clear();
state.in_response = false;
}
processor.push_frame(frame, direction).await?;
}
_ => {
processor.push_frame(frame, direction).await?;
}
}
Ok(())
}
}