1use kameo::prelude::*;
6
7use super::event::{
8 CompressStrategy, NotifyChange, NotifyCompressedForReply, RequestAppend, RequestClear,
9 RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
10 RequestFromJsonl, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
11 RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
12 RequestSend, RequestSendStream, RequestToJsonl, RequestUpdate,
13};
14use super::stream::AgentSendStream;
15use super::types::ContextBackend;
16use crate::error::AgentError;
17
18type CompressEditorReply<M> = (Vec<M>, Vec<M>);
19type CompressEditorRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressEditorReply<M>>;
20use crate::message::ContextMessage;
21use crate::readonly::ReadOnly;
22
23#[derive(Actor)]
44pub struct AgentContext<B: ContextBackend> {
45 backend: B,
46 immutable: ReadOnly<B::Message>,
47 compressed: Vec<B::Message>,
48 incremental: Vec<B::Message>,
49 subscribers: Vec<Recipient<NotifyChange<B::Message>>>,
50 on_compressed: Option<CompressEditorRecipient<B::Message>>,
51}
52
53impl<B: ContextBackend> AgentContext<B> {
54 pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
59 Self {
60 backend,
61 immutable: ReadOnly::from(immutable),
62 compressed: Vec::new(),
63 incremental: Vec::new(),
64 subscribers: Vec::new(),
65 on_compressed: None,
66 }
67 }
68
69 pub fn subscribe_change(mut self, recipient: Recipient<NotifyChange<B::Message>>) -> Self {
74 self.subscribers.push(recipient);
75 self
76 }
77
78 pub fn subscribe_compressed(mut self, recipient: CompressEditorRecipient<B::Message>) -> Self {
84 self.on_compressed = Some(recipient);
85 self
86 }
87
88 fn default_summary_prompt() -> String {
89 "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
90 .to_string()
91 }
92
93 async fn notify_change(&self, event: NotifyChange<B::Message>) {
94 for subscriber in &self.subscribers {
95 if let Err(e) = subscriber.tell(event.clone()).send().await {
96 unreachable!("通知订阅者失败: {e:?}");
97 }
98 }
99 }
100
101 async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
103 let common = opts.as_ref();
104 let all: Vec<B::Message> = self
105 .immutable
106 .iter()
107 .chain(self.compressed.iter())
108 .chain(self.incremental.iter())
109 .cloned()
110 .collect();
111 let tokens = self
112 .backend
113 .estimate_tokens(&all)
114 .await
115 .unwrap_or(usize::MAX);
116 if tokens < common.context_window {
117 return Ok(());
118 }
119 if !common.auto_compress {
120 return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
121 }
122 let total = self.incremental.len();
123 let keep = total / 2;
124 if total <= keep {
125 return Ok(());
126 }
127 let split = total - keep;
128 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
129 if to_summarize.is_empty() {
130 return Ok(());
131 }
132 let mut summary_messages =
133 vec![self.backend.system_message(Self::default_summary_prompt())];
134 summary_messages.append(&mut self.compressed);
135 summary_messages.extend(to_summarize);
136 let response = self.backend.send(&summary_messages, opts).await?;
137 let raw_msgs = self
138 .backend
139 .extract_messages(std::slice::from_ref(&response))?;
140 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
141 let summary: Vec<B::Message> = request_msgs
142 .into_iter()
143 .map(|msg| self.backend.to_system_message(msg))
144 .collect();
145 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
146 let (final_summary, final_kept) = if let Some(editor) = &self.on_compressed {
147 editor
148 .ask(NotifyCompressedForReply { summary, kept })
149 .send()
150 .await
151 .map_err(|e| AgentError::Context(e.to_string()))?
152 } else {
153 (summary, kept)
154 };
155 self.compressed = final_summary;
156 self.incremental = final_kept;
157 Ok(())
158 }
159}
160
161impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
166 type Reply = ();
167
168 async fn handle(
169 &mut self,
170 msg: RequestAppend<B::Message>,
171 _ctx: &mut Context<Self, Self::Reply>,
172 ) -> Self::Reply {
173 self.incremental.push(msg.message);
174 if let Some(last) = self.incremental.last().cloned() {
175 self.notify_change(NotifyChange::Appended(last)).await;
176 }
177 }
178}
179
180impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
181 type Reply = ();
182
183 async fn handle(
184 &mut self,
185 msg: RequestExtend<B::Message>,
186 _ctx: &mut Context<Self, Self::Reply>,
187 ) -> Self::Reply {
188 for m in msg.messages {
189 self.incremental.push(m);
190 if let Some(last) = self.incremental.last().cloned() {
191 self.notify_change(NotifyChange::Appended(last)).await;
192 }
193 }
194 }
195}
196
197impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
198 type Reply = Result<(), AgentError>;
199
200 async fn handle(
201 &mut self,
202 msg: RequestUpdate<B::Message>,
203 _ctx: &mut Context<Self, Self::Reply>,
204 ) -> Self::Reply {
205 if msg.index >= self.incremental.len() {
206 return Err(AgentError::Context("索引越界".into()));
207 }
208 let old = std::mem::replace(&mut self.incremental[msg.index], msg.message);
209 self.notify_change(NotifyChange::Updated {
210 index: msg.index,
211 old,
212 new: self.incremental[msg.index].clone(),
213 })
214 .await;
215 Ok(())
216 }
217}
218
219impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
220 type Reply = Result<(), AgentError>;
221
222 async fn handle(
223 &mut self,
224 msg: RequestInsert<B::Message>,
225 _ctx: &mut Context<Self, Self::Reply>,
226 ) -> Self::Reply {
227 if msg.index > self.incremental.len() {
228 return Err(AgentError::Context("索引越界".into()));
229 }
230 self.incremental.insert(msg.index, msg.message);
231 self.notify_change(NotifyChange::Inserted {
232 index: msg.index,
233 message: self.incremental[msg.index].clone(),
234 })
235 .await;
236 Ok(())
237 }
238}
239
240impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
241 type Reply = Result<(), AgentError>;
242
243 async fn handle(
244 &mut self,
245 msg: RequestRemove,
246 _ctx: &mut Context<Self, Self::Reply>,
247 ) -> Self::Reply {
248 if msg.index >= self.incremental.len() {
249 return Err(AgentError::Context("索引越界".into()));
250 }
251 let removed = self.incremental.remove(msg.index);
252 self.notify_change(NotifyChange::Removed {
253 index: msg.index,
254 message: removed,
255 })
256 .await;
257 Ok(())
258 }
259}
260
261impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
262 type Reply = Option<B::Message>;
263
264 async fn handle(
265 &mut self,
266 _msg: RequestPop,
267 _ctx: &mut Context<Self, Self::Reply>,
268 ) -> Self::Reply {
269 let popped = self.incremental.pop();
270 if let Some(ref msg) = popped {
271 self.notify_change(NotifyChange::Popped(msg.clone())).await;
272 }
273 popped
274 }
275}
276
277impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
278 type Reply = ();
279
280 async fn handle(
281 &mut self,
282 msg: RequestRetain,
283 _ctx: &mut Context<Self, Self::Reply>,
284 ) -> Self::Reply {
285 let mut removed = Vec::new();
286 let role = msg.role;
287 self.incremental.retain(|m| {
288 if m.role() == role {
289 true
290 } else {
291 removed.push(m.clone());
292 false
293 }
294 });
295 self.notify_change(NotifyChange::Retained { role, removed })
296 .await;
297 }
298}
299
300impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
301 type Reply = ();
302
303 async fn handle(
304 &mut self,
305 _msg: RequestClear,
306 _ctx: &mut Context<Self, Self::Reply>,
307 ) -> Self::Reply {
308 if !self.incremental.is_empty() {
309 let removed = std::mem::take(&mut self.incremental);
310 self.notify_change(NotifyChange::Cleared { removed }).await;
311 }
312 }
313}
314
315impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
320 type Reply = usize;
321
322 async fn handle(
323 &mut self,
324 _msg: RequestLen,
325 _ctx: &mut Context<Self, Self::Reply>,
326 ) -> Self::Reply {
327 self.immutable.len() + self.compressed.len() + self.incremental.len()
328 }
329}
330
331impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
332 type Reply = bool;
333
334 async fn handle(
335 &mut self,
336 _msg: RequestIsEmpty,
337 _ctx: &mut Context<Self, Self::Reply>,
338 ) -> Self::Reply {
339 self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
340 }
341}
342
343impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
344 type Reply = Option<B::Message>;
345
346 async fn handle(
347 &mut self,
348 msg: RequestGet,
349 _ctx: &mut Context<Self, Self::Reply>,
350 ) -> Self::Reply {
351 let idx = msg.0;
352 let imm_len = self.immutable.len();
353 let comp_len = self.compressed.len();
354 if idx < imm_len {
355 Some(self.immutable[idx].clone())
356 } else if idx < imm_len + comp_len {
357 Some(self.compressed[idx - imm_len].clone())
358 } else if idx < imm_len + comp_len + self.incremental.len() {
359 Some(self.incremental[idx - imm_len - comp_len].clone())
360 } else {
361 None
362 }
363 }
364}
365
366impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
367 type Reply = Vec<B::Message>;
368
369 async fn handle(
370 &mut self,
371 _msg: RequestMessages,
372 _ctx: &mut Context<Self, Self::Reply>,
373 ) -> Self::Reply {
374 self.immutable
375 .iter()
376 .chain(self.compressed.iter())
377 .chain(self.incremental.iter())
378 .cloned()
379 .collect()
380 }
381}
382
383impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
384 type Reply = Vec<B::Message>;
385
386 async fn handle(
387 &mut self,
388 _msg: RequestImmutable,
389 _ctx: &mut Context<Self, Self::Reply>,
390 ) -> Self::Reply {
391 self.immutable.to_vec()
392 }
393}
394
395impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
396 type Reply = Vec<B::Message>;
397
398 async fn handle(
399 &mut self,
400 _msg: RequestCompressed,
401 _ctx: &mut Context<Self, Self::Reply>,
402 ) -> Self::Reply {
403 self.compressed.clone()
404 }
405}
406
407impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
408 type Reply = Vec<B::Message>;
409
410 async fn handle(
411 &mut self,
412 _msg: RequestIncremental,
413 _ctx: &mut Context<Self, Self::Reply>,
414 ) -> Self::Reply {
415 self.incremental.clone()
416 }
417}
418
419impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
420 type Reply = Vec<B::Message>;
421
422 async fn handle(
423 &mut self,
424 msg: RequestFindByRole,
425 _ctx: &mut Context<Self, Self::Reply>,
426 ) -> Self::Reply {
427 self.immutable
428 .iter()
429 .chain(self.compressed.iter())
430 .chain(self.incremental.iter())
431 .filter(|m| m.role() == msg.0)
432 .cloned()
433 .collect()
434 }
435}
436
437impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
442 type Reply = Result<B::Response, AgentError>;
443
444 async fn handle(
445 &mut self,
446 msg: RequestSend<B::Opts>,
447 _ctx: &mut Context<Self, Self::Reply>,
448 ) -> Self::Reply {
449 self.compress_if_full(&msg.opts).await?;
450 let scratch = msg.opts.as_ref().scratch.clone();
451 let mut all_messages: Vec<B::Message> = self
452 .immutable
453 .iter()
454 .chain(self.compressed.iter())
455 .chain(self.incremental.iter())
456 .cloned()
457 .collect();
458 if let Some(content) = scratch {
459 all_messages.push(self.backend.system_message(content));
460 }
461 let response = self.backend.send(&all_messages, &msg.opts).await?;
462 let raw_msgs = self
463 .backend
464 .extract_messages(std::slice::from_ref(&response))?;
465 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
466 for msg in &request_msgs {
467 self.incremental.push(msg.clone());
468 self.notify_change(NotifyChange::Appended(msg.clone()))
469 .await;
470 }
471 Ok(response)
472 }
473}
474
475impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
476 type Reply = Result<AgentSendStream<B>, AgentError>;
477
478 async fn handle(
479 &mut self,
480 msg: RequestSendStream<B::Opts>,
481 _ctx: &mut Context<Self, Self::Reply>,
482 ) -> Self::Reply {
483 self.compress_if_full(&msg.opts).await?;
484 let scratch = msg.opts.as_ref().scratch.clone();
485 let mut all_messages: Vec<B::Message> = self
486 .immutable
487 .iter()
488 .chain(self.compressed.iter())
489 .chain(self.incremental.iter())
490 .cloned()
491 .collect();
492 if let Some(content) = scratch {
493 all_messages.push(self.backend.system_message(content));
494 }
495 let stream = self.backend.send_stream(all_messages, msg.opts);
496 Ok(AgentSendStream::new(stream))
497 }
498}
499
500impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
505 type Reply = Result<(), AgentError>;
506
507 async fn handle(
508 &mut self,
509 msg: RequestCompress<B::Opts>,
510 _ctx: &mut Context<Self, Self::Reply>,
511 ) -> Self::Reply {
512 match msg.strategy {
513 CompressStrategy::Summarize { keep, prompt } => {
514 let total = self.incremental.len();
515 if total > keep {
516 let split = total - keep;
517 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
518 if !to_summarize.is_empty() {
519 let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
520 let mut summary_messages =
521 vec![self.backend.system_message(summary_prompt)];
522 summary_messages.append(&mut self.compressed);
523 summary_messages.extend(to_summarize);
524 let response = self.backend.send(&summary_messages, &msg.opts).await?;
525 let raw_msgs = self
526 .backend
527 .extract_messages(std::slice::from_ref(&response))?;
528 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
529 let summary: Vec<B::Message> = request_msgs
530 .into_iter()
531 .map(|msg| self.backend.to_system_message(msg))
532 .collect();
533 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
534 let (final_summary, final_kept) = if let Some(editor) =
535 &self.on_compressed
536 {
537 editor
538 .ask(NotifyCompressedForReply { summary, kept })
539 .send()
540 .await
541 .map_err(|e| AgentError::Context(e.to_string()))?
542 } else {
543 (summary, kept)
544 };
545 self.compressed = final_summary;
546 self.incremental = final_kept;
547 }
548 }
549 Ok(())
550 }
551 }
552 }
553}
554
555impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
560 type Reply = usize;
561
562 async fn handle(
563 &mut self,
564 _msg: RequestEstimateTokens,
565 _ctx: &mut Context<Self, Self::Reply>,
566 ) -> Self::Reply {
567 let all: Vec<B::Message> = self
568 .immutable
569 .iter()
570 .chain(self.compressed.iter())
571 .chain(self.incremental.iter())
572 .cloned()
573 .collect();
574 self.backend.estimate_tokens(&all).await.unwrap_or(0)
575 }
576}
577
578impl<B: ContextBackend> Message<RequestToJsonl> for AgentContext<B> {
579 type Reply = Result<String, AgentError>;
580
581 async fn handle(
582 &mut self,
583 _msg: RequestToJsonl,
584 _ctx: &mut Context<Self, Self::Reply>,
585 ) -> Self::Reply {
586 let lines: Vec<String> = self
587 .immutable
588 .iter()
589 .chain(self.compressed.iter())
590 .chain(self.incremental.iter())
591 .map(|m| self.backend.message_to_jsonl(m))
592 .collect::<Result<_, _>>()?;
593 Ok(lines.join("\n"))
594 }
595}
596
597impl<B: ContextBackend> Message<RequestFromJsonl> for AgentContext<B> {
598 type Reply = Result<(), AgentError>;
599
600 async fn handle(
601 &mut self,
602 msg: RequestFromJsonl,
603 _ctx: &mut Context<Self, Self::Reply>,
604 ) -> Self::Reply {
605 for line in msg.jsonl.lines() {
606 let line = line.trim();
607 if line.is_empty() {
608 continue;
609 }
610 let message: B::Message = self.backend.message_from_jsonl(line)?;
611 self.incremental.push(message.clone());
612 self.notify_change(NotifyChange::Appended(message)).await;
613 }
614 Ok(())
615 }
616}