use std::sync::Arc;
use kameo::prelude::*;
use super::event::{ChangeEvent, CompressStrategy};
use super::stream::AgentSendStream;
use super::types::{ContextBackend, ScratchOpts};
use crate::error::AgentError;
use crate::message::ContextMessage;
use crate::readonly::ReadOnly;
use crate::role::Role;
#[derive(Actor)]
pub struct AgentContext<B: ContextBackend> {
backend: B,
immutable: ReadOnly<B::Message>,
compressed: Vec<B::Message>,
incremental: Vec<B::Message>,
#[expect(clippy::type_complexity, reason = "回调类型不可避免复杂")]
on_change: Option<Arc<dyn Fn(ChangeEvent<B::Message>) + Send + Sync>>,
#[expect(clippy::type_complexity, reason = "回调类型不可避免复杂")]
on_compressed: Option<
Arc<
dyn Fn(Vec<B::Message>, Vec<B::Message>) -> (Vec<B::Message>, Vec<B::Message>)
+ Send
+ Sync,
>,
>,
}
impl<B: ContextBackend> AgentContext<B> {
pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
Self {
backend,
immutable: ReadOnly::from(immutable),
compressed: Vec::new(),
incremental: Vec::new(),
on_change: None,
on_compressed: None,
}
}
pub fn with_on_change(
mut self,
f: impl Fn(ChangeEvent<B::Message>) + Send + Sync + 'static,
) -> Self {
self.on_change = Some(Arc::new(f));
self
}
pub fn with_on_compressed(
mut self,
f: impl Fn(Vec<B::Message>, Vec<B::Message>) -> (Vec<B::Message>, Vec<B::Message>)
+ Send
+ Sync
+ 'static,
) -> Self {
self.on_compressed = Some(Arc::new(f));
self
}
fn default_summary_prompt() -> String {
"请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
.to_string()
}
}
pub struct AppendMsg<M> {
pub message: M,
}
impl<B: ContextBackend> Message<AppendMsg<B::Message>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: AppendMsg<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.push(msg.message);
if let Some(cb) = &self.on_change
&& let Some(last) = self.incremental.last().cloned()
{
cb(ChangeEvent::Appended(last));
}
}
}
pub struct Len;
impl<B: ContextBackend> Message<Len> for AgentContext<B> {
type Reply = usize;
async fn handle(&mut self, _msg: Len, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
self.immutable.len() + self.compressed.len() + self.incremental.len()
}
}
pub struct IsEmpty;
impl<B: ContextBackend> Message<IsEmpty> for AgentContext<B> {
type Reply = bool;
async fn handle(
&mut self,
_msg: IsEmpty,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
}
}
pub struct ExtendMsg<M> {
pub messages: Vec<M>,
}
impl<B: ContextBackend> Message<ExtendMsg<B::Message>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: ExtendMsg<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
for m in msg.messages {
self.incremental.push(m);
if let Some(cb) = &self.on_change
&& let Some(last) = self.incremental.last().cloned()
{
cb(ChangeEvent::Appended(last));
}
}
}
}
pub struct SilentAppendMsg<M> {
pub message: M,
}
impl<B: ContextBackend> Message<SilentAppendMsg<B::Message>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: SilentAppendMsg<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.push(msg.message);
}
}
pub struct Get(pub usize);
impl<B: ContextBackend> Message<Get> for AgentContext<B> {
type Reply = Option<B::Message>;
async fn handle(&mut self, msg: Get, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
let idx = msg.0;
let imm_len = self.immutable.len();
let comp_len = self.compressed.len();
if idx < imm_len {
Some(self.immutable[idx].clone())
} else if idx < imm_len + comp_len {
Some(self.compressed[idx - imm_len].clone())
} else if idx < imm_len + comp_len + self.incremental.len() {
Some(self.incremental[idx - imm_len - comp_len].clone())
} else {
None
}
}
}
pub struct MessagesMsg;
impl<B: ContextBackend> Message<MessagesMsg> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: MessagesMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect()
}
}
pub struct ImmutableMsg;
impl<B: ContextBackend> Message<ImmutableMsg> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: ImmutableMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable.to_vec()
}
}
pub struct CompressedMsg;
impl<B: ContextBackend> Message<CompressedMsg> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: CompressedMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.compressed.clone()
}
}
pub struct IncrementalMsg;
impl<B: ContextBackend> Message<IncrementalMsg> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
_msg: IncrementalMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.incremental.clone()
}
}
pub struct FindByRoleMsg(pub Role);
impl<B: ContextBackend> Message<FindByRoleMsg> for AgentContext<B> {
type Reply = Vec<B::Message>;
async fn handle(
&mut self,
msg: FindByRoleMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
self.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.filter(|m| m.role() == msg.0)
.cloned()
.collect()
}
}
pub struct UpdateMsg<M> {
pub index: usize,
pub message: M,
}
impl<B: ContextBackend> Message<UpdateMsg<B::Message>> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: UpdateMsg<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if msg.index >= self.incremental.len() {
return Err(AgentError::Context("索引越界".into()));
}
let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
if let Some(cb) = &self.on_change {
cb(ChangeEvent::Updated {
index: msg.index,
old,
new: self.incremental[msg.index].clone(),
});
}
Ok(())
}
}
pub struct InsertMsg<M> {
pub index: usize,
pub message: M,
}
impl<B: ContextBackend> Message<InsertMsg<B::Message>> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: InsertMsg<B::Message>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if msg.index > self.incremental.len() {
return Err(AgentError::Context("索引越界".into()));
}
self.incremental.insert(msg.index, msg.message);
if let Some(cb) = &self.on_change {
cb(ChangeEvent::Inserted {
index: msg.index,
message: self.incremental[msg.index].clone(),
});
}
Ok(())
}
}
pub struct RemoveMsg {
pub index: usize,
}
impl<B: ContextBackend> Message<RemoveMsg> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: RemoveMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if msg.index >= self.incremental.len() {
return Err(AgentError::Context("索引越界".into()));
}
let removed = self.incremental.remove(msg.index);
if let Some(cb) = &self.on_change {
cb(ChangeEvent::Removed {
index: msg.index,
message: removed,
});
}
Ok(())
}
}
pub struct PopMsg;
impl<B: ContextBackend> Message<PopMsg> for AgentContext<B> {
type Reply = Option<B::Message>;
async fn handle(&mut self, _msg: PopMsg, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
let popped = self.incremental.pop();
if let Some(ref msg) = popped
&& let Some(cb) = &self.on_change
{
cb(ChangeEvent::Popped(msg.clone()));
}
popped
}
}
pub struct RetainMsg {
pub role: Role,
}
impl<B: ContextBackend> Message<RetainMsg> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: RetainMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let mut removed = Vec::new();
let role = msg.role;
self.incremental.retain(|m| {
if m.role() == role {
true
} else {
removed.push(m.clone());
false
}
});
if let Some(cb) = &self.on_change {
cb(ChangeEvent::Retained { role, removed });
}
}
}
pub struct ClearMsg;
impl<B: ContextBackend> Message<ClearMsg> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
_msg: ClearMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
if !self.incremental.is_empty() {
let removed = std::mem::take(&mut self.incremental);
if let Some(cb) = &self.on_change {
cb(ChangeEvent::Cleared { removed });
}
}
}
}
pub struct CompressMsg<O> {
pub strategy: CompressStrategy,
pub opts: O,
}
impl<B: ContextBackend> Message<CompressMsg<B::Opts>> for AgentContext<B> {
type Reply = ();
async fn handle(
&mut self,
msg: CompressMsg<B::Opts>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
match msg.strategy {
CompressStrategy::Summarize { keep, prompt } => {
let total = self.incremental.len();
if total > keep {
let split = total - keep;
let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
if !to_summarize.is_empty() {
let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
let mut summary_messages =
vec![self.backend.system_message(summary_prompt)];
summary_messages.append(&mut self.compressed);
summary_messages.extend(to_summarize);
let result = self.backend.send(&summary_messages, &msg.opts).await;
if let Ok(response) = result {
if let Ok(raw_msgs) =
self.backend.extract_messages_from_backend_response(
std::slice::from_ref(&response),
)
{
if let Ok(request_msgs) = self.backend.to_request_messages(raw_msgs)
{
let summary: Vec<B::Message> = request_msgs
.into_iter()
.map(|msg| self.backend.to_system_message(msg))
.collect();
let kept: Vec<B::Message> =
self.incremental.drain(..).collect();
let (final_summary, final_kept) =
if let Some(cb) = &self.on_compressed {
cb(summary, kept)
} else {
(summary, kept)
};
self.compressed = final_summary;
self.incremental = final_kept;
} else {
log::warn!("压缩摘要转换请求格式失败,已跳过");
}
} else {
log::warn!("压缩摘要提取消息失败,已跳过");
}
}
}
}
}
}
}
}
pub struct SendMsg<O> {
pub opts: O,
}
impl<B: ContextBackend> Message<SendMsg<B::Opts>> for AgentContext<B> {
type Reply = Result<B::Response, AgentError>;
async fn handle(
&mut self,
msg: SendMsg<B::Opts>,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let scratch = msg.opts.scratch().map(|s| s.to_string());
let mut all_messages: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
if let Some(content) = scratch {
all_messages.push(self.backend.system_message(content));
}
let response = self.backend.send(&all_messages, &msg.opts).await?;
let raw_msgs = self
.backend
.extract_messages_from_backend_response(std::slice::from_ref(&response))?;
let request_msgs = self.backend.to_request_messages(raw_msgs)?;
for msg in &request_msgs {
self.incremental.push(msg.clone());
if let Some(cb) = &self.on_change {
cb(ChangeEvent::Appended(msg.clone()));
}
}
Ok(response)
}
}
pub struct SendStreamMsg<O> {
pub opts: O,
}
impl<B: ContextBackend + Clone> Message<SendStreamMsg<B::Opts>> for AgentContext<B> {
type Reply = AgentSendStream<B>;
async fn handle(
&mut self,
msg: SendStreamMsg<B::Opts>,
ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let scratch = msg.opts.scratch().map(|s| s.to_string());
let mut all_messages: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
if let Some(content) = scratch {
all_messages.push(self.backend.system_message(content));
}
let stream = self.backend.send_stream(all_messages, msg.opts);
AgentSendStream::new(
self.backend.clone(),
stream,
ctx.actor_ref().clone(),
self.on_change.clone(),
)
}
}
pub struct EstimateTokensMsg;
impl<B: ContextBackend> Message<EstimateTokensMsg> for AgentContext<B> {
type Reply = usize;
async fn handle(
&mut self,
_msg: EstimateTokensMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let all: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
self.backend.estimate_tokens(&all).await.unwrap_or(0)
}
}
pub struct IsFullMsg;
impl<B: ContextBackend> Message<IsFullMsg> for AgentContext<B> {
type Reply = bool;
async fn handle(
&mut self,
_msg: IsFullMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let all: Vec<B::Message> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.cloned()
.collect();
let tokens = self
.backend
.estimate_tokens(&all)
.await
.unwrap_or(usize::MAX);
tokens >= self.backend.context_window()
}
}
pub struct ToJsonlMsg;
impl<B: ContextBackend> Message<ToJsonlMsg> for AgentContext<B> {
type Reply = Result<String, AgentError>;
async fn handle(
&mut self,
_msg: ToJsonlMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
let lines: Vec<String> = self
.immutable
.iter()
.chain(self.compressed.iter())
.chain(self.incremental.iter())
.map(|m| self.backend.message_to_jsonl(m))
.collect::<Result<_, _>>()?;
Ok(lines.join("\n"))
}
}
pub struct FromJsonlMsg {
pub jsonl: String,
}
impl<B: ContextBackend> Message<FromJsonlMsg> for AgentContext<B> {
type Reply = Result<(), AgentError>;
async fn handle(
&mut self,
msg: FromJsonlMsg,
_ctx: &mut Context<Self, Self::Reply>,
) -> Self::Reply {
for line in msg.jsonl.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let message: B::Message = self.backend.message_from_jsonl(line)?;
self.incremental.push(message.clone());
if let Some(ref cb) = self.on_change {
cb(ChangeEvent::Appended(message));
}
}
Ok(())
}
}