use kameo::prelude::*;
use super::events::{
CompressStrategy, NotifyCompressedForReply, RequestAppend, RequestClear,
RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
RequestImportIncremental, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
RequestSend, RequestSendStream, RequestSubscribeCompressed, RequestUnsubscribeCompressed,
RequestExportIncremental, RequestExportAll, RequestUpdate,
};
use super::stream::AgentSendStream;
use super::types::ContextBackend;
use crate::error::AgentError;
type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
use crate::message::ContextMessage;
use crate::readonly::ReadOnly;
#[derive(Actor)]
pub struct AgentContext<B: ContextBackend> {
backend: B,
immutable: ReadOnly<B::Message>,
compressed: Vec<B::Message>,
incremental: Vec<B::Message>,
on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
}
impl<B: ContextBackend> AgentContext<B> {
pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
Self {
backend,
immutable: ReadOnly::from(immutable),
compressed: Vec::new(),
incremental: Vec::new(),
on_compressed: None,
}
}
fn default_summary_prompt() -> String {
"请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
.to_string()
}
async fn notify_compressed_subscriber(
&self,
summary: Vec<B::Message>,
kept: Vec<B::Message>,
) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
if let Some(subscriber) = &self.on_compressed {
subscriber
.ask(NotifyCompressedForReply {
summary,
kept,
})
.send()
.await
.map_err(|e| AgentError::Context(e.to_string()))
} else {
Ok((summary, kept))
}
}
async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
let common = opts.as_ref();
let mut all: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
if let Some(ref scratch) = common.scratch {
all.push(self.backend.system_message(scratch.clone()));
}
let tokens = self
.backend
.estimate_tokens(&all)
.await
.unwrap_or(usize::MAX);
if tokens < common.context_window {
return Ok(());
}
if !common.auto_compress {
return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
}
let total = self.incremental.len();
let keep = total / 2;
if total <= keep {
return Ok(());
}
let split = total - keep;
let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
if to_summarize.is_empty() {
return Ok(());
}
let mut summary_messages =
vec![self.backend.system_message(Self::default_summary_prompt())];
summary_messages.append(&mut self.compressed);
summary_messages.extend(to_summarize);
let response = self.backend.send(&summary_messages, opts).await?;
let raw_msgs = self
.backend
.extract_messages(std::slice::from_ref(&response))?;
let request_msgs = self.backend.to_request_messages(raw_msgs)?;
let summary: Vec<B::Message> = request_msgs
.into_iter()
.map(|msg| self.backend.to_system_message(msg))
.collect();
let kept: Vec<B::Message> = self.incremental.drain(..).collect();
let (final_summary, final_kept) =
self.notify_compressed_subscriber(summary, kept).await?;
self.compressed = final_summary;
self.incremental = final_kept;
Ok(())
}
}
impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: RequestAppend<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.push(msg.message);
}
}
impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: RequestExtend<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
for m in msg.messages {
self.incremental.push(m);
}
}
}
impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: RequestUpdate<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if msg.index >= self.incremental.len() {
return Err(AgentError::Context("索引越界".into()));
}
self.incremental[msg.index] = msg.message;
Ok(())
}
}
impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: RequestInsert<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if msg.index > self.incremental.len() {
return Err(AgentError::Context("索引越界".into()));
}
self.incremental.insert(msg.index, msg.message);
Ok(())
}
}
impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: RequestRemove,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if msg.index >= self.incremental.len() {
return Err(AgentError::Context("索引越界".into()));
}
self.incremental.remove(msg.index);
Ok(())
}
}
impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
type Reply = Option<B::Message>;
async fn handle(
&mut self,
_msg: RequestPop,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.pop()
}
}
impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: RequestRetain,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let mut removed = Vec::new();
let role = msg.role;
self.incremental.retain(|m| {
if m.role() == role {
true
} else {
removed.push(m.clone());
false
}
});
}
}
impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
_msg: RequestClear,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.clear();
}
}
impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
type Reply = usize;
async fn handle(
&mut self,
_msg: RequestLen,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable.len() + self.compressed.len() + self.incremental.len()
}
}
impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
type Reply = bool;
async fn handle(
&mut self,
_msg: RequestIsEmpty,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
}
}
impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
type Reply = Option<B::Message>;
async fn handle(
&mut self,
msg: RequestGet,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let idx = msg.0;
let imm_len = self.immutable.len();
let comp_len = self.compressed.len();
if idx < imm_len {
Some(self.immutable[idx].clone())
} else if idx < imm_len + comp_len {
Some(self.compressed[idx - imm_len].clone())
} else if idx < imm_len + comp_len + self.incremental.len() {
Some(self.incremental[idx - imm_len - comp_len].clone())
} else {
None
}
}
}
impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: RequestMessages,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect()
}
}
impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: RequestImmutable,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable.to_vec()
}
}
impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: RequestCompressed,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.compressed.clone()
}
}
impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: RequestIncremental,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.clone()
}
}
impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
msg: RequestFindByRole,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.filter(|m| m.role() == msg.0)
.cloned()
.collect()
}
}
impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
type Reply = Result<B::Response, AgentError>;
async fn handle(
&mut self,
msg: RequestSend<B::Opts>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.compress_if_full(&msg.opts).await?;
let scratch = msg.opts.as_ref().scratch.clone();
let mut all_messages: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
if let Some(content) = scratch {
all_messages.push(self.backend.system_message(content));
}
let response = self.backend.send(&all_messages, &msg.opts).await?;
let raw_msgs = self
.backend
.extract_messages(std::slice::from_ref(&response))?;
let request_msgs = self.backend.to_request_messages(raw_msgs)?;
for msg in &request_msgs {
self.incremental.push(msg.clone());
}
Ok(response)
}
}
impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
type Reply = Result<AgentSendStream<B>, AgentError>;
async fn handle(
&mut self,
msg: RequestSendStream<B::Opts>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.compress_if_full(&msg.opts).await?;
let scratch = msg.opts.as_ref().scratch.clone();
let mut all_messages: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
if let Some(content) = scratch {
all_messages.push(self.backend.system_message(content));
}
let stream = self.backend.send_stream(all_messages, msg.opts);
Ok(AgentSendStream::new(stream))
}
}
impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: RequestCompress<B::Opts>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
match msg.strategy {
CompressStrategy::Summarize { keep, prompt } => {
let total = self.incremental.len();
if total > keep {
let split = total - keep;
let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
if !to_summarize.is_empty() {
let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
let mut summary_messages =
vec![self.backend.system_message(summary_prompt)];
summary_messages.append(&mut self.compressed);
summary_messages.extend(to_summarize);
let response = self.backend.send(&summary_messages, &msg.opts).await?;
let raw_msgs = self
.backend
.extract_messages(std::slice::from_ref(&response))?;
let request_msgs = self.backend.to_request_messages(raw_msgs)?;
let summary: Vec<B::Message> = request_msgs
.into_iter()
.map(|msg| self.backend.to_system_message(msg))
.collect();
let kept: Vec<B::Message> = self.incremental.drain(..).collect();
let (final_summary, final_kept) =
self.notify_compressed_subscriber(summary, kept).await?;
self.compressed = final_summary;
self.incremental = final_kept;
}
}
Ok(())
}
}
}
}
impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: RequestSubscribeCompressed<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.on_compressed = Some(msg.recipient);
}
}
impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
_msg: RequestUnsubscribeCompressed,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.on_compressed = None;
}
}
impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
type Reply = usize;
async fn handle(
&mut self,
_msg: RequestEstimateTokens,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let all: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
self.backend.estimate_tokens(&all).await.unwrap_or(0)
}
}
impl<B: ContextBackend> Message<RequestExportAll> for AgentContext<B> {
type Reply = Result<String, AgentError>;
async fn handle(
&mut self,
_msg: RequestExportAll,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let lines: Vec<String> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.map(|m| self.backend.message_to_json(m))
.collect::<Result<_, _>>()?;
Ok(lines.join("\n"))
}
}
impl<B: ContextBackend> Message<RequestExportIncremental> for AgentContext<B> {
type Reply = Result<String, AgentError>;
async fn handle(
&mut self,
_msg: RequestExportIncremental,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let lines: Vec<String> = self
.incremental
.iter()
.map(|m| self.backend.message_to_json(m))
.collect::<Result<_, _>>()?;
Ok(lines.join("\n"))
}
}
impl<B: ContextBackend> Message<RequestImportIncremental> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: RequestImportIncremental,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let mut messages = Vec::new();
for line in msg.json.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
messages.push(self.backend.message_from_json(line)?);
}
self.incremental.clear();
self.incremental = messages.clone();
Ok(())
}
}