nucel_agent_core/
session.rs1use chrono::{DateTime, Utc};
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::stream::{self, Stream, StreamExt};
8
9use crate::error::Result;
10use crate::types::{AgentCost, AgentResponse, ExecutorType, MessageEvent};
11
12pub type EventStream = Pin<Box<dyn Stream<Item = Result<MessageEvent>> + Send>>;
14
15#[derive(Debug, Clone)]
17pub struct SessionMetadata {
18 pub session_id: String,
19 pub executor_type: ExecutorType,
20 pub working_dir: PathBuf,
21 pub created_at: DateTime<Utc>,
22 pub model: Option<String>,
23}
24
25#[async_trait]
28pub trait SessionImpl: Send + Sync {
29 async fn query(&self, prompt: &str) -> Result<AgentResponse>;
30
31 async fn query_stream(&self, prompt: &str) -> Result<EventStream> {
41 let resp = self.query(prompt).await?;
42 let events = vec![
43 Ok(MessageEvent::TextChunk {
44 text: resp.content.clone(),
45 }),
46 Ok(MessageEvent::ResultDone {
47 cost: resp.cost.clone(),
48 content: resp.content,
49 is_error: false,
50 }),
51 ];
52 Ok(Box::pin(stream::iter(events)))
53 }
54
55 async fn total_cost(&self) -> Result<AgentCost>;
56 async fn close(&self) -> Result<()>;
57}
58
59pub struct AgentSession {
65 pub session_id: String,
67 pub executor_type: ExecutorType,
69 pub working_dir: PathBuf,
71 pub created_at: DateTime<Utc>,
73 pub model: Option<String>,
75
76 pub(crate) inner: Arc<dyn SessionImpl>,
77}
78
79impl std::fmt::Debug for AgentSession {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("AgentSession")
82 .field("session_id", &self.session_id)
83 .field("executor_type", &self.executor_type)
84 .field("working_dir", &self.working_dir)
85 .field("created_at", &self.created_at)
86 .field("model", &self.model)
87 .finish_non_exhaustive()
88 }
89}
90
91impl AgentSession {
92 pub fn new(
94 session_id: impl Into<String>,
95 executor_type: ExecutorType,
96 working_dir: impl Into<PathBuf>,
97 model: Option<String>,
98 inner: Arc<dyn SessionImpl>,
99 ) -> Self {
100 Self {
101 session_id: session_id.into(),
102 executor_type,
103 working_dir: working_dir.into(),
104 created_at: Utc::now(),
105 model,
106 inner,
107 }
108 }
109
110 pub async fn query(&self, prompt: &str) -> Result<AgentResponse> {
112 self.inner.query(prompt).await
113 }
114
115 pub async fn query_stream(&self, prompt: &str) -> Result<EventStream> {
117 self.inner.query_stream(prompt).await
118 }
119
120 pub async fn collect_stream(mut stream: EventStream) -> Result<AgentResponse> {
122 let mut content = String::new();
123 let mut cost = AgentCost::default();
124 let mut final_content: Option<String> = None;
125 let mut tool_calls: Vec<crate::types::ToolCall> = Vec::new();
126 let mut pending_tool: Option<crate::types::ToolCall> = None;
127 while let Some(evt) = stream.next().await {
128 match evt? {
129 MessageEvent::TextChunk { text } => content.push_str(&text),
130 MessageEvent::ToolUse { name, input, .. } => {
131 pending_tool = Some(crate::types::ToolCall {
132 name,
133 args: input,
134 result: None,
135 });
136 }
137 MessageEvent::ToolResult { success, output, .. } => {
138 if let Some(mut t) = pending_tool.take() {
139 t.result = Some(crate::types::ToolResult { success, output });
140 tool_calls.push(t);
141 }
142 }
143 MessageEvent::ResultDone {
144 cost: c,
145 content: final_text,
146 is_error,
147 } => {
148 cost = c;
149 if is_error {
150 return Err(crate::error::AgentError::Provider {
151 provider: "stream".into(),
152 message: final_text,
153 });
154 }
155 final_content = Some(final_text);
156 break;
157 }
158 MessageEvent::Error { message } => {
159 return Err(crate::error::AgentError::Provider {
160 provider: "stream".into(),
161 message,
162 });
163 }
164 MessageEvent::RateLimit { message } => {
165 return Err(crate::error::AgentError::RateLimited { message });
166 }
167 _ => {}
168 }
169 }
170 if let Some(c) = final_content {
171 if !c.is_empty() {
172 content = c;
173 }
174 }
175 Ok(AgentResponse {
176 content,
177 cost,
178 confidence: None,
179 requests_escalation: false,
180 tool_calls,
181 })
182 }
183
184 pub async fn total_cost(&self) -> Result<AgentCost> {
186 self.inner.total_cost().await
187 }
188
189 pub async fn close(self) -> Result<()> {
191 self.inner.close().await
192 }
193
194 pub fn metadata(&self) -> SessionMetadata {
196 SessionMetadata {
197 session_id: self.session_id.clone(),
198 executor_type: self.executor_type,
199 working_dir: self.working_dir.clone(),
200 created_at: self.created_at,
201 model: self.model.clone(),
202 }
203 }
204}