Skip to main content

agent_context/context/
stream.rs

1//! 流式对话返回类型 [`AgentSendStream`]。
2//!
3//! 逐块播出后端流数据,内部累积完整响应,Drop 时自动将结果存入 [`AgentContext`](super::AgentContext)。
4
5use std::sync::Arc;
6
7use kameo::prelude::*;
8
9use super::actor::{AgentContext, SilentAppendMsg};
10use super::event::ChangeEvent;
11use super::types::ContextBackend;
12use crate::error::AgentError;
13
14// ---------------------------------------------------------------------------
15// AgentSendStream — 流式对话返回类型
16// ---------------------------------------------------------------------------
17
18/// 流式对话的 Stream 返回类型。
19///
20/// 两层职责:
21/// 1. **poll 时**:逐块播出后端流的数据,同时内部累积
22/// 2. **Drop 时**:将累积的响应通过 [`extract_messages_from_backend_response`](ContextBackend::extract_messages_from_backend_response)
23///    和 [`to_request_messages`](ContextBackend::to_request_messages) 转换后,自动存入 `AgentContext` 的 incremental 区
24///
25/// 无需手动处理 — 调用者消费 stream 直到结束,drop 时触发自动存储。
26pub struct AgentSendStream<B: ContextBackend> {
27    backend: B,
28    inner:
29        std::pin::Pin<Box<dyn futures_core::Stream<Item = Result<B::Response, AgentError>> + Send>>,
30    accumulated: Vec<B::Response>,
31    actor_ref: Option<ActorRef<AgentContext<B>>>,
32    #[expect(clippy::type_complexity, reason = "回调类型不可避免复杂")]
33    on_change: Option<Arc<dyn Fn(ChangeEvent<B::Message>) + Send + Sync>>,
34}
35
36impl<B: ContextBackend> AgentSendStream<B> {
37    #[expect(clippy::type_complexity, reason = "回调类型不可避免复杂")]
38    pub(crate) fn new(
39        backend: B,
40        inner: impl futures_core::Stream<Item = Result<B::Response, AgentError>> + Send + 'static,
41        actor_ref: ActorRef<AgentContext<B>>,
42        on_change: Option<Arc<dyn Fn(ChangeEvent<B::Message>) + Send + Sync>>,
43    ) -> Self {
44        Self {
45            backend,
46            inner: Box::pin(inner),
47            accumulated: Vec::new(),
48            actor_ref: Some(actor_ref),
49            on_change,
50        }
51    }
52
53    pub fn take_accumulated(&mut self) -> Vec<B::Response> {
54        std::mem::take(&mut self.accumulated)
55    }
56}
57
58impl<B: ContextBackend> futures_core::Stream for AgentSendStream<B> {
59    type Item = Result<B::Response, AgentError>;
60
61    fn poll_next(
62        mut self: std::pin::Pin<&mut Self>,
63        cx: &mut std::task::Context<'_>,
64    ) -> std::task::Poll<Option<Self::Item>> {
65        let this = unsafe { self.as_mut().get_unchecked_mut() };
66        match this.inner.as_mut().poll_next(cx) {
67            std::task::Poll::Ready(Some(Ok(resp))) => {
68                this.accumulated.push(resp.clone());
69                std::task::Poll::Ready(Some(Ok(resp)))
70            }
71            std::task::Poll::Ready(Some(Err(e))) => std::task::Poll::Ready(Some(Err(e))),
72            std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
73            std::task::Poll::Pending => std::task::Poll::Pending,
74        }
75    }
76}
77
78impl<B: ContextBackend> Drop for AgentSendStream<B> {
79    fn drop(&mut self) {
80        let responses = std::mem::take(&mut self.accumulated);
81        if !responses.is_empty() {
82            let backend = self.backend.clone();
83            if let Some(actor_ref) = self.actor_ref.take() {
84                let on_change = self.on_change.clone();
85                tokio::spawn(async move {
86                    if let Ok(raw_msgs) = backend.extract_messages_from_backend_response(&responses)
87                    {
88                        if let Ok(request_msgs) = backend.to_request_messages(raw_msgs) {
89                            for m in request_msgs {
90                                if let Some(ref cb) = on_change {
91                                    cb(ChangeEvent::Appended(m.clone()));
92                                }
93                                if let Err(e) = actor_ref.tell(SilentAppendMsg { message: m }).await
94                                {
95                                    log::warn!("SilentAppendMsg 发送失败: {e:?}");
96                                }
97                            }
98                        } else {
99                            log::warn!("流式响应转换请求格式失败,已丢弃");
100                        }
101                    } else {
102                        log::warn!("流式响应提取消息失败,已丢弃");
103                    }
104                });
105            }
106        }
107    }
108}
109
110impl<B: ContextBackend> kameo::Reply for AgentSendStream<B> {
111    type Ok = Self;
112    type Error = std::convert::Infallible;
113    type Value = Self;
114    fn to_result(self) -> Result<Self::Ok, Self::Error> {
115        Ok(self)
116    }
117    fn into_any_err(self) -> Option<Box<dyn kameo::reply::ReplyError>> {
118        None
119    }
120    fn into_value(self) -> Self::Value {
121        self
122    }
123}