1use std::sync::Arc;
6
7use kameo::prelude::*;
8
9use super::event::{ChangeEvent, CompressStrategy};
10use super::stream::AgentSendStream;
11use super::types::ContextBackend;
12use crate::error::AgentError;
13use crate::message::ContextMessage;
14use crate::readonly::ReadOnly;
15use crate::role::Role;
16
17#[derive(Actor)]
38pub struct AgentContext<B: ContextBackend> {
39 backend: B,
40 immutable: ReadOnly<B::Message>,
41 compressed: Vec<B::Message>,
42 incremental: Vec<B::Message>,
43 #[expect(clippy::type_complexity, reason = "回调类型不可避免复杂")]
44 on_change: Option<Arc<dyn Fn(ChangeEvent<B::Message>) + Send + Sync>>,
45 #[expect(clippy::type_complexity, reason = "回调类型不可避免复杂")]
46 on_compressed: Option<
47 Arc<
48 dyn Fn(Vec<B::Message>, Vec<B::Message>) -> (Vec<B::Message>, Vec<B::Message>)
49 + Send
50 + Sync,
51 >,
52 >,
53}
54
55impl<B: ContextBackend> AgentContext<B> {
56 pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
61 Self {
62 backend,
63 immutable: ReadOnly::from(immutable),
64 compressed: Vec::new(),
65 incremental: Vec::new(),
66 on_change: None,
67 on_compressed: None,
68 }
69 }
70
71 pub fn with_on_change(
76 mut self,
77 f: impl Fn(ChangeEvent<B::Message>) + Send + Sync + 'static,
78 ) -> Self {
79 self.on_change = Some(Arc::new(f));
80 self
81 }
82
83 pub fn with_on_compressed(
89 mut self,
90 f: impl Fn(Vec<B::Message>, Vec<B::Message>) -> (Vec<B::Message>, Vec<B::Message>)
91 + Send
92 + Sync
93 + 'static,
94 ) -> Self {
95 self.on_compressed = Some(Arc::new(f));
96 self
97 }
98
99 fn default_summary_prompt() -> String {
100 "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
101 .to_string()
102 }
103
104 async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
106 let common = opts.as_ref();
107 let all: Vec<B::Message> = self
108 .immutable
109 .iter()
110 .chain(self.compressed.iter())
111 .chain(self.incremental.iter())
112 .cloned()
113 .collect();
114 let tokens = self
115 .backend
116 .estimate_tokens(&all)
117 .await
118 .unwrap_or(usize::MAX);
119 if tokens < common.context_window {
120 return Ok(());
121 }
122 if !common.auto_compress {
123 return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
124 }
125 let total = self.incremental.len();
126 let keep = total / 2;
127 if total <= keep {
128 return Ok(());
129 }
130 let split = total - keep;
131 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
132 if to_summarize.is_empty() {
133 return Ok(());
134 }
135 let mut summary_messages = vec![self.backend.system_message(Self::default_summary_prompt())];
136 summary_messages.append(&mut self.compressed);
137 summary_messages.extend(to_summarize);
138 let response = self.backend.send(&summary_messages, opts).await?;
139 let raw_msgs = self
140 .backend
141 .extract_messages(std::slice::from_ref(&response))?;
142 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
143 let summary: Vec<B::Message> = request_msgs
144 .into_iter()
145 .map(|msg| self.backend.to_system_message(msg))
146 .collect();
147 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
148 let (final_summary, final_kept) =
149 if let Some(cb) = &self.on_compressed {
150 cb(summary, kept)
151 } else {
152 (summary, kept)
153 };
154 self.compressed = final_summary;
155 self.incremental = final_kept;
156 Ok(())
157 }
158}
159
160pub struct AppendMsg<M> {
168 pub message: M,
170}
171
172impl<B: ContextBackend> Message<AppendMsg<B::Message>> for AgentContext<B> {
173 type Reply = ();
174
175 async fn handle(
176 &mut self,
177 msg: AppendMsg<B::Message>,
178 _ctx: &mut Context<Self, Self::Reply>,
179 ) -> Self::Reply {
180 self.incremental.push(msg.message);
181 if let Some(cb) = &self.on_change
182 && let Some(last) = self.incremental.last().cloned()
183 {
184 cb(ChangeEvent::Appended(last));
185 }
186 }
187}
188
189pub struct Len;
191
192impl<B: ContextBackend> Message<Len> for AgentContext<B> {
193 type Reply = usize;
194
195 async fn handle(&mut self, _msg: Len, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
196 self.immutable.len() + self.compressed.len() + self.incremental.len()
197 }
198}
199
200pub struct IsEmpty;
202
203impl<B: ContextBackend> Message<IsEmpty> for AgentContext<B> {
204 type Reply = bool;
205
206 async fn handle(
207 &mut self,
208 _msg: IsEmpty,
209 _ctx: &mut Context<Self, Self::Reply>,
210 ) -> Self::Reply {
211 self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
212 }
213}
214
215pub struct ExtendMsg<M> {
219 pub messages: Vec<M>,
221}
222
223impl<B: ContextBackend> Message<ExtendMsg<B::Message>> for AgentContext<B> {
224 type Reply = ();
225
226 async fn handle(
227 &mut self,
228 msg: ExtendMsg<B::Message>,
229 _ctx: &mut Context<Self, Self::Reply>,
230 ) -> Self::Reply {
231 for m in msg.messages {
232 self.incremental.push(m);
233 if let Some(cb) = &self.on_change
234 && let Some(last) = self.incremental.last().cloned()
235 {
236 cb(ChangeEvent::Appended(last));
237 }
238 }
239 }
240}
241
242pub struct Get(pub usize);
247
248impl<B: ContextBackend> Message<Get> for AgentContext<B> {
249 type Reply = Option<B::Message>;
250
251 async fn handle(&mut self, msg: Get, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
252 let idx = msg.0;
253 let imm_len = self.immutable.len();
254 let comp_len = self.compressed.len();
255 if idx < imm_len {
256 Some(self.immutable[idx].clone())
257 } else if idx < imm_len + comp_len {
258 Some(self.compressed[idx - imm_len].clone())
259 } else if idx < imm_len + comp_len + self.incremental.len() {
260 Some(self.incremental[idx - imm_len - comp_len].clone())
261 } else {
262 None
263 }
264 }
265}
266
267pub struct MessagesMsg;
271
272impl<B: ContextBackend> Message<MessagesMsg> for AgentContext<B> {
273 type Reply = Vec<B::Message>;
274
275 async fn handle(
276 &mut self,
277 _msg: MessagesMsg,
278 _ctx: &mut Context<Self, Self::Reply>,
279 ) -> Self::Reply {
280 self.immutable
281 .iter()
282 .chain(self.compressed.iter())
283 .chain(self.incremental.iter())
284 .cloned()
285 .collect()
286 }
287}
288
289pub struct ImmutableMsg;
291
292impl<B: ContextBackend> Message<ImmutableMsg> for AgentContext<B> {
293 type Reply = Vec<B::Message>;
294
295 async fn handle(
296 &mut self,
297 _msg: ImmutableMsg,
298 _ctx: &mut Context<Self, Self::Reply>,
299 ) -> Self::Reply {
300 self.immutable.to_vec()
301 }
302}
303
304pub struct CompressedMsg;
306
307impl<B: ContextBackend> Message<CompressedMsg> for AgentContext<B> {
308 type Reply = Vec<B::Message>;
309
310 async fn handle(
311 &mut self,
312 _msg: CompressedMsg,
313 _ctx: &mut Context<Self, Self::Reply>,
314 ) -> Self::Reply {
315 self.compressed.clone()
316 }
317}
318
319pub struct IncrementalMsg;
321
322impl<B: ContextBackend> Message<IncrementalMsg> for AgentContext<B> {
323 type Reply = Vec<B::Message>;
324
325 async fn handle(
326 &mut self,
327 _msg: IncrementalMsg,
328 _ctx: &mut Context<Self, Self::Reply>,
329 ) -> Self::Reply {
330 self.incremental.clone()
331 }
332}
333
334pub struct FindByRoleMsg(pub Role);
336
337impl<B: ContextBackend> Message<FindByRoleMsg> for AgentContext<B> {
338 type Reply = Vec<B::Message>;
339
340 async fn handle(
341 &mut self,
342 msg: FindByRoleMsg,
343 _ctx: &mut Context<Self, Self::Reply>,
344 ) -> Self::Reply {
345 self.immutable
346 .iter()
347 .chain(self.compressed.iter())
348 .chain(self.incremental.iter())
349 .filter(|m| m.role() == msg.0)
350 .cloned()
351 .collect()
352 }
353}
354
355pub struct UpdateMsg<M> {
359 pub index: usize,
361 pub message: M,
363}
364
365impl<B: ContextBackend> Message<UpdateMsg<B::Message>> for AgentContext<B> {
366 type Reply = Result<(), AgentError>;
367
368 async fn handle(
369 &mut self,
370 msg: UpdateMsg<B::Message>,
371 _ctx: &mut Context<Self, Self::Reply>,
372 ) -> Self::Reply {
373 if msg.index >= self.incremental.len() {
374 return Err(AgentError::Context("索引越界".into()));
375 }
376 let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
377 if let Some(cb) = &self.on_change {
378 cb(ChangeEvent::Updated {
379 index: msg.index,
380 old,
381 new: self.incremental[msg.index].clone(),
382 });
383 }
384 Ok(())
385 }
386}
387
388pub struct InsertMsg<M> {
392 pub index: usize,
394 pub message: M,
396}
397
398impl<B: ContextBackend> Message<InsertMsg<B::Message>> for AgentContext<B> {
399 type Reply = Result<(), AgentError>;
400
401 async fn handle(
402 &mut self,
403 msg: InsertMsg<B::Message>,
404 _ctx: &mut Context<Self, Self::Reply>,
405 ) -> Self::Reply {
406 if msg.index > self.incremental.len() {
407 return Err(AgentError::Context("索引越界".into()));
408 }
409 self.incremental.insert(msg.index, msg.message);
410 if let Some(cb) = &self.on_change {
411 cb(ChangeEvent::Inserted {
412 index: msg.index,
413 message: self.incremental[msg.index].clone(),
414 });
415 }
416 Ok(())
417 }
418}
419
420pub struct RemoveMsg {
424 pub index: usize,
426}
427
428impl<B: ContextBackend> Message<RemoveMsg> for AgentContext<B> {
429 type Reply = Result<(), AgentError>;
430
431 async fn handle(
432 &mut self,
433 msg: RemoveMsg,
434 _ctx: &mut Context<Self, Self::Reply>,
435 ) -> Self::Reply {
436 if msg.index >= self.incremental.len() {
437 return Err(AgentError::Context("索引越界".into()));
438 }
439 let removed = self.incremental.remove(msg.index);
440 if let Some(cb) = &self.on_change {
441 cb(ChangeEvent::Removed {
442 index: msg.index,
443 message: removed,
444 });
445 }
446 Ok(())
447 }
448}
449
450pub struct PopMsg;
454
455impl<B: ContextBackend> Message<PopMsg> for AgentContext<B> {
456 type Reply = Option<B::Message>;
457
458 async fn handle(&mut self, _msg: PopMsg, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
459 let popped = self.incremental.pop();
460 if let Some(ref msg) = popped
461 && let Some(cb) = &self.on_change
462 {
463 cb(ChangeEvent::Popped(msg.clone()));
464 }
465 popped
466 }
467}
468
469pub struct RetainMsg {
473 pub role: Role,
475}
476
477impl<B: ContextBackend> Message<RetainMsg> for AgentContext<B> {
478 type Reply = ();
479
480 async fn handle(
481 &mut self,
482 msg: RetainMsg,
483 _ctx: &mut Context<Self, Self::Reply>,
484 ) -> Self::Reply {
485 let mut removed = Vec::new();
486 let role = msg.role;
487 self.incremental.retain(|m| {
488 if m.role() == role {
489 true
490 } else {
491 removed.push(m.clone());
492 false
493 }
494 });
495 if let Some(cb) = &self.on_change {
496 cb(ChangeEvent::Retained { role, removed });
497 }
498 }
499}
500
501pub struct ClearMsg;
505
506impl<B: ContextBackend> Message<ClearMsg> for AgentContext<B> {
507 type Reply = ();
508
509 async fn handle(
510 &mut self,
511 _msg: ClearMsg,
512 _ctx: &mut Context<Self, Self::Reply>,
513 ) -> Self::Reply {
514 if !self.incremental.is_empty() {
515 let removed = std::mem::take(&mut self.incremental);
516 if let Some(cb) = &self.on_change {
517 cb(ChangeEvent::Cleared { removed });
518 }
519 }
520 }
521}
522
523pub struct CompressMsg<O> {
529 pub strategy: CompressStrategy,
531 pub opts: O,
533}
534
535impl<B: ContextBackend> Message<CompressMsg<B::Opts>> for AgentContext<B> {
536 type Reply = Result<(), AgentError>;
537
538 async fn handle(
539 &mut self,
540 msg: CompressMsg<B::Opts>,
541 _ctx: &mut Context<Self, Self::Reply>,
542 ) -> Self::Reply {
543 match msg.strategy {
544 CompressStrategy::Summarize { keep, prompt } => {
545 let total = self.incremental.len();
546 if total > keep {
547 let split = total - keep;
548 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
549 if !to_summarize.is_empty() {
550 let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
551 let mut summary_messages =
552 vec![self.backend.system_message(summary_prompt)];
553 summary_messages.append(&mut self.compressed);
554 summary_messages.extend(to_summarize);
555 let response = self.backend.send(&summary_messages, &msg.opts).await?;
556 let raw_msgs = self.backend.extract_messages(
557 std::slice::from_ref(&response),
558 )?;
559 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
560 let summary: Vec<B::Message> = request_msgs
561 .into_iter()
562 .map(|msg| self.backend.to_system_message(msg))
563 .collect();
564 let kept: Vec<B::Message> =
565 self.incremental.drain(..).collect();
566 let (final_summary, final_kept) =
567 if let Some(cb) = &self.on_compressed {
568 cb(summary, kept)
569 } else {
570 (summary, kept)
571 };
572 self.compressed = final_summary;
573 self.incremental = final_kept;
574 }
575 }
576 Ok(())
577 }
578 }
579 }
580}
581
582pub struct SendMsg<O> {
588 pub opts: O,
590}
591
592impl<B: ContextBackend> Message<SendMsg<B::Opts>> for AgentContext<B> {
593 type Reply = Result<B::Response, AgentError>;
594
595 async fn handle(
596 &mut self,
597 msg: SendMsg<B::Opts>,
598 _ctx: &mut Context<Self, Self::Reply>,
599 ) -> Self::Reply {
600 self.compress_if_full(&msg.opts).await?;
601 let scratch = msg.opts.as_ref().scratch.clone();
602 let mut all_messages: Vec<B::Message> = self
603 .immutable
604 .iter()
605 .chain(self.compressed.iter())
606 .chain(self.incremental.iter())
607 .cloned()
608 .collect();
609 if let Some(content) = scratch {
610 all_messages.push(self.backend.system_message(content));
611 }
612 let response = self.backend.send(&all_messages, &msg.opts).await?;
613 let raw_msgs = self
614 .backend
615 .extract_messages(std::slice::from_ref(&response))?;
616 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
617 for msg in &request_msgs {
618 self.incremental.push(msg.clone());
619 if let Some(cb) = &self.on_change {
620 cb(ChangeEvent::Appended(msg.clone()));
621 }
622 }
623 Ok(response)
624 }
625}
626
627pub struct SendStreamMsg<O> {
632 pub opts: O,
634}
635
636impl<B: ContextBackend + Clone> Message<SendStreamMsg<B::Opts>> for AgentContext<B> {
637 type Reply = Result<AgentSendStream<B>, AgentError>;
638
639 async fn handle(
640 &mut self,
641 msg: SendStreamMsg<B::Opts>,
642 _ctx: &mut Context<Self, Self::Reply>,
643 ) -> Self::Reply {
644 self.compress_if_full(&msg.opts).await?;
645 let scratch = msg.opts.as_ref().scratch.clone();
646 let mut all_messages: Vec<B::Message> = self
647 .immutable
648 .iter()
649 .chain(self.compressed.iter())
650 .chain(self.incremental.iter())
651 .cloned()
652 .collect();
653 if let Some(content) = scratch {
654 all_messages.push(self.backend.system_message(content));
655 }
656 let stream = self.backend.send_stream(all_messages, msg.opts);
657 Ok(AgentSendStream::new(stream))
658 }
659}
660
661pub struct EstimateTokensMsg;
666
667impl<B: ContextBackend> Message<EstimateTokensMsg> for AgentContext<B> {
668 type Reply = usize;
669
670 async fn handle(
671 &mut self,
672 _msg: EstimateTokensMsg,
673 _ctx: &mut Context<Self, Self::Reply>,
674 ) -> Self::Reply {
675 let all: Vec<B::Message> = self
676 .immutable
677 .iter()
678 .chain(self.compressed.iter())
679 .chain(self.incremental.iter())
680 .cloned()
681 .collect();
682 self.backend.estimate_tokens(&all).await.unwrap_or(0)
683 }
684}
685
686pub struct ToJsonlMsg;
690
691impl<B: ContextBackend> Message<ToJsonlMsg> for AgentContext<B> {
692 type Reply = Result<String, AgentError>;
693
694 async fn handle(
695 &mut self,
696 _msg: ToJsonlMsg,
697 _ctx: &mut Context<Self, Self::Reply>,
698 ) -> Self::Reply {
699 let lines: Vec<String> = self
700 .immutable
701 .iter()
702 .chain(self.compressed.iter())
703 .chain(self.incremental.iter())
704 .map(|m| self.backend.message_to_jsonl(m))
705 .collect::<Result<_, _>>()?;
706 Ok(lines.join("\n"))
707 }
708}
709
710pub struct FromJsonlMsg {
715 pub jsonl: String,
717}
718
719impl<B: ContextBackend> Message<FromJsonlMsg> for AgentContext<B> {
720 type Reply = Result<(), AgentError>;
721
722 async fn handle(
723 &mut self,
724 msg: FromJsonlMsg,
725 _ctx: &mut Context<Self, Self::Reply>,
726 ) -> Self::Reply {
727 for line in msg.jsonl.lines() {
728 let line = line.trim();
729 if line.is_empty() {
730 continue;
731 }
732 let message: B::Message = self.backend.message_from_jsonl(line)?;
733 self.incremental.push(message.clone());
734 if let Some(ref cb) = self.on_change {
735 cb(ChangeEvent::Appended(message));
736 }
737 }
738 Ok(())
739 }
740}